From 77515a574d634ec3cfa0acf9431e5315ea6abe03 Mon Sep 17 00:00:00 2001 From: Daniel Mathiot Date: Tue, 24 Aug 2021 09:41:15 +0200 Subject: [PATCH] Refactor to a simpler tree parsing You can now parse the syntax tree more efficiently with the new utilities functions. For examples, please check out the lua and python configurations --- README.md | 24 ++-- lua/neogen.lua | 1 - lua/neogen/configurations/lua.lua | 41 ++++--- lua/neogen/configurations/python.lua | 160 ++++++++++++--------------- lua/neogen/generators/default.lua | 4 +- lua/neogen/granulators/default.lua | 2 +- lua/neogen/utilities/extractors.lua | 52 ++------- lua/neogen/utilities/nodes.lua | 41 ++++++- 8 files changed, 161 insertions(+), 164 deletions(-) diff --git a/README.md b/README.md index c771dca..e27ef6b 100644 --- a/README.md +++ b/README.md @@ -197,26 +197,28 @@ end ```lua data = { - -- If function or local_function is found as a parent ["function|local_function"] = { - -- Get second child from the parent node + -- get second child from the parent node ["2"] = { - -- This second child has to be of type "parameters", otherwise does nothing + -- it has to be of type "parameters" match = "parameters", - - -- Extractor function that returns a set of TSname = values with values being of type string[] + extract = function(node) - local regular_params = neogen.utilities.extractors:extract_children_text("identifier")(node) - local varargs = neogen.utilities.extractors:extract_children_text("spread")(node) - + local tree = { + { retrieve = "all", node_type = "identifier", extract = true }, + { retrieve = "all", node_type = "spread", extract = true } + } + local nodes = neogen.utilities.nodes:matching_nodes_from(node, tree) + local res = neogen.utilities.extractors:extract_from_matched(nodes) + return { - parameters = regular_params, - vararg = varargs, + parameters = res.identifier, + vararg = res.spread, } end, }, }, -}, +} ``` Notes: diff --git a/lua/neogen.lua b/lua/neogen.lua index 53e1d1b..48490bf 100644 --- a/lua/neogen.lua +++ b/lua/neogen.lua @@ -3,7 +3,6 @@ assert(ok, "neogen requires nvim-treesitter to operate :(") neogen = {} - -- Require utilities neogen.utilities = {} require("neogen.utilities.extractors") diff --git a/lua/neogen/configurations/lua.lua b/lua/neogen/configurations/lua.lua index 20bd7e4..e5c2ea0 100644 --- a/lua/neogen/configurations/lua.lua +++ b/lua/neogen/configurations/lua.lua @@ -1,5 +1,3 @@ -local ts_utils = require("nvim-treesitter.ts_utils") - return { -- Search for these nodes parent = { "function", "local_function", "local_variable_declaration", "field", "variable_declaration" }, @@ -13,12 +11,16 @@ return { match = "parameters", extract = function(node) - local regular_params = neogen.utilities.extractors:extract_children_text("identifier")(node) - local varargs = neogen.utilities.extractors:extract_children_text("spread")(node) + local tree = { + { retrieve = "all", node_type = "identifier", extract = true }, + { retrieve = "all", node_type = "spread", extract = true }, + } + local nodes = neogen.utilities.nodes:matching_nodes_from(node, tree) + local res = neogen.utilities.extractors:extract_from_matched(nodes) return { - parameters = regular_params, - vararg = varargs, + parameters = res.identifier, + vararg = res.spread, } end, }, @@ -28,20 +30,25 @@ return { match = "function_definition", extract = function(node) - local regular_params = neogen.utilities.extractors:extract_children_from({ - [1] = "extract", - }, "identifier")(node) + local tree = { + { + retrieve = "first", + node_type = "parameters", + subtree = { + { retrieve = "all", node_type = "identifier", extract = true }, + { retrieve = "all", node_type = "spread", extract = true }, + }, + }, + { retrieve = "first", node_type = "return_statement", extract = true }, + } - local varargs = neogen.utilities.extractors:extract_children_from({ - [1] = "extract", - }, "spread")(node) - - local return_statement = neogen.utilities.extractors:extract_children_text("return_statement")(node) + local nodes = neogen.utilities.nodes:matching_nodes_from(node, tree) + local res = neogen.utilities.extractors:extract_from_matched(nodes) return { - parameters = regular_params, - vararg = varargs, - return_statement = return_statement, + parameters = res.identifier, + vararg = res.spread, + return_statement = res.return_statement, } end, }, diff --git a/lua/neogen/configurations/python.lua b/lua/neogen/configurations/python.lua index 1ed1ba5..d3c367d 100644 --- a/lua/neogen/configurations/python.lua +++ b/lua/neogen/configurations/python.lua @@ -8,68 +8,48 @@ return { data = { ["function_definition"] = { ["0"] = { - extract = function (node) - local results = { - parameters = {}, - return_statement = {} + extract = function(node) + local results = {} + + local tree = { + { + retrieve = "all", + node_type = "parameters", + subtree = { + { retrieve = "all", node_type = "identifier", extract = true }, + + { + retrieve = "all", + node_type = "default_parameter", + subtree = { { retrieve = "all", node_type = "identifier", extract = true } }, + }, + { + retrieve = "all", + node_type = "typed_parameter", + subtree = { { retrieve = "all", node_type = "identifier", extract = true } }, + }, + { + retrieve = "all", + node_type = "typed_default_parameter", + subtree = { { retrieve = "all", node_type = "identifier", extract = true } }, + }, + }, + }, + { + retrieve = "first", + node_type = "block", + subtree = { + { retrieve = "all", node_type = "return_statement", extract = true }, + }, + }, } + local nodes = neogen.utilities.nodes:matching_nodes_from(node, tree) + local res = neogen.utilities.extractors:extract_from_matched(nodes) - local params = neogen.utilities.nodes:matching_child_nodes(node, "parameters")[1] - - if #params == 0 then - results.parameters = nil - end - - local found_nodes - - -- Find regular parameters - local regular_params = neogen.utilities.extractors:extract_children_text("identifier")(params) - if #regular_params == 0 then - regular_params = nil - end - - for _, _params in pairs(regular_params) do - table.insert(results.parameters, _params) - end - - results.parameters = regular_params - - -- Find regular optional parameters - found_nodes = neogen.utilities.nodes:matching_child_nodes(params, "default_parameter") - for _,_node in pairs(found_nodes) do - local _params = neogen.utilities.extractors:extract_children_text("identifier")(_node)[1] - table.insert(results.parameters, _params) - end - - -- Find typed params - found_nodes = neogen.utilities.nodes:matching_child_nodes(params, "typed_parameter") - for _,_node in pairs(found_nodes) do - local _params = neogen.utilities.extractors:extract_children_text("identifier")(_node)[1] - table.insert(results.parameters, _params) - end - - -- TODO Find optional typed params - found_nodes = neogen.utilities.nodes:matching_child_nodes(params, "typed_default_parameter") - for _,_node in pairs(found_nodes) do - local _params = neogen.utilities.extractors:extract_children_text("identifier")(_node)[1] - table.insert(results.parameters, _params) - end - - - local body = neogen.utilities.nodes:matching_child_nodes(node, "block")[1] - if body ~= nil then - local return_statement = neogen.utilities.nodes:matching_child_nodes(body, "return_statement") - - if #return_statement == 0 then - return_statement = nil - end - - results.return_statement = return_statement - end - - + results.parameters = res.identifier + results.return_statement = res.return_statement return results - end + end, }, }, ["class_definition"] = { @@ -77,34 +57,40 @@ return { match = "block", extract = function(node) - local results = { - attributes = {} + local results = {} + local tree = { + { + retrieve = "first", + node_type = "function_definition", + subtree = { + { + retrieve = "first", + node_type = "block", + subtree = { + { + retrieve = "all", + node_type = "expression_statement", + subtree = { + { retrieve = "first", node_type = "assignment", extract = true }, + }, + }, + }, + }, + }, + }, } - local init_function = neogen.utilities.nodes:matching_child_nodes(node, "function_definition")[1] + local nodes = neogen.utilities.nodes:matching_nodes_from(node, tree) - if init_function == nil then - return - end - - local body = neogen.utilities.nodes:matching_child_nodes(init_function, "block")[1] - - if body == nil then - return - end - - local expressions = neogen.utilities.nodes:matching_child_nodes(body, "expression_statement") - for _,expression in pairs(expressions) do - local assignment = neogen.utilities.nodes:matching_child_nodes(expression, "assignment")[1] - if assignment ~= nil then - local left_side = assignment:field("left")[1] - local left_attribute = left_side:field("attribute")[1] - table.insert(results.attributes, ts_utils.get_node_text(left_attribute)[1]) - end + results.attributes = {} + for _, assignment in pairs(nodes["assignment"]) do + local left_side = assignment:field("left")[1] + local left_attribute = left_side:field("attribute")[1] + table.insert(results.attributes, ts_utils.get_node_text(left_attribute)[1]) end return results - end + end, }, }, }, @@ -115,7 +101,7 @@ return { generator = nil, template = { - annotation_convention = "numpydoc", -- required: Which annotation convention to use (default_generator) + annotation_convention = "google_docstrings", -- required: Which annotation convention to use (default_generator) append = { position = "after", child_name = "block" }, -- optional: where to append the text (default_generator) use_default_comment = false, -- If you want to prefix the template with the default comment for the language, e.g for python: # (default_generator) google_docstrings = { @@ -128,11 +114,11 @@ return { }, numpydoc = { { nil, '"""' }, - { "parameters", "%s: ", { before_first_item = { "", "Parameters", "----------" } } }, - { "attributes", "%s: ", { before_first_item = { "", "Attributes", "----------" } } }, - { "return_statement", "", { before_first_item = { "", "Returns", "-------" } } }, + { "parameters", "%s: ", { before_first_item = { "", "Parameters", "----------" } } }, + { "attributes", "%s: ", { before_first_item = { "", "Attributes", "----------" } } }, + { "return_statement", "", { before_first_item = { "", "Returns", "-------" } } }, { nil, "" }, - { nil, '"""' } - } + { nil, '"""' }, + }, }, } diff --git a/lua/neogen/generators/default.lua b/lua/neogen/generators/default.lua index 52cf3fe..de6363a 100644 --- a/lua/neogen/generators/default.lua +++ b/lua/neogen/generators/default.lua @@ -19,7 +19,7 @@ neogen.default_generator = function(parent, data, template) if append.position == "after" then local child_node = neogen.utilities.nodes:matching_child_nodes(parent, append.child_name)[1] if child_node ~= nil then - row_to_place, col_to_place, _ , _ = child_node:range() + row_to_place, col_to_place, _, _ = child_node:range() end end @@ -58,7 +58,7 @@ neogen.default_generator = function(parent, data, template) -- Will append the item before all their nodes if opts.before_first_item and data[type] then for _, value in pairs(opts.before_first_item) do - table.insert(result, prefix .. value) + table.insert(result, prefix .. value) end end diff --git a/lua/neogen/granulators/default.lua b/lua/neogen/granulators/default.lua index f112b2f..1251f97 100644 --- a/lua/neogen/granulators/default.lua +++ b/lua/neogen/granulators/default.lua @@ -16,7 +16,7 @@ neogen.default_granulator = function(parent_node, node_data) local child_node if tonumber(i) == 0 then - child_node = parent_node + child_node = parent_node else child_node = parent_node:named_child(tonumber(i) - 1) end diff --git a/lua/neogen/utilities/extractors.lua b/lua/neogen/utilities/extractors.lua index bc8d98a..60a478b 100644 --- a/lua/neogen/utilities/extractors.lua +++ b/lua/neogen/utilities/extractors.lua @@ -1,50 +1,18 @@ local ts_utils = require("nvim-treesitter.ts_utils") neogen.utilities.extractors = { - --- Return a function to extract content of required children from a node + --- Extract the content from each node from data --- @param _ any self - --- @param name string the children we want to extract (if multiple childrens, separate each one with "|") - --- @return function cb function taking a node and getting the content of each children we want from name - extract_children_text = function(_, name) - return function(node) - local result = {} - local split = vim.split(name, "|", true) - - for child in node:iter_children() do - if vim.tbl_contains(split, child:type()) then - table.insert(result, ts_utils.get_node_text(child)[1]) - end + --- @param data table a list of k,v values where k is the node_type and v a table of nodes + --- @return any result the same table as data but with node texts instead + extract_from_matched = function(_, data) + local result = {} + for k, v in pairs(data) do + local get_text = function(node) + return ts_utils.get_node_text(node)[1] end - - return result - end - end, - - --- Extract content from specified children from a tree - --- the tree parameter can be a nested { [key] = value} with key being the - --- * key: is which children we want to extract the values from (e.g first children is 1) - --- * value: "extract" or { [key] = value }. If value is "extract", it will extract the key child node - --- Example (extract the first child node from the first child node of the parent node): - --- [1] = { - --- [1] = "extract" - --- } - --- @param tree table see description - --- @param name string the children we want to extract (if multiple children, separate each one with "|") - extract_children_from = function(self, tree, name) - return function(node) - local result = {} - - for i, subtree in pairs(tree) do - local child_node = node:named_child(tonumber(i) - 1) - - if subtree == "extract" then - return self:extract_children_text(name)(child_node) - else - return self:extract_children_from(subtree, name)(child_node) - end - end - - return result + result[k] = vim.tbl_map(get_text, v) end + return result end, } diff --git a/lua/neogen/utilities/nodes.lua b/lua/neogen/utilities/nodes.lua index cb1268e..95a3fb0 100644 --- a/lua/neogen/utilities/nodes.lua +++ b/lua/neogen/utilities/nodes.lua @@ -2,16 +2,51 @@ neogen.utilities.nodes = { --- Get a list of child nodes that match the provided node name --- @param _ any --- @param parent userdata the parent's node - --- @param node_name string the node type to search for + --- @param node_name string the node type to search for (if multiple childrens, separate each one with "|") --- @return table a table of nodes that matched the name - matching_child_nodes = function (_, parent, node_name) + matching_child_nodes = function(_, parent, node_name) local results = {} + local split = vim.split(node_name, "|", true) for child in parent:iter_children() do - if child:type() == node_name then + if vim.tbl_contains(split, child:type()) then table.insert(results, child) end end return results end, + + --- Get all required nodes from tree + --- @param parent userdata the parent node + --- @param tree table a nested table : { retrieve = "all|first", node_type = node_name, subtree = tree } + --- If you want to extract the node, do not specify the subtree and instead: extract = true + --- @param result table the table of results + --- @return table result a table of k,v where k are node_types and v all matched nodes + matching_nodes_from = function(self, parent, tree, result) + result = result or {} + + for _, subtree in pairs(tree) do + -- Match all child nodes + local matched = self:matching_child_nodes(parent, subtree.node_type) + + -- Only keep first matched child node + if subtree.retrieve == "first" and #matched ~= 0 then + matched = { matched[1] } + end + + for _, child in pairs(matched) do + -- Add to results + if subtree.extract == true then + if result[subtree.node_type] == nil then + result[subtree.node_type] = {} + end + table.insert(result[subtree.node_type], child) + else + local test = self:matching_nodes_from(child, subtree.subtree, result) + result = vim.tbl_deep_extend("keep", result, test) + end + end + end + return result + end, }