diff --git a/lua/telescope/actions.lua b/lua/telescope/actions/init.lua similarity index 80% rename from lua/telescope/actions.lua rename to lua/telescope/actions/init.lua index a35ffee..4bc230a 100644 --- a/lua/telescope/actions.lua +++ b/lua/telescope/actions/init.lua @@ -6,43 +6,14 @@ local log = require('telescope.log') local path = require('telescope.path') local state = require('telescope.state') +local transform_mod = require('telescope.actions.mt').transform_mod + local actions = setmetatable({}, { __index = function(_, k) error("Actions does not have a value: " .. tostring(k)) end }) -local action_mt = { - __call = function(t, ...) - local values = {} - for _, v in ipairs(t) do - local result = {v(...)} - for _, res in ipairs(result) do - table.insert(values, res) - end - end - - return unpack(values) - end, - - __add = function(lhs, rhs) - local new_actions = {} - for _, v in ipairs(lhs) do - table.insert(new_actions, v) - end - - for _, v in ipairs(rhs) do - table.insert(new_actions, v) - end - - return setmetatable(new_actions, getmetatable(lhs)) - end -} - -local transform_action = function(a) - return setmetatable({a}, action_mt) -end - --- Get the current picker object for the prompt function actions.get_current_picker(prompt_bufnr) return state.get_status(prompt_bufnr).picker @@ -68,8 +39,8 @@ function actions.add_selection(prompt_bufnr) end --- Get the current entry -function actions.get_selected_entry(prompt_bufnr) - return actions.get_current_picker(prompt_bufnr):get_selection() +function actions.get_selected_entry() + return state.get_global_key('selected_entry') end function actions.preview_scrolling_up(prompt_bufnr) @@ -81,7 +52,7 @@ function actions.preview_scrolling_down(prompt_bufnr) end -- TODO: It seems sometimes we get bad styling. -local function goto_file_selection(prompt_bufnr, command) +function actions._goto_file_selection(prompt_bufnr, command) local entry = actions.get_selected_entry(prompt_bufnr) if not entry then @@ -95,7 +66,7 @@ local function goto_file_selection(prompt_bufnr, command) -- TODO: Check for off-by-one row = entry.row or entry.lnum col = entry.col - else + elseif not entry.bufnr then -- TODO: Might want to remove this and force people -- to put stuff into `filename` local value = entry.value @@ -124,11 +95,11 @@ local function goto_file_selection(prompt_bufnr, command) actions.close(prompt_bufnr) - filename = path.normalize(filename, vim.fn.getcwd()) - if entry_bufnr then vim.cmd(string.format(":%s #%d", command, entry_bufnr)) else + filename = path.normalize(filename, vim.fn.getcwd()) + local bufnr = vim.api.nvim_get_current_buf() if filename ~= vim.api.nvim_buf_get_name(bufnr) then vim.cmd(string.format(":%s %s", command, filename)) @@ -151,19 +122,19 @@ function actions.center(_) end function actions.goto_file_selection_edit(prompt_bufnr) - goto_file_selection(prompt_bufnr, "edit") + actions._goto_file_selection(prompt_bufnr, "edit") end function actions.goto_file_selection_split(prompt_bufnr) - goto_file_selection(prompt_bufnr, "new") + actions._goto_file_selection(prompt_bufnr, "new") end function actions.goto_file_selection_vsplit(prompt_bufnr) - goto_file_selection(prompt_bufnr, "vnew") + actions._goto_file_selection(prompt_bufnr, "vnew") end function actions.goto_file_selection_tabedit(prompt_bufnr) - goto_file_selection(prompt_bufnr, "tabedit") + actions._goto_file_selection(prompt_bufnr, "tabedit") end function actions.close_pum(_) @@ -218,10 +189,9 @@ actions.insert_value = function(prompt_bufnr) return entry.value end -for k, v in pairs(actions) do - actions[k] = transform_action(v) -end - -actions._transform_action = transform_action - +-- ================================================== +-- Transforms modules and sets the corect metatables. +-- ================================================== +actions = transform_mod(actions) return actions + diff --git a/lua/telescope/actions/mt.lua b/lua/telescope/actions/mt.lua new file mode 100644 index 0000000..909e7bb --- /dev/null +++ b/lua/telescope/actions/mt.lua @@ -0,0 +1,96 @@ + +local action_mt = {} + +action_mt.create = function(mod) + local mt = { + __call = function(t, ...) + local values = {} + for _, v in ipairs(t) do + local func = t._replacements[v] or mod[v] + + if t._pre[v] then + t._pre[v](...) + end + + local result = {func(...)} + for _, res in ipairs(result) do + table.insert(values, res) + end + + if t._post[v] then + t._post[v](...) + end + end + + return unpack(values) + end, + + __add = function(lhs, rhs) + local new_actions = {} + for _, v in ipairs(lhs) do + table.insert(new_actions, v) + end + + for _, v in ipairs(rhs) do + table.insert(new_actions, v) + end + + return setmetatable(new_actions, getmetatable(lhs)) + end, + + _pre = {}, + _replacements = {}, + _post = {}, + } + + mt.__index = mt + + mt.clear = function() + mt._pre = {} + mt._replacements = {} + mt._post = {} + end + + --- Replace the reference to the function with a new one temporarily + function mt:replace(v) + assert(#self == 1, "Cannot replace an already combined action") + + local action_name = self[1] + mt._replacements[action_name] = v + end + + function mt:enhance(opts) + assert(#self == 1, "Cannot enhance already combined actions") + + local action_name = self[1] + if opts.pre then + mt._pre[action_name] = opts.pre + end + + if opts.post then + mt._post[action_name] = opts.post + end + end + + return mt +end + +action_mt.transform = function(k, mt) + return setmetatable({k}, mt) +end + +action_mt.transform_mod = function(mod) + local mt = action_mt.create(mod) + + local redirect = {} + + for k, _ in pairs(mod) do + redirect[k] = action_mt.transform(k, mt) + end + + redirect._clear = mt.clear + + return redirect +end + +return action_mt diff --git a/lua/telescope/builtin.lua b/lua/telescope/builtin.lua index c792f54..4102906 100644 --- a/lua/telescope/builtin.lua +++ b/lua/telescope/builtin.lua @@ -804,12 +804,15 @@ builtin.current_buffer_fuzzy_find = function(opts) table.insert(lines_with_numbers, {k, v}) end + local bufnr = vim.api.nvim_get_current_buf() + pickers.new(opts, { prompt_title = 'Current Buffer Fuzzy', finder = finders.new_table { results = lines_with_numbers, entry_maker = function(enumerated_line) return { + bufnr = bufnr, display = enumerated_line[2], ordinal = enumerated_line[2], @@ -818,17 +821,13 @@ builtin.current_buffer_fuzzy_find = function(opts) end }, sorter = sorters.get_generic_fuzzy_sorter(), - attach_mappings = function(prompt_bufnr, map) - local goto_line = function() - local selection = actions.get_selected_entry(prompt_bufnr) - actions.close(prompt_bufnr) - - vim.api.nvim_win_set_cursor(0, {selection.lnum, 0}) - vim.cmd [[stopinsert]] - end - - map('n', '', goto_line) - map('i', '', goto_line) + attach_mappings = function(prompt_bufnr) + actions._goto_file_selection:enhance { + post = vim.schedule_wrap(function() + local selection = actions.get_selected_entry(prompt_bufnr) + vim.api.nvim_win_set_cursor(0, {selection.lnum, 0}) + end), + } return true end diff --git a/lua/telescope/pickers.lua b/lua/telescope/pickers.lua index d3bba7c..ac31348 100644 --- a/lua/telescope/pickers.lua +++ b/lua/telescope/pickers.lua @@ -61,6 +61,9 @@ function Picker:new(opts) error("layout_strategy and get_window_options are not compatible keys") end + -- Reset actions for any replaced / enhanced actions. + actions._clear() + local layout_strategy = get_default(opts.layout_strategy, config.values.layout_strategy) return setmetatable({ @@ -708,6 +711,8 @@ function Picker:set_selection(row) local status = state.get_status(self.prompt_bufnr) local results_bufnr = status.results_bufnr + state.set_global_key("selected_entry", entry) + if not vim.api.nvim_buf_is_valid(results_bufnr) then return end diff --git a/lua/telescope/state.lua b/lua/telescope/state.lua index a014a0d..6a06eb1 100644 --- a/lua/telescope/state.lua +++ b/lua/telescope/state.lua @@ -1,12 +1,21 @@ local state = {} TelescopeGlobalState = TelescopeGlobalState or {} +TelescopeGlobalState.global = TelescopeGlobalState.global or {} --- Set the status for a particular prompt bufnr function state.set_status(prompt_bufnr, status) TelescopeGlobalState[prompt_bufnr] = status end +function state.set_global_key(key, value) + TelescopeGlobalState.global[key] = value +end + +function state.get_global_key(key) + return TelescopeGlobalState.global[key] +end + function state.get_status(prompt_bufnr) return TelescopeGlobalState[prompt_bufnr] or {} end diff --git a/lua/tests/automated/action_spec.lua b/lua/tests/automated/action_spec.lua new file mode 100644 index 0000000..e85cb3c --- /dev/null +++ b/lua/tests/automated/action_spec.lua @@ -0,0 +1,164 @@ +require('plenary.test_harness'):setup_busted() + +local transform_mod = require('telescope.actions.mt').transform_mod + +local eq = function(a, b) + assert.are.same(a, b) +end + +describe('actions', function() + it('should allow creating custom actions', function() + local a = transform_mod { + x = function() return 5 end, + } + + + eq(5, a.x()) + end) + + it('allows adding actions', function() + local a = transform_mod { + x = function() return "x" end, + y = function() return "y" end, + } + + local x_plus_y = a.x + a.y + + eq({"x", "y"}, {x_plus_y()}) + end) + + it('ignores nils from added actions', function() + local a = transform_mod { + x = function() return "x" end, + y = function() return "y" end, + nil_maker = function() return nil end, + } + + local x_plus_y = a.x + a.nil_maker + a.y + + eq({"x", "y"}, {x_plus_y()}) + end) + + it('allows overriding an action', function() + local a = transform_mod { + x = function() return "x" end, + y = function() return "y" end, + } + + -- actions.file_goto_selection_edit:replace(...) + a.x:replace(function() return "foo" end) + eq("foo", a.x()) + + a._clear() + eq("x", a.x()) + end) + + it('enhance.pre', function() + local a = transform_mod { + x = function() return "x" end, + y = function() return "y" end, + } + + local called_pre = false + + a.y:enhance { + pre = function() + called_pre = true + end, + } + eq("y", a.y()) + eq(true, called_pre) + end) + + it('enhance.post', function() + local a = transform_mod { + x = function() return "x" end, + y = function() return "y" end, + } + + local called_post = false + + a.y:enhance { + post = function() + called_post = true + end, + } + eq("y", a.y()) + eq(true, called_post) + end) + + it('can call both', function() + local a = transform_mod { + x = function() return "x" end, + y = function() return "y" end, + } + + local called_count = 0 + local count_inc = function() + called_count = called_count + 1 + end + + a.y:enhance { + pre = count_inc, + post = count_inc, + } + + eq("y", a.y()) + eq(2, called_count) + end) + + it('can call both even when combined', function() + local a = transform_mod { + x = function() return "x" end, + y = function() return "y" end, + } + + local called_count = 0 + local count_inc = function() + called_count = called_count + 1 + end + + a.y:enhance { + pre = count_inc, + post = count_inc, + } + + a.x:enhance { + post = count_inc + } + + local x_plus_y = a.x + a.y + x_plus_y() + + eq(3, called_count) + end) + + it('clears enhance', function() + local a = transform_mod { + x = function() return "x" end, + y = function() return "y" end, + } + + local called_post = false + + a.y:enhance { + post = function() + called_post = true + end, + } + + a._clear() + + eq("y", a.y()) + eq(false, called_post) + end) + + it('handles passing arguments', function() + local a = transform_mod { + x = function(bufnr) return string.format("bufnr: %s") end, + } + + a.x:replace(function(bufnr) return string.format("modified: %s", bufnr) end) + eq("modified: 5", a.x(5)) + end) +end)