diff --git a/lua/neogen/configurations/python.lua b/lua/neogen/configurations/python.lua index 72d20c4..b608c93 100644 --- a/lua/neogen/configurations/python.lua +++ b/lua/neogen/configurations/python.lua @@ -12,6 +12,32 @@ local parent = { type = { "expression_statement" }, } + +--- Check `node` for the first parent matching `type_name`. +--- +--- This function is *inclusive*. It tests `node` before checking any parent so +--- `node` is a possible return value. +--- +---@param node TSNode A tree-sitter node to check for some parent. +---@param type_name string The tree-sitter type to check. e.g. `"function_definition"`. +---@return TSNode? # The found node, if any. +local get_nearest_parent = function(node, type_name) + local current = node + + while current ~= nil + do + if current:type() == type_name + then + return current + end + + current = current:parent() + end + + return nil +end + + --- Modify `nodes` if the found return(s) are **all** bare-returns. --- --- A bare-return is used to return early from a function and aren't meant to be @@ -19,13 +45,13 @@ local parent = { --- --- If at least one return is not a bare-return then this function does nothing. --- ----@param nodes table +---@param nodes table local validate_bare_returns = function(nodes) - local return_node = nodes[i.Return] + local return_nodes = nodes[i.Return] local has_data = false - for _, value in pairs(return_node) do - if value:child_count() > 1 + for _, node in pairs(return_nodes) do + if node:child_count() > 1 then has_data = true end @@ -38,6 +64,30 @@ local validate_bare_returns = function(nodes) end +--- Check if any of ``nodes`` has ``parent`` as a direct function parent. +--- +---@param nodes table +--- The extracted tree-sitter Return nodes to consider. If at least one +--- found return has `parent` then this function does nothing. But if +--- no `parent` is found then neogen ignores return annotations. +---@param parent TSNode +--- The direct function to check for. If there is any function_definition +--- between each node in `nodes` and this `parent` then we consider that node +--- "ignorable". +local validate_direct_returns = function(nodes, parent) + local return_nodes = nodes[i.Return] + + for _, node in pairs(return_nodes) do + if get_nearest_parent(node, "function_definition") == parent + then + return + end + end + + nodes[i.Return] = nil +end + + --- Remove `i.Return` details from `nodes` if a Python generator was found. --- --- If there is at least one `yield` found, Python converts the function to a generator. @@ -178,6 +228,7 @@ return { if nodes[i.Return] then validate_bare_returns(nodes) + validate_direct_returns(nodes, node) end validate_yield_nodes(nodes) @@ -217,27 +268,27 @@ return { end local results = helpers.copy({ - [i.HasParameter] = function(t) - return (t[i.Parameter] or t[i.Tparam]) and { true } - end, - [i.HasReturn] = function(t) - return (t[i.ReturnTypeHint] or t[i.Return]) and { true } - end, - [i.HasThrow] = function(t) - return t[i.Throw] and { true } - end, - [i.Type] = true, - [i.Parameter] = true, - [i.Return] = true, - [i.ReturnTypeHint] = true, - [i.HasYield] = function(t) - return t[i.Yield] and { true } - end, - [i.ArbitraryArgs] = true, - [i.Kwargs] = true, - [i.Throw] = true, - [i.Tparam] = true, - }, res) or {} + [i.HasParameter] = function(t) + return (t[i.Parameter] or t[i.Tparam]) and { true } + end, + [i.HasReturn] = function(t) + return (t[i.ReturnTypeHint] or t[i.Return]) and { true } + end, + [i.HasThrow] = function(t) + return t[i.Throw] and { true } + end, + [i.Type] = true, + [i.Parameter] = true, + [i.Return] = true, + [i.ReturnTypeHint] = true, + [i.HasYield] = function(t) + return t[i.Yield] and { true } + end, + [i.ArbitraryArgs] = true, + [i.Kwargs] = true, + [i.Throw] = true, + [i.Tparam] = true, + }, res) or {} -- Removes generation for returns that are not typed if results[i.ReturnTypeHint] then diff --git a/lua/neogen/init.lua b/lua/neogen/init.lua index a3eb079..3a41707 100644 --- a/lua/neogen/init.lua +++ b/lua/neogen/init.lua @@ -301,7 +301,7 @@ end --- with multiple annotation conventions. ---@tag neogen-changelog ---@toc_entry Changes in neogen plugin -neogen.version = "2.16.0" +neogen.version = "2.16.1" --minidoc_afterlines_end return neogen