feat: cmp async (#1583)

This commit is contained in:
Folke Lemaitre
2023-05-25 19:46:53 +02:00
committed by GitHub
parent 950d0e3a93
commit abb5c7519d
9 changed files with 220 additions and 51 deletions

View File

@@ -19,6 +19,8 @@ return function()
debounce = 60,
throttle = 30,
fetching_timeout = 500,
async_budget = 1,
max_view_entries = 200,
},
preselect = types.cmp.PreselectMode.Item,
@@ -88,7 +90,10 @@ return function()
},
view = {
entries = { name = 'custom', selection_order = 'top_down' },
entries = {
name = 'custom',
selection_order = 'top_down',
},
},
window = {

View File

@@ -16,6 +16,7 @@ local api = require('cmp.utils.api')
---@field public cursor_line string
---@field public cursor_after_line string
---@field public cursor_before_line string
---@field public aborted boolean
local context = {}
---Create new empty context
@@ -55,9 +56,14 @@ context.new = function(prev_context, option)
self.cursor.character = misc.to_utfindex(self.cursor_line, self.cursor.col)
self.cursor_before_line = string.sub(self.cursor_line, 1, self.cursor.col - 1)
self.cursor_after_line = string.sub(self.cursor_line, self.cursor.col)
self.aborted = false
return self
end
context.abort = function(self)
self.aborted = true
end
---Return context creation reason.
---@return cmp.ContextReason
context.get_reason = function(self)

View File

@@ -56,6 +56,7 @@ end
---@param option? cmp.ContextOption
---@return cmp.Context
core.get_context = function(self, option)
self.context:abort()
local prev = self.context:clone()
prev.prev_context = nil
prev.cache = nil
@@ -296,7 +297,7 @@ core.complete = function(self, ctx)
end
---Update completion menu
core.filter = async.throttle(function(self)
local async_filter = async.wrap(function(self)
self.filter.timeout = config.get().performance.throttle
-- Check invalid condition.
@@ -323,20 +324,17 @@ core.filter = async.throttle(function(self)
local ctx = self:get_context()
-- Display completion results.
self.view:open(ctx, sources)
local did_open = self.view:open(ctx, sources)
local fetching = #self:get_sources(function(s)
return s.status == source.SourceStatus.FETCHING
end)
-- Check onetime config.
if #self:get_sources(function(s)
if s.status == source.SourceStatus.FETCHING then
return true
elseif #s:get_entries(ctx) > 0 then
return true
end
return false
end) == 0 then
if not did_open and fetching == 0 then
config.set_onetime({})
end
end, config.get().performance.throttle)
end)
core.filter = async.throttle(async_filter, config.get().performance.throttle)
---Confirm completion.
---@param e cmp.Entry

View File

@@ -376,7 +376,10 @@ entry.match = function(self, input, matching_config)
}
local score, matches, filter_text, _
local checked = {} ---@type string[]
filter_text = self:get_filter_text()
checked[filter_text] = true
score, matches = matcher.match(input, filter_text, option)
-- Support the language server that doesn't respect VSCode's behaviors.
@@ -390,16 +393,21 @@ entry.match = function(self, input, matching_config)
accept = accept or string.find(self:get_completion_item().textEdit.newText, prefix, 1, true)
if accept then
filter_text = prefix .. self:get_filter_text()
score, matches = matcher.match(input, filter_text, option)
if not checked[filter_text] then
checked[filter_text] = true
score, matches = matcher.match(input, filter_text, option)
end
end
end
end
end
local vim_item = self:get_vim_item(self:get_offset())
if filter_text ~= vim_item.abbr then
if score == 0 then
local vim_item = self:get_vim_item(self:get_offset())
filter_text = vim_item.abbr or vim_item.word
_, matches = matcher.match(input, filter_text, option)
if not checked[filter_text] then
_, matches = matcher.match(input, filter_text, option)
end
end
return { score = score, matches = matches }

View File

