Skip to content

Commit

Permalink
fix(neotest): nested modules + position updates when switching buffers (
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshurst authored Feb 11, 2024
1 parent e016e69 commit f8a33fb
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 44 deletions.
141 changes: 97 additions & 44 deletions lua/rustaceanvim/neotest/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ NeotestAdapter.root = function(file_name)
return cargo.get_root_dir(file_name)
end

-- ---@param name string Name of directory
---@param rel_path string Path to directory, relative to root
-- ---@param root string Root directory of project
---@return boolean
NeotestAdapter.filter_dir = function(_, rel_path, _)
return rel_path ~= 'target'
end

---@param file_path string
---@return boolean
NeotestAdapter.is_test_file = function(file_path)
Expand Down Expand Up @@ -69,80 +77,125 @@ NeotestAdapter.discover_positions = function(file_path)
---@diagnostic disable-next-line: missing-parameter
return lib.positions.parse_tree(positions)
end
---@diagnostic disable-next-line: cast-type-mismatch
---@cast runnables RARunnable[]
-- We need a runnable for the 'file' position, so we pick the first 'namespace' one
-- Typically, a file only has one test module
---@type RARunnable
local crate_runnable = nil

local max_end_row = 0
for _, runnable in pairs(runnables) do
local pos = trans.runnable_to_position(file_path, runnable)
if pos then
max_end_row = math.max(max_end_row, pos.range[3])
if pos.type == 'dir' then
crate_runnable = runnable
else
if pos.type ~= 'dir' then
table.insert(positions, pos)
end
end
end
-- parse_tree expects sorted positions, with the parent first
---@type rustaceanvim.neotest.Position[]
local sorted_positions = {}
---@diagnostic disable-next-line: cast-type-mismatch
---@cast runnables RARunnable[]

---@type { [string]: neotest.Position }
local tests_by_name = {}
-- If there's only one module in a file, we use it as the file runnable
local namespace_runnable
local namespace_count = 0
---@type rustaceanvim.neotest.Position[]
local namespaces = {}
for _, pos in pairs(positions) do
if pos.type == 'test' then
tests_by_name[pos.name] = pos
elseif pos.type == 'namespace' then
table.insert(namespaces, pos)
end
end
local test_names = vim.tbl_keys(tests_by_name)
for _, namespace_pos in pairs(positions) do
if namespace_pos.type == 'namespace' then
namespace_runnable = namespace_runnable or namespace_pos.runnable
namespace_count = namespace_count + 1
table.insert(sorted_positions, namespace_pos)

-- sort namespaces by name from longest to shortest
table.sort(namespaces, function(a, b)
return #a.name > #b.name
end)

---@type { [string]: rustaceanvim.neotest.Position[] }
local positions_by_namespace = {}
-- group tests by their longest matching namespace
for _, namespace in ipairs(namespaces) do
if namespace.name ~= '' then
---@type string[]
local child_keys = vim.tbl_filter(function(name)
return vim.startswith(name, namespace_pos.name)
end, test_names)
for _, key in pairs(child_keys) do
return vim.startswith(name, namespace.name .. '::')
end, vim.tbl_keys(tests_by_name))
local children = { namespace }
for _, key in ipairs(child_keys) do
local child_pos = tests_by_name[key]
--- strip the namespace and "::" from the name so neotest can build the Tree
child_pos.name = child_pos.name:sub(namespace_pos.name:len() + 3, child_pos.name:len())
table.insert(sorted_positions, child_pos)
tests_by_name[key] = nil
--- strip the namespace and "::" from the name
child_pos.name = child_pos.name:sub(#namespace.name + 3, #child_pos.name)
table.insert(children, child_pos)
end
positions_by_namespace[namespace.name] = children
end
end
if namespace_runnable then

-- nest child namespaces in their parent namespace
for i, namespace in ipairs(namespaces) do
---@type rustaceanvim.neotest.Position?
local parent = nil
-- search remaning namespaces for the longest matching parent namespace
for _, other_namespace in ipairs { unpack(namespaces, i + 1) } do
if vim.startswith(namespace.name, other_namespace.name .. '::') then
parent = other_namespace
break
end
end
if parent ~= nil then
local namespace_name = namespace.name
local children = positions_by_namespace[namespace_name]
-- strip parent namespace + "::"
children[1].name = children[1].name:sub(#parent.name + 3, #namespace_name)
table.insert(positions_by_namespace[parent.name], children)
positions_by_namespace[namespace_name] = nil
end
end

local sorted_positions = {}
for _, namespace_positions in pairs(positions_by_namespace) do
table.insert(sorted_positions, namespace_positions)
end
-- any remaning tests had no parent namespace
vim.list_extend(sorted_positions, vim.tbl_values(tests_by_name))

-- sort positions by their start range
local function sort_positions(to_sort)
for _, item in ipairs(to_sort) do
if vim.tbl_islist(item) then
sort_positions(item)
end
end

-- pop header from the list before sorting since it's used to sort in its parent's context
local header = table.remove(to_sort, 1)
table.sort(to_sort, function(a, b)
local a_item = vim.tbl_islist(a) and a[1] or a
local b_item = vim.tbl_islist(b) and b[1] or b
if a_item.range[1] == b_item.range[1] then
return a_item.name < b_item.name
else
return a_item.range[1] < b_item.range[1]
end
end)
table.insert(to_sort, 1, header)
end
sort_positions(sorted_positions)

if #namespaces > 0 then
local file_pos = {
id = file_path,
name = file_path,
name = vim.fn.fnamemodify(file_path, ':t'),
type = 'file',
path = file_path,
range = { 0, 0, max_end_row, 0 },
runnable = namespace_runnable,
-- use the shortest namespace for the file runnable
runnable = namespaces[#namespaces].runnable,
}
table.insert(sorted_positions, 1, file_pos)
end
if crate_runnable and #sorted_positions > 0 then
-- Only insert a crate runnable position if there exist positions
local crate_pos = {
id = 'rustaceanvim:' .. crate_runnable.args.workspaceRoot,
name = 'suite',
type = 'dir',
path = crate_runnable.args.workspaceRoot,
range = { 0, 0, 0, 0 },
runnable = crate_runnable,
}
table.insert(sorted_positions, 1, crate_pos)
end
---@diagnostic disable-next-line: missing-parameter
return lib.positions.parse_tree(sorted_positions)

return require('neotest.types.tree').from_list(sorted_positions, function(x)
return x.name
end)
end

---@class rustaceanvim.neotest.RunSpec: neotest.RunSpec
Expand Down
7 changes: 7 additions & 0 deletions lua/rustaceanvim/neotest/trans.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ function M.runnable_to_position(file_path, runnable)
end_col = location.targetRange['end'].character
end
local test_path = get_test_path(runnable)
-- strip the file module prefix from the name
if test_path then
local mod_name = vim.fn.fnamemodify(file_path, ':t:r')
if vim.startswith(test_path, mod_name .. '::') then
test_path = test_path:sub(#mod_name + 3, #test_path)
end
end
---@type rustaceanvim.neotest.Position
local pos = {
id = M.get_position_id(file_path, runnable),
Expand Down

0 comments on commit f8a33fb

Please sign in to comment.