From 0f70a414faab5e27f26233b868c16d892093a992 Mon Sep 17 00:00:00 2001 From: danymat Date: Fri, 21 Oct 2022 16:29:44 +0200 Subject: [PATCH] feat(py) Better copy of nodes to generator (#106) --- lua/neogen/configurations/python.lua | 67 ++++++++++++++++++++-------- lua/neogen/templates/reST.lua | 4 +- lua/neogen/utilities/helpers.lua | 28 ++++++++++++ lua/neogen/utilities/nodes.lua | 2 +- 4 files changed, 79 insertions(+), 22 deletions(-) diff --git a/lua/neogen/configurations/python.lua b/lua/neogen/configurations/python.lua index e04e500..5bdb736 100644 --- a/lua/neogen/configurations/python.lua +++ b/lua/neogen/configurations/python.lua @@ -22,18 +22,23 @@ return { ["function_definition"] = { ["0"] = { extract = function(node) - local results = {} - local tree = { { retrieve = "all", node_type = "parameters", subtree = { - { retrieve = "all", node_type = "identifier", extract = true }, + { retrieve = "all", node_type = "identifier", extract = true, as = i.Parameter }, { retrieve = "all", node_type = "default_parameter", - subtree = { { retrieve = "all", node_type = "identifier", extract = true } }, + subtree = { + { + retrieve = "all", + node_type = "identifier", + extract = true, + as = i.Parameter, + }, + }, }, { retrieve = "all", @@ -45,7 +50,14 @@ return { retrieve = "all", node_type = "typed_default_parameter", extract = true, - subtree = { { retrieve = "all", node_type = "identifier", extract = true } }, + subtree = { + { + retrieve = "all", + node_type = "identifier", + extract = true, + as = i.Tparam, + }, + }, }, { retrieve = "first", @@ -70,6 +82,7 @@ return { node_type = "return_statement", recursive = true, extract = true, + as = i.Return, }, { retrieve = "all", @@ -94,8 +107,9 @@ return { }, } local nodes = nodes_utils:matching_nodes_from(node, tree) + local temp = {} if nodes[i.Tparam] then - results[i.Tparam] = {} + temp[i.Tparam] = {} for _, n in pairs(nodes[i.Tparam]) do local type_subtree = { { retrieve = "all", node_type = "identifier", extract = true, as = i.Parameter }, @@ -103,14 +117,15 @@ return { } local typed_parameters = nodes_utils:matching_nodes_from(n, type_subtree) typed_parameters = extractors:extract_from_matched(typed_parameters) - table.insert(results[i.Tparam], typed_parameters) + table.insert(temp[i.Tparam], typed_parameters) end end local res = extractors:extract_from_matched(nodes) + res[i.Tparam] = temp[i.Tparam] -- Return type hints takes precedence over all other types for generating template if res[i.ReturnTypeHint] then - res["return_statement"] = nil + res[i.HasReturn] = nil if res[i.ReturnTypeHint][1] == "None" then res[i.ReturnTypeHint] = nil end @@ -136,17 +151,31 @@ return { end end - results[i.HasParameter] = (res.typed_parameter or res.identifier) and { true } or nil - results[i.Type] = res.type - results[i.Parameter] = res.identifier - results[i.Return] = res.return_statement - results[i.ReturnTypeHint] = res[i.ReturnTypeHint] - results[i.HasReturn] = (res.return_statement or res.anonymous_return or res[i.ReturnTypeHint]) - and { true } - or nil - results[i.ArbitraryArgs] = res[i.ArbitraryArgs] - results[i.Kwargs] = res[i.Kwargs] - results[i.Throw] = res[i.Throw] + local results = helpers.copy({ + [i.HasParameter] = function(t) + return t[i.Parameter] and { true } or nil + end, + [i.Type] = true, + [i.Parameter] = function(t) + return t[i.Parameter] + end, + [i.Return] = true, + [i.HasReturn] = true, + [i.ReturnTypeHint] = true, + [i.ArbitraryArgs] = true, + [i.Kwargs] = true, + [i.Throw] = true, + [i.Tparam] = true, + }, res) or {} + + -- Generates a "flag" return + results[i.HasReturn] = (results[i.ReturnTypeHint] or results[i.Return]) and { true } or nil + + -- Removes generation for returns that are not typed + if results[i.ReturnTypeHint] then + results[i.Return] = nil + end + return results end, }, diff --git a/lua/neogen/templates/reST.lua b/lua/neogen/templates/reST.lua index 848bcdf..f7ad144 100644 --- a/lua/neogen/templates/reST.lua +++ b/lua/neogen/templates/reST.lua @@ -27,7 +27,7 @@ return { }, { i.ClassAttribute, ":param %s: $1" }, { i.Throw, ":raises %s: $1", { type = { "func" } } }, - { i.HasReturn, ":return: $1", { type = { "func" } } }, - { i.HasReturn, ":rtype: $1", { type = { "func" } } }, + { i.Return, ":return: $1", { type = { "func" }, after_each = ":rtype: $1" } }, + { i.ReturnTypeHint, ":return: $1", { type = { "func" } } }, { nil, '"""' }, } diff --git a/lua/neogen/utilities/helpers.lua b/lua/neogen/utilities/helpers.lua index bd0ea71..6cc1918 100644 --- a/lua/neogen/utilities/helpers.lua +++ b/lua/neogen/utilities/helpers.lua @@ -32,4 +32,32 @@ return { get_node_text = function(node, bufnr) return vim.split(vim.treesitter.query.get_node_text(node, bufnr or 0), "\n") end, + + --- Copies a table to another table depending of the parameters that we want to expose + ---TODO: create a doc for the table structure + ---@param rules table the rules that we want to execute + ---@param table table the table to copy + ---@return table? + ---@private + copy = function(rules, table) + P(rules, table) + local copy = {} + + for parameter, rule in pairs(rules) do + local parameter_value = table[parameter] + + if parameter_value then + if type(rule) == "function" then + copy[parameter] = vim.tbl_deep_extend("error", rule(table), copy[parameter] or {}) + elseif rule == true and parameter_value ~= nil then + copy[parameter] = parameter_value + else + vim.notify("Incorrect rule format for parameter " .. parameter, vim.log.levels.ERROR) + return + end + end + end + + return copy + end, } diff --git a/lua/neogen/utilities/nodes.lua b/lua/neogen/utilities/nodes.lua index 1bdd4c7..5cf1ec3 100644 --- a/lua/neogen/utilities/nodes.lua +++ b/lua/neogen/utilities/nodes.lua @@ -57,7 +57,7 @@ return { --- @param tree table a nested table : { retrieve = "all|first", node_type = node_name, subtree = tree, recursive = true } --- If you want to extract the node, do not specify the subtree and instead: extract = true --- Optional: you can specify position = number instead of retrieve, and it will fetch the child node at position number - --- @param result table the table of results + --- @param result? table the table of results --- @return table result a table of k,v where k are node_types and v all matched nodes matching_nodes_from = function(self, parent, tree, result) result = result or {}