@@ -18,7 +18,6 @@ local char = require('cmp.utils.char')
---@field public incomplete boolean
---@field public is_triggered_by_symbol boolean
---@field public entries cmp.Entry[]
---@field public filtered {entries: cmp.Entry[], ctx: cmp.Context}
---@field public offset integer
---@field public request_offset integer
---@field public context cmp.Context
@@ -54,7 +53,6 @@ source.reset = function(self)
self.is_triggered_by_symbol = false
self.incomplete = false
self.entries = {}
self.filtered = {}
self.offset = -1
self.request_offset = -1
self.completion_context = nil
@@ -90,28 +88,26 @@ source.get_entries = function(self, ctx)
return {}
end
if self.filtered.ctx and self.filtered.ctx.id == ctx.id then
return self.filtered.entries
end
local target_entries = self.entries
local target_entries = (function()
local key = { 'get_entries', self.revision }
for i = ctx.cursor.col, self.offset, -1 do
key[3] = string.sub(ctx.cursor_before_line, 1, i)
local prev_entries = self.cache:get(key)
if prev_entries then
return prev_entries
end
local prev = self.cache:get({ 'get_entries', self.revision })
if prev and ctx.cursor.row == prev.ctx.cursor.row then
if ctx.cursor.col == prev.ctx.cursor.col then
return prev.entries
end
return self.entries
end)()
-- only use prev entries when cursor is moved forward.
-- and the pattern offset is the same.
if ctx.cursor.col >= prev.ctx.cursor.col and ctx.offset == prev.ctx.offset then
target_entries = prev.entries
end
end
local entry_filter = self:get_entry_filter()
local inputs = {}
---@type cmp.Entry[]
local entries = {}
local max_item_count = self:get_source_config().max_item_count or 200
local matching_config = self:get_matching_config()
for _, e in ipairs(target_entries) do
local o = e:get_offset()
@@ -128,20 +124,16 @@ source.get_entries = function(self, ctx)
if entry_filter(e, ctx) then
entries[#entries + 1] = e
if max_item_count and #entries >= max_item_count then
break
end
end
end
async.yield()
if ctx.aborted then
async.abort()
end
end
-- only save to cache, when there are no additional entries that could match the filter
-- This also prevents too much memory usage
if #entries < max_item_count then
self.cache:set({ 'get_entries', tostring(self.revision), ctx.cursor_before_line }, entries)
end
self.cache:set({ 'get_entries', self.revision }, { entries = entries, ctx = ctx })
self.filtered = { entries = entries, ctx = ctx }
return entries
end
@@ -337,7 +329,10 @@ source.complete = function(self, ctx, callback)
context = ctx,
completion_context = completion_context,
}),
self.complete_dedup(vim.schedule_wrap(function(response)
self.complete_dedup(function(response)
if self.context ~= ctx then
return
end
---@type lsp.CompletionResponse
response = response or {}
@@ -358,7 +353,7 @@ source.complete = function(self, ctx, callback)
end
end
self.revision = self.revision + 1
if #self:get_entries(ctx) == 0 then
if #self.entries == 0 then
self.offset = old_offset
self.entries = old_entries
self.revision = self.revision + 1
@@ -372,7 +367,7 @@ source.complete = function(self, ctx, callback)
self.status = prev_status
end
callback()
end))
end)
)
return true
end

View File

