diff --git a/lua/telescope/actions/mt.lua b/lua/telescope/actions/mt.lua index adb3293..02edea5 100644 --- a/lua/telescope/actions/mt.lua +++ b/lua/telescope/actions/mt.lua @@ -14,7 +14,28 @@ local run_replace_or_original = function(replacements, original_func, ...) return original_func(...) end -action_mt.create = function(mod) +local append_action_copy = function(new, v, old) + table.insert(new, v) + new._func[v] = old._func[v] + new._static_pre[v] = old._static_pre[v] + new._pre[v] = old._pre[v] + new._replacements[v] = old._replacements[v] + new._static_post[v] = old._static_post[v] + new._post[v] = old._post[v] +end + +--TODO(conni2461): Not a fan of this solution/hack. Needs to be addressed +local all_mts = {} + +--- an action is metatable which allows replacement(prepend or append) of the function +---@class Action +---@field _func table: the original action function +---@field _static_pre table: will allways run before the function even if its replaced +---@field _pre table: the functions that will run before the action +---@field _replacements table: the function that replaces this action +---@field _static_post table: will allways run after the function even if its replaced +---@field _post table: the functions that will run after the action +action_mt.create = function() local mt = { __call = function(t, ...) local values = {} @@ -27,7 +48,7 @@ action_mt.create = function(mod) end local result = { - run_replace_or_original(t._replacements[action_name], mod[action_name], ...), + run_replace_or_original(t._replacements[action_name], t._func[action_name], ...), } for _, res in ipairs(result) do table.insert(values, res) @@ -45,18 +66,23 @@ action_mt.create = function(mod) end, __add = function(lhs, rhs) - local new_actions = {} + local new_action = setmetatable({}, action_mt.create()) for _, v in ipairs(lhs) do - table.insert(new_actions, v) + append_action_copy(new_action, v, lhs) end for _, v in ipairs(rhs) do - table.insert(new_actions, v) + append_action_copy(new_action, v, rhs) + end + new_action.clear = function() + lhs.clear() + rhs.clear() end - return setmetatable(new_actions, getmetatable(lhs)) + return new_action end, + _func = {}, _static_pre = {}, _pre = {}, _replacements = {}, @@ -120,33 +146,47 @@ action_mt.create = function(mod) return self end + table.insert(all_mts, mt) return mt end -action_mt.transform = function(k, mt, mod, v) +action_mt.transform = function(k, mt, _, v) local res = setmetatable({ k }, mt) if type(v) == "table" then res._static_pre[k] = v.pre res._static_post[k] = v.post - mod[k] = v.action + res._func[k] = v.action + else + res._func[k] = v end return res end action_mt.transform_mod = function(mod) - local mt = action_mt.create(mod) - -- Pass the metatable of the module if applicable. -- This allows for custom errors, lookups, etc. local redirect = setmetatable({}, getmetatable(mod) or {}) for k, v in pairs(mod) do - redirect[k] = action_mt.transform(k, mt, mod, v) + local mt = action_mt.create() + redirect[k] = action_mt.transform(k, mt, _, v) end - redirect._clear = mt.clear + redirect._clear = function() + for k, v in pairs(redirect) do + if k ~= "_clear" then + pcall(v.clear) + end + end + end return redirect end +action_mt.clear_all = function() + for _, v in ipairs(all_mts) do + pcall(v.clear) + end +end + return action_mt diff --git a/lua/telescope/mappings.lua b/lua/telescope/mappings.lua index 24e22f6..487891c 100644 --- a/lua/telescope/mappings.lua +++ b/lua/telescope/mappings.lua @@ -222,10 +222,6 @@ mappings.apply_keymap = function(prompt_bufnr, attach_mappings, buffer_keymap) end end end - - vim.cmd( - string.format([[autocmd BufDelete %s :lua require('telescope.mappings').clear(%s)]], prompt_bufnr, prompt_bufnr) - ) end mappings.execute_keymap = function(prompt_bufnr, keymap_identifier) diff --git a/lua/telescope/pickers.lua b/lua/telescope/pickers.lua index 0e26781..a9c7fb2 100644 --- a/lua/telescope/pickers.lua +++ b/lua/telescope/pickers.lua @@ -8,7 +8,6 @@ local channel = require("plenary.async.control").channel local popup = require "plenary.popup" local actions = require "telescope.actions" -local action_set = require "telescope.actions.set" local config = require "telescope.config" local debounce = require "telescope.debounce" local deprecated = require "telescope.deprecated" @@ -52,14 +51,16 @@ function Picker:new(opts) error "layout_strategy and get_window_options are not compatible keys" end - -- Reset actions for any replaced / enhanced actions. - -- TODO: Think about how we could remember to NOT have to do this... - -- I almost forgot once already, cause I'm not smart enough to always do it. - actions._clear() - action_set._clear() - deprecated.options(opts) + -- We need to clear at the beginning not on close because after close we can still have select:post + -- etc ... + require("telescope.actions.mt").clear_all() + -- TODO(conni2461): This seems like the better solution but it won't clear actions that were never mapped + -- for _, v in ipairs(keymap_store[prompt_bufnr]) do + -- pcall(v.clear) + -- end + local layout_strategy = get_default(opts.layout_strategy, config.values.layout_strategy) local obj = setmetatable({ @@ -1446,6 +1447,7 @@ function pickers.on_close_prompt(prompt_bufnr) end picker.close_windows(status) + mappings.clear(prompt_bufnr) end function pickers.on_resize_window(prompt_bufnr) diff --git a/lua/tests/automated/action_spec.lua b/lua/tests/automated/action_spec.lua index 752ed53..634b2d7 100644 --- a/lua/tests/automated/action_spec.lua +++ b/lua/tests/automated/action_spec.lua @@ -3,9 +3,7 @@ local action_set = require "telescope.actions.set" local transform_mod = require("telescope.actions.mt").transform_mod -local eq = function(a, b) - assert.are.same(a, b) -end +local eq = assert.are.same describe("actions", function() it("should allow creating custom actions", function() @@ -207,6 +205,29 @@ describe("actions", function() eq(true, called_post) end) + it("static_pre static_post", function() + local called_pre = false + local called_post = false + local static_post = 0 + local a = transform_mod { + x = { + pre = function() + called_pre = true + end, + action = function() + return "x" + end, + post = function() + called_post = true + end, + }, + } + + eq("x", a.x()) + eq(true, called_pre) + eq(true, called_post) + end) + it("can call both", function() local a = transform_mod { x = function() @@ -298,6 +319,102 @@ describe("actions", function() eq("modified: 5", a.x(5)) end) + it("handles add with two different tables", function() + local count_a = 0 + local count_b = 0 + local a = transform_mod { + x = function() + count_a = count_a + 1 + end, + } + local b = transform_mod { + y = function() + count_b = count_b + 1 + end, + } + + local called_count = 0 + local count_inc = function() + called_count = called_count + 1 + end + + a.x:enhance { + post = count_inc, + } + b.y:enhance { + post = count_inc, + } + + local x_plus_y = a.x + b.y + x_plus_y() + + eq(2, called_count) + eq(1, count_a) + eq(1, count_b) + end) + + it("handles tripple concat with static pre post", function() + local count_a = 0 + local count_b = 0 + local count_c = 0 + local static_pre = 0 + local static_post = 0 + local a = transform_mod { + x = { + pre = function() + static_pre = static_pre + 1 + end, + action = function() + count_a = count_a + 1 + end, + post = function() + static_post = static_post + 1 + end, + }, + } + local b = transform_mod { + y = { + pre = function() + static_pre = static_pre + 1 + end, + action = function() + count_b = count_b + 1 + end, + post = function() + static_post = static_post + 1 + end, + }, + } + local c = transform_mod { + z = { + pre = function() + static_pre = static_pre + 1 + end, + action = function() + count_c = count_c + 1 + end, + post = function() + static_post = static_post + 1 + end, + }, + } + + local replace_count = 0 + a.x:replace(function() + replace_count = replace_count + 1 + end) + + local x_plus_y_plus_z = a.x + b.y + c.z + x_plus_y_plus_z() + + eq(0, count_a) + eq(1, count_b) + eq(1, count_c) + eq(1, replace_count) + eq(3, static_pre) + eq(3, static_post) + end) + describe("action_set", function() it("can replace `action_set.edit`", function() action_set.edit:replace(function(_, arg)