Refactor context module, return false (not {}) in context.in_treesitter_capture (#777)

This commit is contained in:
Jonas Strittmatter
2022-02-11 06:23:59 +01:00
committed by GitHub
parent 3a2f1bbc55
commit ad3c1adbc3

View File

@@ -11,6 +11,7 @@ context.in_syntax_group = function(group)
return true return true
end end
end end
return false
end end
---Check if cursor is in treesitter capture ---Check if cursor is in treesitter capture
@@ -20,6 +21,7 @@ context.in_treesitter_capture = function(capture)
local highlighter = require('vim.treesitter.highlighter') local highlighter = require('vim.treesitter.highlighter')
local ts_utils = require('nvim-treesitter.ts_utils') local ts_utils = require('nvim-treesitter.ts_utils')
local buf = vim.api.nvim_get_current_buf() local buf = vim.api.nvim_get_current_buf()
local row, col = unpack(vim.api.nvim_win_get_cursor(0)) local row, col = unpack(vim.api.nvim_win_get_cursor(0))
row = row - 1 row = row - 1
if vim.api.nvim_get_mode().mode == 'i' then if vim.api.nvim_get_mode().mode == 'i' then
@@ -28,7 +30,7 @@ context.in_treesitter_capture = function(capture)
local self = highlighter.active[buf] local self = highlighter.active[buf]
if not self then if not self then
return {} return false
end end
local node_types = {} local node_types = {}
@@ -40,30 +42,24 @@ context.in_treesitter_capture = function(capture)
local root = tstree:root() local root = tstree:root()
local root_start_row, _, root_end_row, _ = root:range() local root_start_row, _, root_end_row, _ = root:range()
if root_start_row > row or root_end_row < row then if root_start_row > row or root_end_row < row then
return return
end end
local query = self:get_query(tree:lang()) local query = self:get_query(tree:lang())
if not query:query() then if not query:query() then
return return
end end
local iter = query:query():iter_captures(root, self.bufnr, row, row + 1) local iter = query:query():iter_captures(root, self.bufnr, row, row + 1)
for _, node, _ in iter do for _, node, _ in iter do
if ts_utils.is_in_node_range(node, row, col) then if ts_utils.is_in_node_range(node, row, col) then
table.insert(node_types, node:type()) table.insert(node_types, node:type())
end end
end end
end, true) end, true)
if vim.tbl_contains(node_types, capture) then
return true return vim.tbl_contains(node_types, capture)
else
return false
end
end end
return context return context