feat: cmp async (#1583)

This commit is contained in:
Folke Lemaitre
2023-05-25 19:46:53 +02:00
committed by GitHub
parent 950d0e3a93
commit abb5c7519d
9 changed files with 220 additions and 51 deletions

View File

@@ -418,6 +418,17 @@ performance.throttle~
Sets the timeout of candidate fetching process. Sets the timeout of candidate fetching process.
The nvim-cmp will wait to display the most prioritized source. The nvim-cmp will wait to display the most prioritized source.
*cmp-config.performance.async_budget*
performance.async_budget~
`number`
Maximum time (in ms) an async function is allowed to run during
one step of the event loop.
*cmp-config.performance.max_view_entries*
performance.max_view_entries~
`number`
Maximum number of items to show in the entries list.
*cmp-config.preselect* *cmp-config.preselect*
preselect~ preselect~
`cmp.PreselectMode` `cmp.PreselectMode`

View File

@@ -19,6 +19,8 @@ return function()
debounce = 60, debounce = 60,
throttle = 30, throttle = 30,
fetching_timeout = 500, fetching_timeout = 500,
async_budget = 1,
max_view_entries = 200,
}, },
preselect = types.cmp.PreselectMode.Item, preselect = types.cmp.PreselectMode.Item,
@@ -88,7 +90,10 @@ return function()
}, },
view = { view = {
entries = { name = 'custom', selection_order = 'top_down' }, entries = {
name = 'custom',
selection_order = 'top_down',
},
}, },
window = { window = {

View File

@@ -16,6 +16,7 @@ local api = require('cmp.utils.api')
---@field public cursor_line string ---@field public cursor_line string
---@field public cursor_after_line string ---@field public cursor_after_line string
---@field public cursor_before_line string ---@field public cursor_before_line string
---@field public aborted boolean
local context = {} local context = {}
---Create new empty context ---Create new empty context
@@ -55,9 +56,14 @@ context.new = function(prev_context, option)
self.cursor.character = misc.to_utfindex(self.cursor_line, self.cursor.col) self.cursor.character = misc.to_utfindex(self.cursor_line, self.cursor.col)
self.cursor_before_line = string.sub(self.cursor_line, 1, self.cursor.col - 1) self.cursor_before_line = string.sub(self.cursor_line, 1, self.cursor.col - 1)
self.cursor_after_line = string.sub(self.cursor_line, self.cursor.col) self.cursor_after_line = string.sub(self.cursor_line, self.cursor.col)
self.aborted = false
return self return self
end end
context.abort = function(self)
self.aborted = true
end
---Return context creation reason. ---Return context creation reason.
---@return cmp.ContextReason ---@return cmp.ContextReason
context.get_reason = function(self) context.get_reason = function(self)

View File

@@ -56,6 +56,7 @@ end
---@param option? cmp.ContextOption ---@param option? cmp.ContextOption
---@return cmp.Context ---@return cmp.Context
core.get_context = function(self, option) core.get_context = function(self, option)
self.context:abort()
local prev = self.context:clone() local prev = self.context:clone()
prev.prev_context = nil prev.prev_context = nil
prev.cache = nil prev.cache = nil
@@ -296,7 +297,7 @@ core.complete = function(self, ctx)
end end
---Update completion menu ---Update completion menu
core.filter = async.throttle(function(self) local async_filter = async.wrap(function(self)
self.filter.timeout = config.get().performance.throttle self.filter.timeout = config.get().performance.throttle
-- Check invalid condition. -- Check invalid condition.
@@ -323,20 +324,17 @@ core.filter = async.throttle(function(self)
local ctx = self:get_context() local ctx = self:get_context()
-- Display completion results. -- Display completion results.
self.view:open(ctx, sources) local did_open = self.view:open(ctx, sources)
local fetching = #self:get_sources(function(s)
return s.status == source.SourceStatus.FETCHING
end)
-- Check onetime config. -- Check onetime config.
if #self:get_sources(function(s) if not did_open and fetching == 0 then
if s.status == source.SourceStatus.FETCHING then
return true
elseif #s:get_entries(ctx) > 0 then
return true
end
return false
end) == 0 then
config.set_onetime({}) config.set_onetime({})
end end
end, config.get().performance.throttle) end)
core.filter = async.throttle(async_filter, config.get().performance.throttle)
---Confirm completion. ---Confirm completion.
---@param e cmp.Entry ---@param e cmp.Entry

View File

@@ -376,7 +376,10 @@ entry.match = function(self, input, matching_config)
} }
local score, matches, filter_text, _ local score, matches, filter_text, _
local checked = {} ---@type string[]
filter_text = self:get_filter_text() filter_text = self:get_filter_text()
checked[filter_text] = true
score, matches = matcher.match(input, filter_text, option) score, matches = matcher.match(input, filter_text, option)
-- Support the language server that doesn't respect VSCode's behaviors. -- Support the language server that doesn't respect VSCode's behaviors.
@@ -390,17 +393,22 @@ entry.match = function(self, input, matching_config)
accept = accept or string.find(self:get_completion_item().textEdit.newText, prefix, 1, true) accept = accept or string.find(self:get_completion_item().textEdit.newText, prefix, 1, true)
if accept then if accept then
filter_text = prefix .. self:get_filter_text() filter_text = prefix .. self:get_filter_text()
if not checked[filter_text] then
checked[filter_text] = true
score, matches = matcher.match(input, filter_text, option) score, matches = matcher.match(input, filter_text, option)
end end
end end
end end
end end
end
if score == 0 then
local vim_item = self:get_vim_item(self:get_offset()) local vim_item = self:get_vim_item(self:get_offset())
if filter_text ~= vim_item.abbr then
filter_text = vim_item.abbr or vim_item.word filter_text = vim_item.abbr or vim_item.word
if not checked[filter_text] then
_, matches = matcher.match(input, filter_text, option) _, matches = matcher.match(input, filter_text, option)
end end
end
return { score = score, matches = matches } return { score = score, matches = matches }
end) end)

