From ad3c1adbc365a3b70d8b6e1e90946d98eeacf20f Mon Sep 17 00:00:00 2001 From: Jonas Strittmatter <40792180+smjonas@users.noreply.github.com> Date: Fri, 11 Feb 2022 06:23:59 +0100 Subject: [PATCH] Refactor context module, return false (not {}) in context.in_treesitter_capture (#777) --- lua/cmp/config/context.lua | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/lua/cmp/config/context.lua b/lua/cmp/config/context.lua index 95aed7a..584f38a 100644 --- a/lua/cmp/config/context.lua +++ b/lua/cmp/config/context.lua @@ -11,6 +11,7 @@ context.in_syntax_group = function(group) return true end end + return false end ---Check if cursor is in treesitter capture @@ -20,6 +21,7 @@ context.in_treesitter_capture = function(capture) local highlighter = require('vim.treesitter.highlighter') local ts_utils = require('nvim-treesitter.ts_utils') local buf = vim.api.nvim_get_current_buf() + local row, col = unpack(vim.api.nvim_win_get_cursor(0)) row = row - 1 if vim.api.nvim_get_mode().mode == 'i' then @@ -28,7 +30,7 @@ context.in_treesitter_capture = function(capture) local self = highlighter.active[buf] if not self then - return {} + return false end local node_types = {} @@ -40,30 +42,24 @@ context.in_treesitter_capture = function(capture) local root = tstree:root() local root_start_row, _, root_end_row, _ = root:range() - if root_start_row > row or root_end_row < row then return end local query = self:get_query(tree:lang()) - if not query:query() then return end local iter = query:query():iter_captures(root, self.bufnr, row, row + 1) - for _, node, _ in iter do if ts_utils.is_in_node_range(node, row, col) then table.insert(node_types, node:type()) end end end, true) - if vim.tbl_contains(node_types, capture) then - return true - else - return false - end + + return vim.tbl_contains(node_types, capture) end return context