Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(neotest): nested modules + position updates when switching buffers #223

Merged
merged 5 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 97 additions & 44 deletions lua/rustaceanvim/neotest/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ NeotestAdapter.root = function(file_name)
return cargo.get_root_dir(file_name)
end

-- ---@param name string Name of directory
---@param rel_path string Path to directory, relative to root
-- ---@param root string Root directory of project
---@return boolean
NeotestAdapter.filter_dir = function(_, rel_path, _)
return rel_path ~= 'target'
end

---@param file_path string
---@return boolean
NeotestAdapter.is_test_file = function(file_path)
Expand Down Expand Up @@ -69,80 +77,125 @@ NeotestAdapter.discover_positions = function(file_path)
---@diagnostic disable-next-line: missing-parameter
return lib.positions.parse_tree(positions)
end
---@diagnostic disable-next-line: cast-type-mismatch
---@cast runnables RARunnable[]
-- We need a runnable for the 'file' position, so we pick the first 'namespace' one
-- Typically, a file only has one test module
---@type RARunnable
local crate_runnable = nil

local max_end_row = 0
for _, runnable in pairs(runnables) do
local pos = trans.runnable_to_position(file_path, runnable)
if pos then
max_end_row = math.max(max_end_row, pos.range[3])
if pos.type == 'dir' then
crate_runnable = runnable
else
if pos.type ~= 'dir' then
table.insert(positions, pos)
end
end
end
-- parse_tree expects sorted positions, with the parent first
---@type rustaceanvim.neotest.Position[]
local sorted_positions = {}
---@diagnostic disable-next-line: cast-type-mismatch
---@cast runnables RARunnable[]

---@type { [string]: neotest.Position }
local tests_by_name = {}
-- If there's only one module in a file, we use it as the file runnable
local namespace_runnable
local namespace_count = 0
---@type rustaceanvim.neotest.Position[]
local namespaces = {}
for _, pos in pairs(positions) do
if pos.type == 'test' then
tests_by_name[pos.name] = pos
elseif pos.type == 'namespace' then
table.insert(namespaces, pos)
end
end
local test_names = vim.tbl_keys(tests_by_name)
for _, namespace_pos in pairs(positions) do
if namespace_pos.type == 'namespace' then
namespace_runnable = namespace_runnable or namespace_pos.runnable
namespace_count = namespace_count + 1
table.insert(sorted_positions, namespace_pos)

-- sort namespaces by name from longest to shortest
table.sort(namespaces, function(a, b)
return #a.name > #b.name
end)

---@type { [string]: rustaceanvim.neotest.Position[] }
local positions_by_namespace = {}
-- group tests by their longest matching namespace
for _, namespace in ipairs(namespaces) do
if namespace.name ~= '' then
---@type string[]
local child_keys = vim.tbl_filter(function(name)
return vim.startswith(name, namespace_pos.name)
end, test_names)
for _, key in pairs(child_keys) do
return vim.startswith(name, namespace.name .. '::')
end, vim.tbl_keys(tests_by_name))
local children = { namespace }
for _, key in ipairs(child_keys) do
local child_pos = tests_by_name[key]
--- strip the namespace and "::" from the name so neotest can build the Tree
child_pos.name = child_pos.name:sub(namespace_pos.name:len() + 3, child_pos.name:len())
table.insert(sorted_positions, child_pos)
tests_by_name[key] = nil
--- strip the namespace and "::" from the name
child_pos.name = child_pos.name:sub(#namespace.name + 3, #child_pos.name)
table.insert(children, child_pos)
end
positions_by_namespace[namespace.name] = children
end
end
if namespace_runnable then

-- nest child namespaces in their parent namespace
for i, namespace in ipairs(namespaces) do
---@type rustaceanvim.neotest.Position?
local parent = nil
-- search remaning namespaces for the longest matching parent namespace
for _, other_namespace in ipairs { unpack(namespaces, i + 1) } do
if vim.startswith(namespace.name, other_namespace.name .. '::') then
parent = other_namespace
break
end
end
if parent ~= nil then
local namespace_name = namespace.name
local children = positions_by_namespace[namespace_name]
-- strip parent namespace + "::"
children[1].name = children[1].name:sub(#parent.name + 3, #namespace_name)
table.insert(positions_by_namespace[parent.name], children)
positions_by_namespace[namespace_name] = nil
end
end

local sorted_positions = {}
for _, namespace_positions in pairs(positions_by_namespace) do
table.insert(sorted_positions, namespace_positions)
end
-- any remaning tests had no parent namespace
vim.list_extend(sorted_positions, vim.tbl_values(tests_by_name))

-- sort positions by their start range
local function sort_positions(to_sort)
for _, item in ipairs(to_sort) do
if vim.tbl_islist(item) then
sort_positions(item)
end
end

-- pop header from the list before sorting since it's used to sort in its parent's context
local header = table.remove(to_sort, 1)
table.sort(to_sort, function(a, b)
local a_item = vim.tbl_islist(a) and a[1] or a
local b_item = vim.tbl_islist(b) and b[1] or b
if a_item.range[1] == b_item.range[1] then
return a_item.name < b_item.name
else
return a_item.range[1] < b_item.range[1]
end
end)
table.insert(to_sort, 1, header)
end
sort_positions(sorted_positions)

if #namespaces > 0 then
local file_pos = {
id = file_path,
name = file_path,
name = vim.fn.fnamemodify(file_path, ':t'),
type = 'file',
path = file_path,
range = { 0, 0, max_end_row, 0 },
runnable = namespace_runnable,
-- use the shortest namespace for the file runnable
runnable = namespaces[#namespaces].runnable,
}
table.insert(sorted_positions, 1, file_pos)
end
if crate_runnable and #sorted_positions > 0 then
-- Only insert a crate runnable position if there exist positions
local crate_pos = {
id = 'rustaceanvim:' .. crate_runnable.args.workspaceRoot,
name = 'suite',
type = 'dir',
path = crate_runnable.args.workspaceRoot,
range = { 0, 0, 0, 0 },
runnable = crate_runnable,
}
table.insert(sorted_positions, 1, crate_pos)
end
---@diagnostic disable-next-line: missing-parameter
return lib.positions.parse_tree(sorted_positions)

return require('neotest.types.tree').from_list(sorted_positions, function(x)
return x.name
end)
end

---@class rustaceanvim.neotest.RunSpec: neotest.RunSpec
Expand Down
7 changes: 7 additions & 0 deletions lua/rustaceanvim/neotest/trans.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ function M.runnable_to_position(file_path, runnable)
end_col = location.targetRange['end'].character
end
local test_path = get_test_path(runnable)
-- strip the file module prefix from the name
if test_path then
local mod_name = vim.fn.fnamemodify(file_path, ':t:r')
if vim.startswith(test_path, mod_name .. '::') then
test_path = test_path:sub(#mod_name + 3, #test_path)
end
end
---@type rustaceanvim.neotest.Position
local pos = {
id = M.get_position_id(file_path, runnable),
Expand Down
Loading