diff --git a/lua/telescope/algos/fzy.lua b/lua/telescope/algos/fzy.lua new file mode 100644 index 0000000..2b4b8ec --- /dev/null +++ b/lua/telescope/algos/fzy.lua @@ -0,0 +1,191 @@ +-- The fzy matching algorithm +-- +-- by Seth Warn +-- a lua port of John Hawthorn's fzy +-- +-- > fzy tries to find the result the user intended. It does this by favouring +-- > matches on consecutive letters and starts of words. This allows matching +-- > using acronyms or different parts of the path." - J Hawthorn + +local path = require('telescope.path') + +local SCORE_GAP_LEADING = -0.005 +local SCORE_GAP_TRAILING = -0.005 +local SCORE_GAP_INNER = -0.01 +local SCORE_MATCH_CONSECUTIVE = 1.0 +local SCORE_MATCH_SLASH = 0.9 +local SCORE_MATCH_WORD = 0.8 +local SCORE_MATCH_CAPITAL = 0.7 +local SCORE_MATCH_DOT = 0.6 +local SCORE_MAX = math.huge +local SCORE_MIN = -math.huge +local MATCH_MAX_LENGTH = 1024 + +local fzy = {} + +function fzy.has_match(needle, haystack) + needle = string.lower(needle) + haystack = string.lower(haystack) + + local j = 1 + for i = 1, string.len(needle) do + j = string.find(haystack, needle:sub(i, i), j, true) + if not j then + return false + else + j = j + 1 + end + end + + return true +end + +local function is_lower(c) + return c:match("%l") +end + +local function is_upper(c) + return c:match("%u") +end + +local function precompute_bonus(haystack) + local match_bonus = {} + + local last_char = path.separator + for i = 1, string.len(haystack) do + local this_char = haystack:sub(i, i) + if last_char == path.separator then + match_bonus[i] = SCORE_MATCH_SLASH + elseif last_char == "-" or last_char == "_" or last_char == " " then + match_bonus[i] = SCORE_MATCH_WORD + elseif last_char == "." then + match_bonus[i] = SCORE_MATCH_DOT + elseif is_lower(last_char) and is_upper(this_char) then + match_bonus[i] = SCORE_MATCH_CAPITAL + else + match_bonus[i] = 0 + end + + last_char = this_char + end + + return match_bonus +end + +local function compute(needle, haystack, D, M) + local match_bonus = precompute_bonus(haystack) + local n = string.len(needle) + local m = string.len(haystack) + local lower_needle = string.lower(needle) + local lower_haystack = string.lower(haystack) + + -- Because lua only grants access to chars through substring extraction, + -- get all the characters from the haystack once now, to reuse below. + local haystack_chars = {} + for i = 1, m do + haystack_chars[i] = lower_haystack:sub(i, i) + end + + for i=1,n do + D[i] = {} + M[i] = {} + + local prev_score = SCORE_MIN + local gap_score = i == n and SCORE_GAP_TRAILING or SCORE_GAP_INNER + local needle_char = lower_needle:sub(i, i) + + for j = 1, m do + if needle_char == haystack_chars[j] then + local score = SCORE_MIN + if i == 1 then + score = ((j - 1) * SCORE_GAP_LEADING) + match_bonus[j] + elseif j > 1 then + local a = M[i - 1][j - 1] + match_bonus[j] + local b = D[i - 1][j - 1] + SCORE_MATCH_CONSECUTIVE + score = math.max(a, b) + end + D[i][j] = score + prev_score = math.max(score, prev_score + gap_score) + M[i][j] = prev_score + else + D[i][j] = SCORE_MIN + prev_score = prev_score + gap_score + M[i][j] = prev_score + end + end + end +end + +function fzy.score(needle, haystack) + local n = string.len(needle) + local m = string.len(haystack) + + if n == 0 or m == 0 or m > MATCH_MAX_LENGTH or n > MATCH_MAX_LENGTH then + return SCORE_MIN + elseif n == m then + return SCORE_MAX + else + local D = {} + local M = {} + compute(needle, haystack, D, M) + return M[n][m] + end +end + +function fzy.positions(needle, haystack) + local n = string.len(needle) + local m = string.len(haystack) + + if n == 0 or m == 0 or m > MATCH_MAX_LENGTH or n > MATCH_MAX_LENGTH then + return {} + elseif n == m then + local consecutive = {} + for i=1,n do + consecutive[i] = i + end + return consecutive + end + + local D = {} + local M = {} + compute(needle, haystack, D, M) + + local positions = {} + local match_required = false + local j = m + for i=n,1,-1 do + while j >= 1 do + if D[i][j] ~= SCORE_MIN and (match_required or D[i][j] == M[i][j]) then + match_required = (i ~= 1) and (j ~= 1) and ( + M[i][j] == D[i - 1][j - 1] + SCORE_MATCH_CONSECUTIVE) + positions[i] = j + j = j - 1 + break + else + j = j - 1 + end + end + end + + return positions +end + +-- If strings a or b are empty or too long, `fzy.score(a, b) == fzy.get_score_min()`. +function fzy.get_score_min() + return SCORE_MIN +end + +-- For exact matches, `fzy.score(s, s) == fzy.get_score_max()`. +function fzy.get_score_max() + return SCORE_MAX +end + +-- For all strings a and b that +-- - are not covered by either `fzy.get_score_min()` or fzy.get_score_max()`, and +-- - are matched, such that `fzy.has_match(a, b) == true`, +-- then `fzy.score(a, b) > fzy.get_score_floor()` will be true. +function fzy.get_score_floor() + return (MATCH_MAX_LENGTH + 1) * SCORE_GAP_INNER +end + +return fzy diff --git a/lua/telescope/sorters.lua b/lua/telescope/sorters.lua index 4088ce0..4c268e2 100644 --- a/lua/telescope/sorters.lua +++ b/lua/telescope/sorters.lua @@ -307,6 +307,45 @@ sorters.fuzzy_with_index_bias = function(opts) } end +-- Sorter using the fzy algorithm +sorters.get_fzy_sorter = function() + local fzy = require('telescope.algos.fzy') + local OFFSET = -fzy.get_score_floor() + + return sorters.Sorter:new{ + scoring_function = function(_, prompt, line) + -- Check for actual matches before running the scoring alogrithm. + if not fzy.has_match(prompt, line) then + return -1 + end + + local fzy_score = fzy.score(prompt, line) + + -- The fzy score is -inf for empty queries and overlong strings. Since + -- this function converts all scores into the range (0, 1), we can + -- convert these to 1 as a suitable "worst score" value. + if fzy_score == fzy.get_score_min() then + return 1 + end + + -- Poor non-empty matches can also have negative values. Offset the score + -- so that all values are positive, then invert to match the + -- telescope.Sorter "smaller is better" convention. Note that for exact + -- matches, fzy returns +inf, which when inverted becomes 0. + return 1 / (fzy_score + OFFSET) + end, + + -- The fzy.positions function, which returns an array of string indices, is + -- compatible with telescope's conventions. It's moderately wasteful to + -- call call fzy.score(x,y) followed by fzy.positions(x,y): both call the + -- fzy.compute function, which does all the work. But, this doesn't affect + -- perceived performance. + highlighter = function(_, prompt, display) + return fzy.positions(prompt, display) + end, + } +end + -- Bad & Dumb Sorter sorters.get_levenshtein_sorter = function() return Sorter:new { diff --git a/lua/tests/automated/telescope_spec.lua b/lua/tests/automated/telescope_spec.lua index bf296f5..933cfc5 100644 --- a/lua/tests/automated/telescope_spec.lua +++ b/lua/tests/automated/telescope_spec.lua @@ -268,6 +268,102 @@ describe('Sorters', function() end) end) + describe('fzy', function() + local sorter = require'telescope.sorters'.get_fzy_sorter() + local function score(prompt, line) + return sorter:score(prompt, {ordinal = line}) + end + + describe("matches", function() + it("exact matches", function() + assert.True(score("a", "a") >= 0) + assert.True(score("a.bb", "a.bb") >= 0) + end) + it("ignore case", function() + assert.True(score("AbB", "abb") >= 0) + assert.True(score("abb", "ABB") >= 0) + end) + it("partial matches", function() + assert.True(score("a", "ab") >= 0) + assert.True(score("a", "ba") >= 0) + assert.True(score("aba", "baabbaab") >= 0) + end) + it("with delimiters between", function() + assert.True(score("abc", "a|b|c") >= 0) + end) + it("with empty query", function() + assert.True(score("", "") >= 0) + assert.True(score("", "a") >= 0) + end) + it("rejects non-matches", function() + assert.True(score("a", "") < 0) + assert.True(score("a", "b") < 0) + assert.True(score("aa", "a") < 0) + assert.True(score("ba", "a") < 0) + assert.True(score("ab", "a") < 0) + end) + end) + + describe("scoring", function() + it("prefers beginnings of words", function() + assert.True(score("amor", "app/models/order") < score("amor", "app/models/zrder")) + end) + it("prefers consecutive letters", function() + assert.True(score("amo", "app/models/foo") < score("amo", "app/m/foo")) + assert.True(score("erf", "perfect") < score("erf", "terrific")) + end) + it("prefers contiguous over letter following period", function() + assert.True(score("gemfil", "Gemfile") < score("gemfil", "Gemfile.lock")) + end) + it("prefers shorter matches", function() + assert.True(score("abce", "abcdef") < score("abce", "abc de")); + assert.True(score("abc", " a b c ") < score("abc", " a b c ")); + assert.True(score("abc", " a b c ") < score("abc", " a b c ")); + end) + it("prefers shorter candidates", function() + assert.True(score("test", "tests") < score("test", "testing")) + end) + it("prefers matches at the beginning", function() + assert.True(score("ab", "abbb") < score("ab", "babb")) + assert.True(score("test", "testing") < score("test", "/testing")) + end) + it("prefers matches at some locations", function() + assert.True(score("a", "/a") < score("a", "ba")) + assert.True(score("a", "bA") < score("a", "ba")) + assert.True(score("a", ".a") < score("a", "ba")) + end) + end) + + local function positions(prompt, line) + return sorter:highlighter(prompt, line) + end + + describe("positioning", function() + it("favors consecutive positions", function() + assert.same({1, 5, 6}, positions("amo", "app/models/foo")) + end) + it("favors word beginnings", function() + assert.same({1, 5, 12, 13}, positions("amor", "app/models/order")) + end) + it("works when there are no bonuses", function() + assert.same({2, 4}, positions("as", "tags")) + assert.same({3, 8}, positions("as", "examples.txt")) + end) + it("favors smaller groupings of positions", function() + assert.same({3, 5, 7}, positions("abc", "a/a/b/c/c")) + assert.same({3, 5}, positions("ab", "caacbbc")) + end) + it("handles exact matches", function() + assert.same({1, 2, 3}, positions("foo", "foo")) + end) + it("ignores empty requests", function() + assert.same({}, positions("", "")) + assert.same({}, positions("", "foo")) + assert.same({}, positions("foo", "")) + end) + end) + end) + describe('layout_strategies', function() describe('center', function() it('should handle large terminals', function()