diff --git a/lua/cmp/utils/misc.lua b/lua/cmp/utils/misc.lua index 8f58f07..fbd7de1 100644 --- a/lua/cmp/utils/misc.lua +++ b/lua/cmp/utils/misc.lua @@ -156,28 +156,39 @@ misc.set = function(t, keys, v) c[keys[#keys]] = v end ----Copy table ----@generic T ----@param tbl T ----@return T -misc.copy = function(tbl) - if type(tbl) ~= 'table' then - return tbl - end +do + local function do_copy(tbl, seen) + if type(tbl) ~= 'table' then + return tbl + end + if seen[tbl] then + return seen[tbl] + end + + if islist(tbl) then + local copy = {} + seen[tbl] = copy + for i, value in ipairs(tbl) do + copy[i] = do_copy(value, seen) + end + return copy + end - if islist(tbl) then local copy = {} - for i, value in ipairs(tbl) do - copy[i] = misc.copy(value) + seen[tbl] = copy + for key, value in pairs(tbl) do + copy[key] = do_copy(value, seen) end return copy end - local copy = {} - for key, value in pairs(tbl) do - copy[key] = misc.copy(value) + ---Copy table + ---@generic T + ---@param tbl T + ---@return T + misc.copy = function(tbl) + return do_copy(tbl, {}) end - return copy end ---Safe version of vim.str_utfindex diff --git a/lua/cmp/utils/misc_spec.lua b/lua/cmp/utils/misc_spec.lua index f687155..db84d0d 100644 --- a/lua/cmp/utils/misc_spec.lua +++ b/lua/cmp/utils/misc_spec.lua @@ -5,6 +5,33 @@ local misc = require('cmp.utils.misc') describe('misc', function() before_each(spec.before) + it('copy', function() + -- basic. + local tbl, copy + tbl = { + a = { + b = 1, + }, + } + copy = misc.copy(tbl) + assert.are_not.equal(tbl, copy) + assert.are_not.equal(tbl.a, copy.a) + assert.are.same(tbl, copy) + + -- self reference. + tbl = { + a = { + b = 1, + }, + } + tbl.a.c = tbl.a + copy = misc.copy(tbl) + assert.are_not.equal(tbl, copy) + assert.are_not.equal(tbl.a, copy.a) + assert.are_not.equal(tbl.a.c, copy.a.c) + assert.are.same(tbl, copy) + end) + it('merge', function() local merged merged = misc.merge({