diff --git a/lua/neogen.lua b/lua/neogen.lua index 01ad38d..d843339 100644 --- a/lua/neogen.lua +++ b/lua/neogen.lua @@ -1,67 +1,269 @@ -local ts_utils = require("nvim-treesitter.ts_utils") +local ok, ts_utils = pcall(require, "nvim-treesitter.ts_utils") -neogen = {} +assert(ok, "neogen requires nvim-treesitter to operate :(") -local configuration = require('neogen.config') +local neogen = { + utility = { + wrap = function(name, ...) + local args = { ... } -neogen.generate = function () - local comment = {} + return function() + name(table.unpack(args)) + end + end, - -- Try to find the upper function - local cursor = ts_utils.get_node_at_cursor(0) - local function_node = cursor - while function_node ~= nil do - if function_node:type() == "function_definition" then break end - if function_node:type() == "function" then break end - if function_node:type() == "local_function" then break end - function_node = function_node:parent() - end - local line = ts_utils.get_node_range(function_node) + extract_children = function(_, name) + return function(node) + local result = {} + local split = vim.split(name, "|", true) - -- find the starting position in the line function - local line_content = vim.api.nvim_buf_get_lines(0, line, line+1, false)[1] - local offset = line_content:match("^%s+") or "" + for child in node:iter_children() do + if vim.tbl_contains(split, child:type()) then + table.insert(result, ts_utils.get_node_text(child)[1]) + end + end - local return_comment = offset .. "---@return " - local param_comment = offset .. "---@param " + return result + end + end, - -- Parse and iterate over each found query - local returned = vim.treesitter.get_query("lua", "neogen") - for id, node in returned:iter_captures(function_node) do + extract_children_from = function(self, name, nodes) + return function(node) + local result = {} - -- Try to add params - if returned.captures[id] == "params" then - local params = ts_utils.get_node_text(node)[1]:sub(2,-2) - for p in string.gmatch(params, '[^,]+') do - p = p:gsub("%s+", "") -- remove trailing spaces - table.insert(comment, param_comment .. p .. " ") + for i, value in ipairs(nodes) do + local child_node = node:named_child(i - 1) + + if value == "extract" then + return self:extract_children(name)(child_node) + else + return self:extract_children_from(name, value)(node) + end + end + + return result + end + end, + }, + + default_locator = function(node_info, nodes_to_match) + if vim.tbl_contains(nodes_to_match, node_info.current:type()) then + return node_info.current + end + + while node_info.current and not vim.tbl_contains(nodes_to_match, node_info.current:type()) do + node_info.current = node_info.current:parent() + end + + return node_info.current + end, + + default_granulator = function(parent_node, node_data) + local result = {} + + for parent_type, child_data in pairs(node_data) do + local matches = vim.split(parent_type, "|", true) + if vim.tbl_contains(matches, parent_node:type()) then + for i, _ in pairs(node_data[parent_type]) do + local data = child_data[i] + + local child_node = parent_node:named_child(tonumber(i) - 1) + + if not child_node then + return + end + + if child_node:type() == data.match or not data.match then + local extract = {} + + if data.extract then + extract = data.extract(child_node) + + if data.type then + -- Extract information into a one-dimensional array + local one_dimensional_arr = {} + + for _, values in pairs(extract) do + table.insert(one_dimensional_arr, values) + end + + result[data.type] = one_dimensional_arr + else + for type, extracted_data in pairs(extract) do + result[type] = extracted_data + end + end + else + extract = ts_utils.get_node_text(child_node) + result[data.type] = extract + end + end + end end end - -- Try to add return statement - if returned.captures[id] == "return" then - table.insert(comment, return_comment) + return result + end, + + default_generator = function(parent, data, template) + local start_row, start_column, _, _ = ts_utils.get_node_range(parent) + local commentstring, generated_template = vim.trim(vim.api.nvim_buf_get_option(0, "commentstring"):format("")) + + if not template then + generated_template = { + { nil, "" }, + { "name", " @Summary " }, + { "parameters", " @Param " }, + { "return", " @Return " }, + } + elseif type(template) == "function" then + generated_template = template(parent, commentstring, data) + else + generated_template = template end - end - -- At the end, add description annotation - table.insert(comment, 1, offset .. "---") + local function parse_generated_template() + local result = {} + local prefix = (" "):rep(start_column) .. commentstring - if #comment == 0 then return end + for _, values in ipairs(generated_template) do + local type = values[1] - -- Write on top of function - vim.fn.append(line, comment) - vim.fn.cursor(line+1, #comment[1]) - vim.api.nvim_command('startinsert!') + if not type then + table.insert(result, prefix .. values[2]:format("")) + else + if data[type] then + if #vim.tbl_values(data[type]) == 1 then + table.insert(result, prefix .. values[2]:format(data[type][1])) + else + for _, value in ipairs(data[type]) do + table.insert(result, prefix .. values[2]:format(value)) + end + end + end + end + end + + return result + end + + return start_row, parse_generated_template() + end, +} + +-- TODO: Move code here +neogen.generate = function(searcher, generator) end + +neogen.auto_generate = function(custom_template) + vim.treesitter.get_parser(0):for_each_tree(function(tree, language_tree) + local searcher = neogen.configuration.languages[language_tree:lang()] + + if searcher then + searcher.locator = searcher.locator or neogen.default_locator + searcher.granulator = searcher.granulator or neogen.default_granulator + searcher.generator = searcher.generator or neogen.default_generator + + local located_parent_node = searcher.locator({ + root = tree:root(), + current = ts_utils.get_node_at_cursor(0), + }, searcher.parent) + + if not located_parent_node then + return + end + + local data = searcher.granulator(located_parent_node, searcher.data) + + if data and not vim.tbl_isempty(data) then + local to_place, content = searcher.generator( + located_parent_node, + data, + custom_template or searcher.template + ) + + vim.fn.append(to_place, content) + end + end + end) end function neogen.generate_command() - vim.api.nvim_command('command! -range -bar Neogen lua require("neogen").generate()') + vim.api.nvim_command('command! -range -bar Neogen lua require("neogen").auto_generate()') end neogen.setup = function(opts) - local config = opts or configuration - if config.enabled == true then neogen.generate_command() end + neogen.configuration = vim.tbl_deep_extend("keep", opts or {}, { + -- DEFAULT CONFIGURATION + languages = { + lua = { + -- Search for these nodes + parent = { "function", "local_function", "local_variable_declaration", "field" }, + + -- Traverse down these nodes and extract the information as necessary + data = { + ["function|local_function"] = { + ["2"] = { + match = "parameters", + + extract = function(node) + local regular_params = neogen.utility:extract_children("identifier")(node) + local varargs = neogen.utility:extract_children("spread")(node) + + return { + parameters = regular_params, + vararg = varargs, + } + end, + }, + }, + ["local_variable_declaration|field"] = { + ["2"] = { + match = "function_definition", + + extract = function(node) + local regular_params = neogen.utility:extract_children_from("identifier", { + [1] = "extract", + })(node) + + local varargs = neogen.utility:extract_children_from("spread", { + [1] = "extract", + })(node) + + return { + parameters = regular_params, + vararg = varargs, + } + end, + }, + }, + }, + + -- Custom lua locator that escapes from comments + locator = function(node_info, nodes_to_match) + -- We're dealing with a lua comment and we need to escape its grasp + if node_info.current:type() == "source" then + local start_row, _, _, _ = ts_utils.get_node_range(node_info.current) + vim.api.nvim_win_set_cursor(0, { start_row, 0 }) + node_info.current = ts_utils.get_node_at_cursor() + end + + return neogen.default_locator(node_info, nodes_to_match) + end, + + -- Use default granulator and generator + granulator = nil, + generator = nil, + + template = { + { nil, "-" }, + { "parameters", "-@param %s any" }, + { "vararg", "-@vararg any" }, + }, + }, + }, + }) + + neogen.generate_command() end return neogen diff --git a/lua/neogen/config.lua b/lua/neogen/config.lua deleted file mode 100644 index 78ea978..0000000 --- a/lua/neogen/config.lua +++ /dev/null @@ -1,5 +0,0 @@ -neogen.configuration = { - enabled = false, -} - -return neogen diff --git a/stylua.toml b/stylua.toml new file mode 100644 index 0000000..bb258b9 --- /dev/null +++ b/stylua.toml @@ -0,0 +1,6 @@ +column_width = 120 +line_endings = "Unix" +indent_type = "Spaces" +indent_width = 4 +quote_style = "AutoPreferDouble" +no_call_parentheses = false