View File

@@ -18,7 +18,6 @@ local char = require('cmp.utils.char')
---@field public incomplete boolean ---@field public incomplete boolean
---@field public is_triggered_by_symbol boolean ---@field public is_triggered_by_symbol boolean
---@field public entries cmp.Entry[] ---@field public entries cmp.Entry[]
---@field public filtered {entries: cmp.Entry[], ctx: cmp.Context}
---@field public offset integer ---@field public offset integer
---@field public request_offset integer ---@field public request_offset integer
---@field public context cmp.Context ---@field public context cmp.Context
@@ -54,7 +53,6 @@ source.reset = function(self)
self.is_triggered_by_symbol = false self.is_triggered_by_symbol = false
self.incomplete = false self.incomplete = false
self.entries = {} self.entries = {}
self.filtered = {}
self.offset = -1 self.offset = -1
self.request_offset = -1 self.request_offset = -1
self.completion_context = nil self.completion_context = nil
@@ -90,28 +88,26 @@ source.get_entries = function(self, ctx)
return {} return {}
end end
if self.filtered.ctx and self.filtered.ctx.id == ctx.id then local target_entries = self.entries
return self.filtered.entries
end
local target_entries = (function() local prev = self.cache:get({ 'get_entries', self.revision })
local key = { 'get_entries', self.revision }
for i = ctx.cursor.col, self.offset, -1 do if prev and ctx.cursor.row == prev.ctx.cursor.row then
key[3] = string.sub(ctx.cursor_before_line, 1, i) if ctx.cursor.col == prev.ctx.cursor.col then
local prev_entries = self.cache:get(key) return prev.entries
if prev_entries then end
return prev_entries -- only use prev entries when cursor is moved forward.
-- and the pattern offset is the same.
if ctx.cursor.col >= prev.ctx.cursor.col and ctx.offset == prev.ctx.offset then
target_entries = prev.entries
end end
end end
return self.entries
end)()
local entry_filter = self:get_entry_filter() local entry_filter = self:get_entry_filter()
local inputs = {} local inputs = {}
---@type cmp.Entry[] ---@type cmp.Entry[]
local entries = {} local entries = {}
local max_item_count = self:get_source_config().max_item_count or 200
local matching_config = self:get_matching_config() local matching_config = self:get_matching_config()
for _, e in ipairs(target_entries) do for _, e in ipairs(target_entries) do
local o = e:get_offset() local o = e:get_offset()
@@ -128,20 +124,16 @@ source.get_entries = function(self, ctx)
if entry_filter(e, ctx) then if entry_filter(e, ctx) then
entries[#entries + 1] = e entries[#entries + 1] = e
if max_item_count and #entries >= max_item_count then
break
end end
end end
async.yield()
if ctx.aborted then
async.abort()
end end
end end
-- only save to cache, when there are no additional entries that could match the filter self.cache:set({ 'get_entries', self.revision }, { entries = entries, ctx = ctx })
-- This also prevents too much memory usage
if #entries < max_item_count then
self.cache:set({ 'get_entries', tostring(self.revision), ctx.cursor_before_line }, entries)
end
self.filtered = { entries = entries, ctx = ctx }
return entries return entries
end end
@@ -337,7 +329,10 @@ source.complete = function(self, ctx, callback)
context = ctx, context = ctx,
completion_context = completion_context, completion_context = completion_context,
}), }),
self.complete_dedup(vim.schedule_wrap(function(response) self.complete_dedup(function(response)
if self.context ~= ctx then
return
end
---@type lsp.CompletionResponse ---@type lsp.CompletionResponse
response = response or {} response = response or {}
@@ -358,7 +353,7 @@ source.complete = function(self, ctx, callback)
end end
end end
self.revision = self.revision + 1 self.revision = self.revision + 1
if #self:get_entries(ctx) == 0 then if #self.entries == 0 then
self.offset = old_offset self.offset = old_offset
self.entries = old_entries self.entries = old_entries
self.revision = self.revision + 1 self.revision = self.revision + 1
@@ -372,7 +367,7 @@ source.complete = function(self, ctx, callback)
self.status = prev_status self.status = prev_status
end end
callback() callback()
end)) end)
) )
return true return true
end end

