From d38ad438f3bb4e3721b9964172c8c9d70d5d06a8 Mon Sep 17 00:00:00 2001 From: Simon Hauser Date: Thu, 31 Mar 2022 18:42:38 +0200 Subject: [PATCH] fix: action replace/enhance if the replaced/enhanced action as combined (#1814) --- lua/telescope/actions/mt.lua | 40 +++++++++++++----- lua/tests/automated/action_spec.lua | 64 +++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/lua/telescope/actions/mt.lua b/lua/telescope/actions/mt.lua index 02edea5..804d972 100644 --- a/lua/telescope/actions/mt.lua +++ b/lua/telescope/actions/mt.lua @@ -24,9 +24,23 @@ local append_action_copy = function(new, v, old) new._post[v] = old._post[v] end ---TODO(conni2461): Not a fan of this solution/hack. Needs to be addressed +-- TODO(conni2461): Not a fan of this solution/hack. Needs to be addressed local all_mts = {} +--TODO(conni2461): It gets worse. This is so bad but because we have now n mts for n actions +-- We have to check all actions for relevant mts to set replace and before, after +-- Its not bad for performance because its being called on startup when we attach mappings. +-- Its just a bad solution +local find_all_relevant_mts = function(action_name, f) + for _, mt in ipairs(all_mts) do + for fun, _ in pairs(mt._func) do + if fun == action_name then + f(mt) + end + end + end +end + --- an action is metatable which allows replacement(prepend or append) of the function ---@class Action ---@field _func table: the original action function @@ -122,12 +136,14 @@ action_mt.create = function() assert(#self == 1, "Cannot replace an already combined action") local action_name = self[1] + find_all_relevant_mts(action_name, function(another) + if not another._replacements[action_name] then + another._replacements[action_name] = {} + end - if not mt._replacements[action_name] then - mt._replacements[action_name] = {} - end + table.insert(another._replacements[action_name], 1, tbl) + end) - table.insert(mt._replacements[action_name], 1, tbl) return self end @@ -135,13 +151,15 @@ action_mt.create = function() assert(#self == 1, "Cannot enhance already combined actions") local action_name = self[1] - if opts.pre then - mt._pre[action_name] = opts.pre - end + find_all_relevant_mts(action_name, function(another) + if opts.pre then + another._pre[action_name] = opts.pre + end - if opts.post then - mt._post[action_name] = opts.post - end + if opts.post then + another._post[action_name] = opts.post + end + end) return self end diff --git a/lua/tests/automated/action_spec.lua b/lua/tests/automated/action_spec.lua index 634b2d7..c1dd8b3 100644 --- a/lua/tests/automated/action_spec.lua +++ b/lua/tests/automated/action_spec.lua @@ -282,6 +282,70 @@ describe("actions", function() eq(3, called_count) end) + it( + "can call replace fn even when combined before replace registered the fn (because that happens with mappings)", + function() + local a = transform_mod { + x = function() + return "x" + end, + y = function() + return "y" + end, + } + + local called_count = 0 + local count_inc = function() + called_count = called_count + 1 + end + + local x_plus_y = a.x + a.y + a.x:replace(function() + count_inc() + end) + a.y:replace(function() + count_inc() + end) + + x_plus_y() + + eq(2, called_count) + end + ) + + it( + "can call enhance fn even when combined before enhance registed fns (because that happens with mappings)", + function() + local a = transform_mod { + x = function() + return "x" + end, + y = function() + return "y" + end, + } + + local called_count = 0 + local count_inc = function() + called_count = called_count + 1 + end + + local x_plus_y = a.x + a.y + a.y:enhance { + pre = count_inc, + post = count_inc, + } + + a.x:enhance { + post = count_inc, + } + + x_plus_y() + + eq(3, called_count) + end + ) + it("clears enhance", function() local a = transform_mod { x = function()