fix: action replace/enhance if the replaced/enhanced action as combined (#1814)

This commit is contained in:
Simon Hauser
2022-03-31 18:42:38 +02:00
committed by GitHub
parent b83d6d4711
commit d38ad438f3
2 changed files with 93 additions and 11 deletions

View File

@@ -24,9 +24,23 @@ local append_action_copy = function(new, v, old)
new._post[v] = old._post[v] new._post[v] = old._post[v]
end end
--TODO(conni2461): Not a fan of this solution/hack. Needs to be addressed -- TODO(conni2461): Not a fan of this solution/hack. Needs to be addressed
local all_mts = {} local all_mts = {}
--TODO(conni2461): It gets worse. This is so bad but because we have now n mts for n actions
-- We have to check all actions for relevant mts to set replace and before, after
-- Its not bad for performance because its being called on startup when we attach mappings.
-- Its just a bad solution
local find_all_relevant_mts = function(action_name, f)
for _, mt in ipairs(all_mts) do
for fun, _ in pairs(mt._func) do
if fun == action_name then
f(mt)
end
end
end
end
--- an action is metatable which allows replacement(prepend or append) of the function --- an action is metatable which allows replacement(prepend or append) of the function
---@class Action ---@class Action
---@field _func table<string, function>: the original action function ---@field _func table<string, function>: the original action function
@@ -122,12 +136,14 @@ action_mt.create = function()
assert(#self == 1, "Cannot replace an already combined action") assert(#self == 1, "Cannot replace an already combined action")
local action_name = self[1] local action_name = self[1]
find_all_relevant_mts(action_name, function(another)
if not mt._replacements[action_name] then if not another._replacements[action_name] then
mt._replacements[action_name] = {} another._replacements[action_name] = {}
end end
table.insert(mt._replacements[action_name], 1, tbl) table.insert(another._replacements[action_name], 1, tbl)
end)
return self return self
end end
@@ -135,13 +151,15 @@ action_mt.create = function()
assert(#self == 1, "Cannot enhance already combined actions") assert(#self == 1, "Cannot enhance already combined actions")
local action_name = self[1] local action_name = self[1]
find_all_relevant_mts(action_name, function(another)
if opts.pre then if opts.pre then
mt._pre[action_name] = opts.pre another._pre[action_name] = opts.pre
end end
if opts.post then if opts.post then
mt._post[action_name] = opts.post another._post[action_name] = opts.post
end end
end)
return self return self
end end

View File

@@ -282,6 +282,70 @@ describe("actions", function()
eq(3, called_count) eq(3, called_count)
end) end)
it(
"can call replace fn even when combined before replace registered the fn (because that happens with mappings)",
function()
local a = transform_mod {
x = function()
return "x"
end,
y = function()
return "y"
end,
}
local called_count = 0
local count_inc = function()
called_count = called_count + 1
end
local x_plus_y = a.x + a.y
a.x:replace(function()
count_inc()
end)
a.y:replace(function()
count_inc()
end)
x_plus_y()
eq(2, called_count)
end
)
it(
"can call enhance fn even when combined before enhance registed fns (because that happens with mappings)",
function()
local a = transform_mod {
x = function()
return "x"
end,
y = function()
return "y"
end,
}
local called_count = 0
local count_inc = function()
called_count = called_count + 1
end
local x_plus_y = a.x + a.y
a.y:enhance {
pre = count_inc,
post = count_inc,
}
a.x:enhance {
post = count_inc,
}
x_plus_y()
eq(3, called_count)
end
)
it("clears enhance", function() it("clears enhance", function()
local a = transform_mod { local a = transform_mod {
x = function() x = function()