@@ -99,6 +99,8 @@ cmp.ItemField = {
---@field public debounce integer
---@field public throttle integer
---@field public fetching_timeout integer
---@field public async_budget integer Maximum time (in ms) an async function is allowed to run during one step of the event loop.
---@field public max_view_entries integer
---@class cmp.WindowConfig
---@field completion cmp.WindowConfig

View File

@@ -1,4 +1,5 @@
local feedkeys = require('cmp.utils.feedkeys')
local config = require('cmp.config')
local async = {}
@@ -9,6 +10,7 @@ local async = {}
---@field public stop function
---@field public __call function
---@type uv_timer_t[]
local timers = {}
vim.api.nvim_create_autocmd('VimLeavePre', {
@@ -27,7 +29,8 @@ vim.api.nvim_create_autocmd('VimLeavePre', {
---@return cmp.AsyncThrottle
async.throttle = function(fn, timeout)
local time = nil
local timer = vim.loop.new_timer()
local timer = assert(vim.loop.new_timer())
local _async = nil ---@type Async?
timers[#timers + 1] = timer
return setmetatable({
running = false,
@@ -37,9 +40,15 @@ async.throttle = function(fn, timeout)
return not self.running
end)
end,
stop = function()
time = nil
stop = function(reset_time)
if reset_time ~= false then
time = nil
end
timer:stop()
if _async then
_async:cancel()
_async = nil
end
end,
}, {
__call = function(self, ...)
@@ -50,12 +59,23 @@ async.throttle = function(fn, timeout)
end
self.running = true
timer:stop()
self.stop(false)
timer:start(math.max(1, self.timeout - (vim.loop.now() - time)), 0, function()
vim.schedule(function()
time = nil
fn(unpack(args))
self.running = false
local ret = fn(unpack(args))
if async.is_async(ret) then
---@cast ret Async
_async = ret
_async:await(function(_, error)
self.running = false
if error and error ~= 'abort' then
vim.notify(error, vim.log.levels.ERROR)
end
end)
else
self.running = false
end
end)
end)
end,
@@ -147,4 +167,124 @@ async.debounce_next_tick_by_keymap = function(callback)
end
end
local Scheduler = {}
Scheduler._queue = {}
Scheduler._executor = assert(vim.loop.new_check())
function Scheduler.step()
local budget = config.get().performance.async_budget * 1e6
local start = vim.loop.hrtime()
while #Scheduler._queue > 0 and vim.loop.hrtime() - start < budget do
local a = table.remove(Scheduler._queue, 1)
a:_step()
if a.running then
table.insert(Scheduler._queue, a)
end
end
if #Scheduler._queue == 0 then
return Scheduler._executor:stop()
end
end
---@param a Async
function Scheduler.add(a)
table.insert(Scheduler._queue, a)
if not Scheduler._executor:is_active() then
Scheduler._executor:start(vim.schedule_wrap(Scheduler.step))
end
end
--- @alias AsyncCallback fun(result?:any, error?:string)
--- @class Async
--- @field running boolean
--- @field result? any
--- @field error? string
--- @field callbacks AsyncCallback[]
--- @field thread thread
local Async = {}
Async.__index = Async
function Async.new(fn)
local self = setmetatable({}, Async)
self.callbacks = {}
self.running = true
self.thread = coroutine.create(fn)
Scheduler.add(self)
return self
end
---@param result? any
---@param error? string
function Async:_done(result, error)
self.running = false
self.result = result
self.error = error
for _, callback in ipairs(self.callbacks) do
callback(result, error)
end
end
function Async:_step()
local ok, res = coroutine.resume(self.thread)
if not ok then
return self:_done(nil, res)
elseif res == 'abort' then
return self:_done(nil, 'abort')
elseif coroutine.status(self.thread) == 'dead' then
return self:_done(res)
end
end
function Async:cancel()
self.running = false
end
---@param cb AsyncCallback
function Async:await(cb)
if not cb then
error('callback is required')
end
if self.running then
table.insert(self.callbacks, cb)
else
cb(self.result, self.error)
end
end
function Async:sync()
while self.running do
vim.wait(10)
end
return self.error and error(self.error) or self.result
end
--- @return boolean
function async.is_async(obj)
return obj and type(obj) == 'table' and getmetatable(obj) == Async
end
--- @return fun(...): Async
function async.wrap(fn)
return function(...)
local args = { ... }
return Async.new(function()
return fn(unpack(args))
end)
end
end
-- This will yield when called from a coroutine
function async.yield(...)
if not coroutine.isyieldable() then
error('Trying to yield from a non-yieldable context')
return ...
end
return coroutine.yield(...)
end
function async.abort()
return async.yield('abort')
end
return async

View File

@@ -47,6 +47,7 @@ end
---Open menu
---@param ctx cmp.Context
---@param sources cmp.Source[]
---@return boolean did_open
view.open = function(self, ctx, sources)
local source_group_map = {}
for _, s in ipairs(sources) do
@@ -104,6 +105,8 @@ view.open = function(self, ctx, sources)
end
end
end)
local max_item_count = config.get().performance.max_view_entries or 200
entries = vim.list_slice(entries, 1, max_item_count)
-- open
if #entries > 0 then
@@ -119,6 +122,7 @@ view.open = function(self, ctx, sources)
if #entries == 0 then
self:close()
end
return #entries > 0
end
---Close menu