diff --git a/lua/cmp/config.lua b/lua/cmp/config.lua index ddda370..42fdc4f 100644 --- a/lua/cmp/config.lua +++ b/lua/cmp/config.lua @@ -1,3 +1,4 @@ +local mapping = require('cmp.config.mapping') local cache = require('cmp.utils.cache') local keymap = require('cmp.utils.keymap') local misc = require('cmp.utils.misc') @@ -91,9 +92,11 @@ end ---@return cmp.ConfigSchema config.normalize = function(c) if c.mapping then + local normalized = {} for k, v in pairs(c.mapping) do - c.mapping[keymap.normalize(k)] = v + normalized[keymap.normalize(k)] = mapping(v, { 'i' }) end + c.mapping = normalized end return c end diff --git a/lua/cmp/config/mapping.lua b/lua/cmp/config/mapping.lua index 100b446..a4754f1 100644 --- a/lua/cmp/config/mapping.lua +++ b/lua/cmp/config/mapping.lua @@ -1,35 +1,14 @@ -local api = require('cmp.utils.api') - local mapping mapping = setmetatable({}, { __call = function(_, invoke, modes) if type(invoke) == 'function' then - return { - invoke = function(...) - invoke(...) - end, - modes = modes or { 'i' }, - __type = 'mapping', - } - elseif type(invoke) == 'table' then - if invoke.__type == 'mapping' then - return invoke - else - return mapping(function(fallback) - if api.is_insert_mode() and invoke.i then - return invoke.i(fallback) - elseif api.is_cmdline_mode() and invoke.c then - return invoke.c(fallback) - elseif api.is_select_mode() and invoke.s then - return invoke.s(fallback) - elseif api.is_visual_mode() and invoke.x then - return invoke.x(fallback) - else - fallback() - end - end, vim.tbl_keys(invoke)) + local map = {} + for _, mode in ipairs(modes or { 'i' }) do + map[mode] = invoke end + return map end + return invoke end, }) diff --git a/lua/cmp/core.lua b/lua/cmp/core.lua index 01f683c..42c20d0 100644 --- a/lua/cmp/core.lua +++ b/lua/cmp/core.lua @@ -104,14 +104,10 @@ end ---Keypress handler core.on_keymap = function(self, keys, fallback) - for key, action in pairs(config.get().mapping) do - if keymap.equals(key, keys) then - if type(action) == 'function' then - action(fallback) - else - action.invoke(fallback) - end - return + local mode = api.get_mode() + for key, mapping in pairs(config.get().mapping) do + if keymap.equals(key, keys) and mapping[mode] then + return mapping[mode](fallback) end end @@ -139,14 +135,8 @@ end ---Prepare completion core.prepare = function(self) - for keys, action in pairs(config.get().mapping) do - if type(action) == 'function' then - action = { - modes = { 'i' }, - action = action, - } - end - for _, mode in ipairs(action.modes) do + for keys, mapping in pairs(config.get().mapping) do + for mode in pairs(mapping) do keymap.listen(mode, keys, function(...) self:on_keymap(...) end) diff --git a/lua/cmp/types/cmp.lua b/lua/cmp/types/cmp.lua index db8a41f..c1fb3df 100644 --- a/lua/cmp/types/cmp.lua +++ b/lua/cmp/types/cmp.lua @@ -50,6 +50,7 @@ cmp.ItemField.Menu = 'menu' ---@field public __call fun(c: cmp.ConfigSchema) ---@field public buffer fun(c: cmp.ConfigSchema) ---@field public global fun(c: cmp.ConfigSchema) +---@field public cmdline fun(type: string, c: cmp.ConfigSchema) ---@class cmp.SourceBaseApiParams ---@field public option table @@ -59,6 +60,12 @@ cmp.ItemField.Menu = 'menu' ---@field public offset number ---@field public completion_context lsp.CompletionContext +---@class cmp.Mapping +---@field public i nil|function(fallback: function): void +---@field public c nil|function(fallback: function): void +---@field public x nil|function(fallback: function): void +---@field public s nil|function(fallback: function): void + ---@class cmp.ConfigSchema ---@field private revision number ---@field public enabled fun():boolean|boolean @@ -69,7 +76,7 @@ cmp.ItemField.Menu = 'menu' ---@field public sorting cmp.SortingConfig ---@field public formatting cmp.FormattingConfig ---@field public snippet cmp.SnippetConfig ----@field public mapping table +---@field public mapping table ---@field public sources cmp.SourceConfig[] ---@field public experimental cmp.ExperimentalConfig diff --git a/lua/cmp/utils/api.lua b/lua/cmp/utils/api.lua index e98ba63..a91f3e3 100644 --- a/lua/cmp/utils/api.lua +++ b/lua/cmp/utils/api.lua @@ -1,5 +1,17 @@ local api = {} +api.get_mode = function() + if api.is_insert_mode() then + return 'i' + elseif api.is_visual_mode() then + return 'x' + elseif api.is_select_mode() then + return 's' + elseif api.is_cmdline_mode() then + return 'c' + end +end + api.is_insert_mode = function() return vim.tbl_contains({ 'i',