From abb5c7519d40314240effc70d781149c0f097af3 Mon Sep 17 00:00:00 2001 From: Folke Lemaitre Date: Thu, 25 May 2023 19:46:53 +0200 Subject: [PATCH] feat: cmp async (#1583) --- doc/cmp.txt | 13 +++- lua/cmp/config/default.lua | 7 +- lua/cmp/context.lua | 6 ++ lua/cmp/core.lua | 20 +++-- lua/cmp/entry.lua | 16 +++- lua/cmp/source.lua | 51 ++++++------- lua/cmp/types/cmp.lua | 2 + lua/cmp/utils/async.lua | 152 +++++++++++++++++++++++++++++++++++-- lua/cmp/view.lua | 4 + 9 files changed, 220 insertions(+), 51 deletions(-) diff --git a/doc/cmp.txt b/doc/cmp.txt index fd6daef..0c13dfe 100644 --- a/doc/cmp.txt +++ b/doc/cmp.txt @@ -413,11 +413,22 @@ performance.throttle~ This is used to delay filtering and displaying completions. *cmp-config.performance.fetching_timeout* - performance.fetching_timeout~ +performance.fetching_timeout~ `number` Sets the timeout of candidate fetching process. The nvim-cmp will wait to display the most prioritized source. + *cmp-config.performance.async_budget* +performance.async_budget~ + `number` + Maximum time (in ms) an async function is allowed to run during + one step of the event loop. + + *cmp-config.performance.max_view_entries* +performance.max_view_entries~ + `number` + Maximum number of items to show in the entries list. + *cmp-config.preselect* preselect~ `cmp.PreselectMode` diff --git a/lua/cmp/config/default.lua b/lua/cmp/config/default.lua index 5259520..02296c9 100644 --- a/lua/cmp/config/default.lua +++ b/lua/cmp/config/default.lua @@ -19,6 +19,8 @@ return function() debounce = 60, throttle = 30, fetching_timeout = 500, + async_budget = 1, + max_view_entries = 200, }, preselect = types.cmp.PreselectMode.Item, @@ -88,7 +90,10 @@ return function() }, view = { - entries = { name = 'custom', selection_order = 'top_down' }, + entries = { + name = 'custom', + selection_order = 'top_down', + }, }, window = { diff --git a/lua/cmp/context.lua b/lua/cmp/context.lua index 0c25462..0411a54 100644 --- a/lua/cmp/context.lua +++ b/lua/cmp/context.lua @@ -16,6 +16,7 @@ local api = require('cmp.utils.api') ---@field public cursor_line string ---@field public cursor_after_line string ---@field public cursor_before_line string +---@field public aborted boolean local context = {} ---Create new empty context @@ -55,9 +56,14 @@ context.new = function(prev_context, option) self.cursor.character = misc.to_utfindex(self.cursor_line, self.cursor.col) self.cursor_before_line = string.sub(self.cursor_line, 1, self.cursor.col - 1) self.cursor_after_line = string.sub(self.cursor_line, self.cursor.col) + self.aborted = false return self end +context.abort = function(self) + self.aborted = true +end + ---Return context creation reason. ---@return cmp.ContextReason context.get_reason = function(self) diff --git a/lua/cmp/core.lua b/lua/cmp/core.lua index d7432da..2775cf2 100644 --- a/lua/cmp/core.lua +++ b/lua/cmp/core.lua @@ -56,6 +56,7 @@ end ---@param option? cmp.ContextOption ---@return cmp.Context core.get_context = function(self, option) + self.context:abort() local prev = self.context:clone() prev.prev_context = nil prev.cache = nil @@ -296,7 +297,7 @@ core.complete = function(self, ctx) end ---Update completion menu -core.filter = async.throttle(function(self) +local async_filter = async.wrap(function(self) self.filter.timeout = config.get().performance.throttle -- Check invalid condition. @@ -323,20 +324,17 @@ core.filter = async.throttle(function(self) local ctx = self:get_context() -- Display completion results. - self.view:open(ctx, sources) + local did_open = self.view:open(ctx, sources) + local fetching = #self:get_sources(function(s) + return s.status == source.SourceStatus.FETCHING + end) -- Check onetime config. - if #self:get_sources(function(s) - if s.status == source.SourceStatus.FETCHING then - return true - elseif #s:get_entries(ctx) > 0 then - return true - end - return false - end) == 0 then + if not did_open and fetching == 0 then config.set_onetime({}) end -end, config.get().performance.throttle) +end) +core.filter = async.throttle(async_filter, config.get().performance.throttle) ---Confirm completion. ---@param e cmp.Entry diff --git a/lua/cmp/entry.lua b/lua/cmp/entry.lua index 54c57bb..6da8e76 100644 --- a/lua/cmp/entry.lua +++ b/lua/cmp/entry.lua @@ -376,7 +376,10 @@ entry.match = function(self, input, matching_config) } local score, matches, filter_text, _ + local checked = {} ---@type string[] + filter_text = self:get_filter_text() + checked[filter_text] = true score, matches = matcher.match(input, filter_text, option) -- Support the language server that doesn't respect VSCode's behaviors. @@ -390,16 +393,21 @@ entry.match = function(self, input, matching_config) accept = accept or string.find(self:get_completion_item().textEdit.newText, prefix, 1, true) if accept then filter_text = prefix .. self:get_filter_text() - score, matches = matcher.match(input, filter_text, option) + if not checked[filter_text] then + checked[filter_text] = true + score, matches = matcher.match(input, filter_text, option) + end end end end end - local vim_item = self:get_vim_item(self:get_offset()) - if filter_text ~= vim_item.abbr then + if score == 0 then + local vim_item = self:get_vim_item(self:get_offset()) filter_text = vim_item.abbr or vim_item.word - _, matches = matcher.match(input, filter_text, option) + if not checked[filter_text] then + _, matches = matcher.match(input, filter_text, option) + end end return { score = score, matches = matches } diff --git a/lua/cmp/source.lua b/lua/cmp/source.lua index 87a95a8..28ae135 100644 --- a/lua/cmp/source.lua +++ b/lua/cmp/source.lua @@ -18,7 +18,6 @@ local char = require('cmp.utils.char') ---@field public incomplete boolean ---@field public is_triggered_by_symbol boolean ---@field public entries cmp.Entry[] ----@field public filtered {entries: cmp.Entry[], ctx: cmp.Context} ---@field public offset integer ---@field public request_offset integer ---@field public context cmp.Context @@ -54,7 +53,6 @@ source.reset = function(self) self.is_triggered_by_symbol = false self.incomplete = false self.entries = {} - self.filtered = {} self.offset = -1 self.request_offset = -1 self.completion_context = nil @@ -90,28 +88,26 @@ source.get_entries = function(self, ctx) return {} end - if self.filtered.ctx and self.filtered.ctx.id == ctx.id then - return self.filtered.entries - end + local target_entries = self.entries - local target_entries = (function() - local key = { 'get_entries', self.revision } - for i = ctx.cursor.col, self.offset, -1 do - key[3] = string.sub(ctx.cursor_before_line, 1, i) - local prev_entries = self.cache:get(key) - if prev_entries then - return prev_entries - end + local prev = self.cache:get({ 'get_entries', self.revision }) + + if prev and ctx.cursor.row == prev.ctx.cursor.row then + if ctx.cursor.col == prev.ctx.cursor.col then + return prev.entries end - return self.entries - end)() + -- only use prev entries when cursor is moved forward. + -- and the pattern offset is the same. + if ctx.cursor.col >= prev.ctx.cursor.col and ctx.offset == prev.ctx.offset then + target_entries = prev.entries + end + end local entry_filter = self:get_entry_filter() local inputs = {} ---@type cmp.Entry[] local entries = {} - local max_item_count = self:get_source_config().max_item_count or 200 local matching_config = self:get_matching_config() for _, e in ipairs(target_entries) do local o = e:get_offset() @@ -128,20 +124,16 @@ source.get_entries = function(self, ctx) if entry_filter(e, ctx) then entries[#entries + 1] = e - if max_item_count and #entries >= max_item_count then - break - end end end + async.yield() + if ctx.aborted then + async.abort() + end end - -- only save to cache, when there are no additional entries that could match the filter - -- This also prevents too much memory usage - if #entries < max_item_count then - self.cache:set({ 'get_entries', tostring(self.revision), ctx.cursor_before_line }, entries) - end + self.cache:set({ 'get_entries', self.revision }, { entries = entries, ctx = ctx }) - self.filtered = { entries = entries, ctx = ctx } return entries end @@ -337,7 +329,10 @@ source.complete = function(self, ctx, callback) context = ctx, completion_context = completion_context, }), - self.complete_dedup(vim.schedule_wrap(function(response) + self.complete_dedup(function(response) + if self.context ~= ctx then + return + end ---@type lsp.CompletionResponse response = response or {} @@ -358,7 +353,7 @@ source.complete = function(self, ctx, callback) end end self.revision = self.revision + 1 - if #self:get_entries(ctx) == 0 then + if #self.entries == 0 then self.offset = old_offset self.entries = old_entries self.revision = self.revision + 1 @@ -372,7 +367,7 @@ source.complete = function(self, ctx, callback) self.status = prev_status end callback() - end)) + end) ) return true end diff --git a/lua/cmp/types/cmp.lua b/lua/cmp/types/cmp.lua index 8d544d7..7bf58d2 100644 --- a/lua/cmp/types/cmp.lua +++ b/lua/cmp/types/cmp.lua @@ -99,6 +99,8 @@ cmp.ItemField = { ---@field public debounce integer ---@field public throttle integer ---@field public fetching_timeout integer +---@field public async_budget integer Maximum time (in ms) an async function is allowed to run during one step of the event loop. +---@field public max_view_entries integer ---@class cmp.WindowConfig ---@field completion cmp.WindowConfig diff --git a/lua/cmp/utils/async.lua b/lua/cmp/utils/async.lua index fdfc7aa..c62b565 100644 --- a/lua/cmp/utils/async.lua +++ b/lua/cmp/utils/async.lua @@ -1,4 +1,5 @@ local feedkeys = require('cmp.utils.feedkeys') +local config = require('cmp.config') local async = {} @@ -9,6 +10,7 @@ local async = {} ---@field public stop function ---@field public __call function +---@type uv_timer_t[] local timers = {} vim.api.nvim_create_autocmd('VimLeavePre', { @@ -27,7 +29,8 @@ vim.api.nvim_create_autocmd('VimLeavePre', { ---@return cmp.AsyncThrottle async.throttle = function(fn, timeout) local time = nil - local timer = vim.loop.new_timer() + local timer = assert(vim.loop.new_timer()) + local _async = nil ---@type Async? timers[#timers + 1] = timer return setmetatable({ running = false, @@ -37,9 +40,15 @@ async.throttle = function(fn, timeout) return not self.running end) end, - stop = function() - time = nil + stop = function(reset_time) + if reset_time ~= false then + time = nil + end timer:stop() + if _async then + _async:cancel() + _async = nil + end end, }, { __call = function(self, ...) @@ -50,12 +59,23 @@ async.throttle = function(fn, timeout) end self.running = true - timer:stop() + self.stop(false) timer:start(math.max(1, self.timeout - (vim.loop.now() - time)), 0, function() vim.schedule(function() time = nil - fn(unpack(args)) - self.running = false + local ret = fn(unpack(args)) + if async.is_async(ret) then + ---@cast ret Async + _async = ret + _async:await(function(_, error) + self.running = false + if error and error ~= 'abort' then + vim.notify(error, vim.log.levels.ERROR) + end + end) + else + self.running = false + end end) end) end, @@ -147,4 +167,124 @@ async.debounce_next_tick_by_keymap = function(callback) end end +local Scheduler = {} +Scheduler._queue = {} +Scheduler._executor = assert(vim.loop.new_check()) + +function Scheduler.step() + local budget = config.get().performance.async_budget * 1e6 + local start = vim.loop.hrtime() + while #Scheduler._queue > 0 and vim.loop.hrtime() - start < budget do + local a = table.remove(Scheduler._queue, 1) + a:_step() + if a.running then + table.insert(Scheduler._queue, a) + end + end + if #Scheduler._queue == 0 then + return Scheduler._executor:stop() + end +end + +---@param a Async +function Scheduler.add(a) + table.insert(Scheduler._queue, a) + if not Scheduler._executor:is_active() then + Scheduler._executor:start(vim.schedule_wrap(Scheduler.step)) + end +end + +--- @alias AsyncCallback fun(result?:any, error?:string) + +--- @class Async +--- @field running boolean +--- @field result? any +--- @field error? string +--- @field callbacks AsyncCallback[] +--- @field thread thread +local Async = {} +Async.__index = Async + +function Async.new(fn) + local self = setmetatable({}, Async) + self.callbacks = {} + self.running = true + self.thread = coroutine.create(fn) + Scheduler.add(self) + return self +end + +---@param result? any +---@param error? string +function Async:_done(result, error) + self.running = false + self.result = result + self.error = error + for _, callback in ipairs(self.callbacks) do + callback(result, error) + end +end + +function Async:_step() + local ok, res = coroutine.resume(self.thread) + if not ok then + return self:_done(nil, res) + elseif res == 'abort' then + return self:_done(nil, 'abort') + elseif coroutine.status(self.thread) == 'dead' then + return self:_done(res) + end +end + +function Async:cancel() + self.running = false +end + +---@param cb AsyncCallback +function Async:await(cb) + if not cb then + error('callback is required') + end + if self.running then + table.insert(self.callbacks, cb) + else + cb(self.result, self.error) + end +end + +function Async:sync() + while self.running do + vim.wait(10) + end + return self.error and error(self.error) or self.result +end + +--- @return boolean +function async.is_async(obj) + return obj and type(obj) == 'table' and getmetatable(obj) == Async +end + +--- @return fun(...): Async +function async.wrap(fn) + return function(...) + local args = { ... } + return Async.new(function() + return fn(unpack(args)) + end) + end +end + +-- This will yield when called from a coroutine +function async.yield(...) + if not coroutine.isyieldable() then + error('Trying to yield from a non-yieldable context') + return ... + end + return coroutine.yield(...) +end + +function async.abort() + return async.yield('abort') +end + return async diff --git a/lua/cmp/view.lua b/lua/cmp/view.lua index 42f5664..ef79142 100644 --- a/lua/cmp/view.lua +++ b/lua/cmp/view.lua @@ -47,6 +47,7 @@ end ---Open menu ---@param ctx cmp.Context ---@param sources cmp.Source[] +---@return boolean did_open view.open = function(self, ctx, sources) local source_group_map = {} for _, s in ipairs(sources) do @@ -104,6 +105,8 @@ view.open = function(self, ctx, sources) end end end) + local max_item_count = config.get().performance.max_view_entries or 200 + entries = vim.list_slice(entries, 1, max_item_count) -- open if #entries > 0 then @@ -119,6 +122,7 @@ view.open = function(self, ctx, sources) if #entries == 0 then self:close() end + return #entries > 0 end ---Close menu