diff --git a/pytest_relaxed/classes.py b/pytest_relaxed/classes.py index 10615bc..26edf4e 100644 --- a/pytest_relaxed/classes.py +++ b/pytest_relaxed/classes.py @@ -3,12 +3,15 @@ import six +from pytest import __version__ as pytest_version from pytest import Class, Instance, Module # NOTE: don't see any other way to get access to pytest innards besides using # the underscored name :( from _pytest.python import PyCollector +pytest_version_info = tuple(map(int, pytest_version.split(".")[:3])) + # NOTE: these are defined here for reuse by both pytest's own machinery and our # internal bits. @@ -45,6 +48,13 @@ def istestfunction(self, obj, name): class SpecModule(RelaxedMixin, Module): + @classmethod + def from_parent(cls, parent, fspath): + if pytest_version_info >= (5, 4): + return super(SpecModule, cls).from_parent(parent, fspath=fspath) + else: + return cls(parent=parent, fspath=fspath) + def _is_test_obj(self, test_func, obj, name): # First run our super() test, which should be RelaxedMixin's. good_name = getattr(super(SpecModule, self), test_func)(obj, name) @@ -69,6 +79,7 @@ def collect(self): # Get whatever our parent picked up as valid test items (given our # relaxed name constraints above). It'll be nearly all module contents. items = super(SpecModule, self).collect() + collected = [] for item in items: # Replace Class objects with recursive SpecInstances (via @@ -80,7 +91,7 @@ def collect(self): # them to be handled by pytest's own unittest support) but since # those are almost always in test_prefixed_filenames anyways...meh if isinstance(item, Class): - item = SpecClass(item.name, item.parent) + item = SpecClass.from_parent(item.parent, name=item.name) collected.append(item) return collected @@ -89,6 +100,13 @@ def collect(self): # its lonesome class SpecClass(Class): + @classmethod + def from_parent(cls, parent, name): + if pytest_version_info >= (5, 4): + return super(SpecClass, cls).from_parent(parent, name=name) + else: + return cls(parent=parent, name=name) + def collect(self): items = super(SpecClass, self).collect() collected = [] @@ -96,13 +114,20 @@ def collect(self): # recurse into inner classes. # TODO: is this ever not a one-item list? Meh. for item in items: - item = SpecInstance(name=item.name, parent=item.parent) + item = SpecInstance.from_parent(item.parent, name=item.name) collected.append(item) return collected class SpecInstance(RelaxedMixin, Instance): + @classmethod + def from_parent(cls, parent, name): + if pytest_version_info >= (5, 4): + return super(SpecInstance, cls).from_parent(parent, name=name) + else: + return cls(parent=parent, name=name) + def _getobj(self): # Regular object-making first obj = super(SpecInstance, self)._getobj() @@ -172,5 +197,5 @@ def _makeitem(self, name, obj): # recurse. # TODO: can we unify this with SpecModule's same bits? if isinstance(item, Class): - item = SpecClass(item.name, item.parent) + item = SpecClass.from_parent(item.parent, name=item.name) return item diff --git a/pytest_relaxed/plugin.py b/pytest_relaxed/plugin.py index 0f5a389..8f64358 100644 --- a/pytest_relaxed/plugin.py +++ b/pytest_relaxed/plugin.py @@ -28,7 +28,7 @@ def pytest_collect_file(path, parent): ): # Then use our custom module class which performs modified # function/class selection as well as class recursion - return SpecModule(path, parent) + return SpecModule.from_parent(parent, fspath=path) @pytest.mark.trylast # So we can be sure builtin terminalreporter exists diff --git a/tests/test_display.py b/tests/test_display.py index 5b7f9c8..2884c84 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -1,4 +1,5 @@ from pytest import skip +from pytest import __version__ as pytest_version # Load some fixtures we expose, without actually loading our entire plugin from pytest_relaxed.fixtures import environ # noqa @@ -8,6 +9,8 @@ # (May not be feasible if it has to assume something about how our collection # works?) CLI option (99% sure we can hook into that as a plugin)? +pytest_version_info = tuple(map(int, pytest_version.split(".")[:3])) + def _expect_regular_output(testdir): output = testdir.runpytest().stdout.str() @@ -225,7 +228,14 @@ def behavior_four(self): assert "== FAILURES ==" in output assert "AssertionError" in output # Summary - assert "== 1 failed, 4 passed, 1 skipped in " in output + if pytest_version_info >= (5, 3): + expected_out = ( + "== \x1b[31m\x1b[1m1 failed\x1b[0m, \x1b[32m4 passed\x1b[0m, " + "\x1b[33m1 skipped\x1b[0m\x1b[31m in " + ) + else: + expected_out = "== 1 failed, 4 passed, 1 skipped in " + assert expected_out in output def test_nests_many_levels_deep_no_problem(self, testdir): testdir.makepyfile(