View File

@@ -99,6 +99,8 @@ cmp.ItemField = {
---@field public debounce integer ---@field public debounce integer
---@field public throttle integer ---@field public throttle integer
---@field public fetching_timeout integer ---@field public fetching_timeout integer
---@field public async_budget integer Maximum time (in ms) an async function is allowed to run during one step of the event loop.
---@field public max_view_entries integer
---@class cmp.WindowConfig ---@class cmp.WindowConfig
---@field completion cmp.WindowConfig ---@field completion cmp.WindowConfig

View File

@@ -1,4 +1,5 @@
local feedkeys = require('cmp.utils.feedkeys') local feedkeys = require('cmp.utils.feedkeys')
local config = require('cmp.config')
local async = {} local async = {}
@@ -9,6 +10,7 @@ local async = {}
---@field public stop function ---@field public stop function
---@field public __call function ---@field public __call function
---@type uv_timer_t[]
local timers = {} local timers = {}
vim.api.nvim_create_autocmd('VimLeavePre', { vim.api.nvim_create_autocmd('VimLeavePre', {
@@ -27,7 +29,8 @@ vim.api.nvim_create_autocmd('VimLeavePre', {
---@return cmp.AsyncThrottle ---@return cmp.AsyncThrottle
async.throttle = function(fn, timeout) async.throttle = function(fn, timeout)
local time = nil local time = nil
local timer = vim.loop.new_timer() local timer = assert(vim.loop.new_timer())
local _async = nil ---@type Async?
timers[#timers + 1] = timer timers[#timers + 1] = timer
return setmetatable({ return setmetatable({
running = false, running = false,
@@ -37,9 +40,15 @@ async.throttle = function(fn, timeout)
return not self.running return not self.running
end) end)
end, end,
stop = function() stop = function(reset_time)
if reset_time ~= false then
time = nil time = nil
end
timer:stop() timer:stop()
if _async then
_async:cancel()
_async = nil
end
end, end,
}, { }, {
__call = function(self, ...) __call = function(self, ...)
@@ -50,12 +59,23 @@ async.throttle = function(fn, timeout)
end end
self.running = true self.running = true
timer:stop() self.stop(false)
timer:start(math.max(1, self.timeout - (vim.loop.now() - time)), 0, function() timer:start(math.max(1, self.timeout - (vim.loop.now() - time)), 0, function()
vim.schedule(function() vim.schedule(function()
time = nil time = nil
fn(unpack(args)) local ret = fn(unpack(args))
if async.is_async(ret) then
---@cast ret Async
_async = ret
_async:await(function(_, error)
self.running = false self.running = false
if error and error ~= 'abort' then
vim.notify(error, vim.log.levels.ERROR)
end
end)
else
self.running = false
end
end) end)
end) end)
end, end,
@@ -147,4 +167,124 @@ async.debounce_next_tick_by_keymap = function(callback)
end end
end end
local Scheduler = {}
Scheduler._queue = {}
Scheduler._executor = assert(vim.loop.new_check())
function Scheduler.step()
local budget = config.get().performance.async_budget * 1e6
local start = vim.loop.hrtime()
while #Scheduler._queue > 0 and vim.loop.hrtime() - start < budget do
local a = table.remove(Scheduler._queue, 1)
a:_step()
if a.running then
table.insert(Scheduler._queue, a)
end
end
if #Scheduler._queue == 0 then
return Scheduler._executor:stop()
end
end
---@param a Async
function Scheduler.add(a)
table.insert(Scheduler._queue, a)
if not Scheduler._executor:is_active() then
Scheduler._executor:start(vim.schedule_wrap(Scheduler.step))
end
end
--- @alias AsyncCallback fun(result?:any, error?:string)
--- @class Async
--- @field running boolean
--- @field result? any
--- @field error? string
--- @field callbacks AsyncCallback[]
--- @field thread thread
local Async = {}
Async.__index = Async
function Async.new(fn)
local self = setmetatable({}, Async)
self.callbacks = {}
self.running = true
self.thread = coroutine.create(fn)
Scheduler.add(self)
return self
end
---@param result? any
---@param error? string
function Async:_done(result, error)
self.running = false
self.result = result
self.error = error
for _, callback in ipairs(self.callbacks) do
callback(result, error)
end
end
function Async:_step()
local ok, res = coroutine.resume(self.thread)
if not ok then
return self:_done(nil, res)
elseif res == 'abort' then
return self:_done(nil, 'abort')
elseif coroutine.status(self.thread) == 'dead' then
return self:_done(res)
end
end
function Async:cancel()
self.running = false
end
---@param cb AsyncCallback
function Async:await(cb)
if not cb then
error('callback is required')
end
if self.running then
table.insert(self.callbacks, cb)
else
cb(self.result, self.error)
end
end
function Async:sync()
while self.running do
vim.wait(10)
end
return self.error and error(self.error) or self.result
end
--- @return boolean
function async.is_async(obj)
return obj and type(obj) == 'table' and getmetatable(obj) == Async
end
--- @return fun(...): Async
function async.wrap(fn)
return function(...)
local args = { ... }
return Async.new(function()
return fn(unpack(args))
end)
end
end
-- This will yield when called from a coroutine
function async.yield(...)
if not coroutine.isyieldable() then
error('Trying to yield from a non-yieldable context')
return ...
end
return coroutine.yield(...)
end
function async.abort()
return async.yield('abort')
end
return async return async

View File

@@ -47,6 +47,7 @@ end
---Open menu ---Open menu
---@param ctx cmp.Context ---@param ctx cmp.Context
---@param sources cmp.Source[] ---@param sources cmp.Source[]
---@return boolean did_open
view.open = function(self, ctx, sources) view.open = function(self, ctx, sources)
local source_group_map = {} local source_group_map = {}
for _, s in ipairs(sources) do for _, s in ipairs(sources) do
@@ -104,6 +105,8 @@ view.open = function(self, ctx, sources)
end end
end end
end) end)
local max_item_count = config.get().performance.max_view_entries or 200
entries = vim.list_slice(entries, 1, max_item_count)
-- open -- open
if #entries > 0 then if #entries > 0 then
@@ -119,6 +122,7 @@ view.open = function(self, ctx, sources)
if #entries == 0 then if #entries == 0 then
self:close() self:close()
end end
return #entries > 0
end end
---Close menu ---Close menu