diff --git a/README.md b/README.md index 6e90407..74dd888 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,9 @@ require("neotest").setup({ is_test_file = function(file_path) ... end, - + -- !!EXPERIMENTAL!! Enable shelling out to `pytest` to discover test + -- instances for files containing a parametrize mark (default: false) + pytest_discover_instances = true, }) } }) diff --git a/lua/neotest-python/init.lua b/lua/neotest-python/init.lua index bb781c6..566a3a6 100644 --- a/lua/neotest-python/init.lua +++ b/lua/neotest-python/init.lua @@ -1,6 +1,7 @@ local async = require("neotest.async") local lib = require("neotest.lib") local base = require("neotest-python.base") +local pytest = require("neotest-python.pytest") local function get_script() local paths = vim.api.nvim_get_runtime_file("neotest.py", true) @@ -15,6 +16,7 @@ end local dap_args local is_test_file = base.is_test_file +local pytest_discover_instances = false local function get_strategy_config(strategy, python, program, args) local config = { @@ -80,8 +82,13 @@ function PythonNeotestAdapter.filter_dir(name) end ---@async ----@return Tree | nil +---@return neotest.Tree | nil function PythonNeotestAdapter.discover_positions(path) + local root = PythonNeotestAdapter.root(path) or vim.loop.cwd() + local python = get_python(root) + local runner = get_runner(python) + + -- Parse the file while pytest is running local query = [[ ;; Match undecorated functions ((function_definition @@ -109,12 +116,15 @@ function PythonNeotestAdapter.discover_positions(path) (#not-has-parent? @namespace.definition decorated_definition) ) ]] - local root = PythonNeotestAdapter.root(path) - local python = get_python(root) - local runner = get_runner(python) - return lib.treesitter.parse_positions(path, query, { + local positions = lib.treesitter.parse_positions(path, query, { require_namespaces = runner == "unittest", }) + + if runner == "pytest" and pytest_discover_instances then + pytest.augment_positions(python, get_script(), path, positions, root) + end + + return positions end ---@async @@ -232,6 +242,9 @@ setmetatable(PythonNeotestAdapter, { if type(opts.dap) == "table" then dap_args = opts.dap end + if opts.pytest_discover_instances ~= nil then + pytest_discover_instances = opts.pytest_discover_instances + end return PythonNeotestAdapter end, }) diff --git a/lua/neotest-python/pytest.lua b/lua/neotest-python/pytest.lua new file mode 100644 index 0000000..d462158 --- /dev/null +++ b/lua/neotest-python/pytest.lua @@ -0,0 +1,101 @@ +local lib = require("neotest.lib") +local logger = require("neotest.logging") + +local M = {} + +---@async +---Add test instances for path in root to positions +---@param positions neotest.Tree +---@param test_params table +local function add_test_instances(positions, test_params) + for _, node in positions:iter_nodes() do + local position = node:data() + if position.type == "test" then + local pos_params = test_params[position.id] or {} + for _, params_str in ipairs(pos_params) do + local new_data = vim.tbl_extend("force", position, { + id = string.format("%s[%s]", position.id, params_str), + name = string.format("%s[%s]", position.name, params_str), + }) + new_data.range = nil + + local new_pos = node:new(new_data, {}, node._key, {}, {}) + node:add_child(new_data.id, new_pos) + end + end + end +end + +---@async +---@param path string +---@return boolean +local function has_parametrize(path) + local query = [[ + ;; Detect parametrize decorators + (decorator + (call + function: + (attribute + attribute: (identifier) @parametrize + (#eq? @parametrize "parametrize")))) + ]] + local content = lib.files.read(path) + local ts_root, lang = lib.treesitter.get_parse_root(path, content, { fast = true }) + local built_query = lib.treesitter.normalise_query(lang, query) + return built_query:iter_matches(ts_root, content)() ~= nil +end + +---@async +---Discover test instances for path (by running script using python) +---@param python string[] +---@param script string +---@param path string +---@param positions neotest.Tree +---@param root string +local function discover_params(python, script, path, positions, root) + local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", path }) + logger.debug("Running test instance discovery:", cmd) + + local test_params = {} + local res, data = lib.process.run(cmd, { stdout = true, stderr = true }) + if res ~= 0 then + logger.warn("Pytest discovery failed") + if data.stderr then + logger.debug(data.stderr) + end + return {} + end + + for line in vim.gsplit(data.stdout, "\n", true) do + local param_index = string.find(line, "[", nil, true) + if param_index then + local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1) + local param_id = string.sub(line, param_index + 1, #line - 1) + + if positions:get_key(test_id) then + if not test_params[test_id] then + test_params[test_id] = { param_id } + else + table.insert(test_params[test_id], param_id) + end + end + end + end + return test_params +end + +---@async +---Launch pytest to discover test instances for path, if configured +---@param python string[] +---@param script string +---@param path string +---@param positions neotest.Tree +---@param root string +function M.augment_positions(python, script, path, positions, root) + if has_parametrize(path) then + local test_params = discover_params(python, script, path, positions, root) + add_test_instances(positions, test_params) + end +end + +return M diff --git a/neotest_python/__init__.py b/neotest_python/__init__.py index a3d1f56..e8c9871 100644 --- a/neotest_python/__init__.py +++ b/neotest_python/__init__.py @@ -41,6 +41,12 @@ def get_adapter(runner: TestRunner) -> NeotestAdapter: def main(argv: List[str]): + if "--pytest-collect" in argv: + argv.remove("--pytest-collect") + from .pytest import collect + collect(argv) + return + args = parser.parse_args(argv) adapter = get_adapter(TestRunner(args.runner)) diff --git a/neotest_python/pytest.py b/neotest_python/pytest.py index f64fb0a..603d82c 100644 --- a/neotest_python/pytest.py +++ b/neotest_python/pytest.py @@ -106,6 +106,7 @@ def pytest_runtest_makereport(self, item: "pytest.Item", call: "pytest.CallInfo" if getattr(item, "callspec", None) is not None: # Parametrized test msg_prefix = f"[{item.callspec.id}] " + pos_id += f"[{item.callspec.id}]" if report.outcome == "failed": exc_repr = report.longrepr # Test fails due to condition outside of test e.g. xfail @@ -176,3 +177,7 @@ def maybe_debugpy_postmortem(excinfo): py_db.stop_on_unhandled_exception(py_db, thread, additional_info, excinfo) finally: additional_info.is_tracing -= 1 + + +def collect(args): + pytest.main(['--collect-only', '-q'] + args)