Skip to content

Commit

Permalink
feat(pytest): populate parameterized test instances (#36)
Browse files Browse the repository at this point in the history
Co-authored-by: Rónán Carrigan <rcarriga@tcd.ie>
  • Loading branch information
OddBloke and rcarriga authored Nov 12, 2023
1 parent 81d2265 commit ff20740
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 6 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ require("neotest").setup({
is_test_file = function(file_path)
...
end,

-- !!EXPERIMENTAL!! Enable shelling out to `pytest` to discover test
-- instances for files containing a parametrize mark (default: false)
pytest_discover_instances = true,
})
}
})
Expand Down
23 changes: 18 additions & 5 deletions lua/neotest-python/init.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local async = require("neotest.async")
local lib = require("neotest.lib")
local base = require("neotest-python.base")
local pytest = require("neotest-python.pytest")

local function get_script()
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
Expand All @@ -15,6 +16,7 @@ end

local dap_args
local is_test_file = base.is_test_file
local pytest_discover_instances = false

local function get_strategy_config(strategy, python, program, args)
local config = {
Expand Down Expand Up @@ -80,8 +82,13 @@ function PythonNeotestAdapter.filter_dir(name)
end

---@async
---@return Tree | nil
---@return neotest.Tree | nil
function PythonNeotestAdapter.discover_positions(path)
local root = PythonNeotestAdapter.root(path) or vim.loop.cwd()
local python = get_python(root)
local runner = get_runner(python)

-- Parse the file while pytest is running
local query = [[
;; Match undecorated functions
((function_definition
Expand Down Expand Up @@ -109,12 +116,15 @@ function PythonNeotestAdapter.discover_positions(path)
(#not-has-parent? @namespace.definition decorated_definition)
)
]]
local root = PythonNeotestAdapter.root(path)
local python = get_python(root)
local runner = get_runner(python)
return lib.treesitter.parse_positions(path, query, {
local positions = lib.treesitter.parse_positions(path, query, {
require_namespaces = runner == "unittest",
})

if runner == "pytest" and pytest_discover_instances then
pytest.augment_positions(python, get_script(), path, positions, root)
end

return positions
end

---@async
Expand Down Expand Up @@ -232,6 +242,9 @@ setmetatable(PythonNeotestAdapter, {
if type(opts.dap) == "table" then
dap_args = opts.dap
end
if opts.pytest_discover_instances ~= nil then
pytest_discover_instances = opts.pytest_discover_instances
end
return PythonNeotestAdapter
end,
})
Expand Down
101 changes: 101 additions & 0 deletions lua/neotest-python/pytest.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
local lib = require("neotest.lib")
local logger = require("neotest.logging")

local M = {}

---@async
---Add test instances for path in root to positions
---@param positions neotest.Tree
---@param test_params table<string, string[]>
local function add_test_instances(positions, test_params)
for _, node in positions:iter_nodes() do
local position = node:data()
if position.type == "test" then
local pos_params = test_params[position.id] or {}
for _, params_str in ipairs(pos_params) do
local new_data = vim.tbl_extend("force", position, {
id = string.format("%s[%s]", position.id, params_str),
name = string.format("%s[%s]", position.name, params_str),
})
new_data.range = nil

local new_pos = node:new(new_data, {}, node._key, {}, {})
node:add_child(new_data.id, new_pos)
end
end
end
end

---@async
---@param path string
---@return boolean
local function has_parametrize(path)
local query = [[
;; Detect parametrize decorators
(decorator
(call
function:
(attribute
attribute: (identifier) @parametrize
(#eq? @parametrize "parametrize"))))
]]
local content = lib.files.read(path)
local ts_root, lang = lib.treesitter.get_parse_root(path, content, { fast = true })
local built_query = lib.treesitter.normalise_query(lang, query)
return built_query:iter_matches(ts_root, content)() ~= nil
end

---@async
---Discover test instances for path (by running script using python)
---@param python string[]
---@param script string
---@param path string
---@param positions neotest.Tree
---@param root string
local function discover_params(python, script, path, positions, root)
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", path })
logger.debug("Running test instance discovery:", cmd)

local test_params = {}
local res, data = lib.process.run(cmd, { stdout = true, stderr = true })
if res ~= 0 then
logger.warn("Pytest discovery failed")
if data.stderr then
logger.debug(data.stderr)
end
return {}
end

for line in vim.gsplit(data.stdout, "\n", true) do
local param_index = string.find(line, "[", nil, true)
if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
local param_id = string.sub(line, param_index + 1, #line - 1)

if positions:get_key(test_id) then
if not test_params[test_id] then
test_params[test_id] = { param_id }
else
table.insert(test_params[test_id], param_id)
end
end
end
end
return test_params
end

---@async
---Launch pytest to discover test instances for path, if configured
---@param python string[]
---@param script string
---@param path string
---@param positions neotest.Tree
---@param root string
function M.augment_positions(python, script, path, positions, root)
if has_parametrize(path) then
local test_params = discover_params(python, script, path, positions, root)
add_test_instances(positions, test_params)
end
end

return M
6 changes: 6 additions & 0 deletions neotest_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def get_adapter(runner: TestRunner) -> NeotestAdapter:


def main(argv: List[str]):
if "--pytest-collect" in argv:
argv.remove("--pytest-collect")
from .pytest import collect
collect(argv)
return

args = parser.parse_args(argv)
adapter = get_adapter(TestRunner(args.runner))

Expand Down
5 changes: 5 additions & 0 deletions neotest_python/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def pytest_runtest_makereport(self, item: "pytest.Item", call: "pytest.CallInfo"
if getattr(item, "callspec", None) is not None:
# Parametrized test
msg_prefix = f"[{item.callspec.id}] "
pos_id += f"[{item.callspec.id}]"
if report.outcome == "failed":
exc_repr = report.longrepr
# Test fails due to condition outside of test e.g. xfail
Expand Down Expand Up @@ -176,3 +177,7 @@ def maybe_debugpy_postmortem(excinfo):
py_db.stop_on_unhandled_exception(py_db, thread, additional_info, excinfo)
finally:
additional_info.is_tracing -= 1


def collect(args):
pytest.main(['--collect-only', '-q'] + args)

0 comments on commit ff20740

Please sign in to comment.