Skip to content

Commit

Permalink
feat: allow extensions to override snapshot equality check (#548)
Browse files Browse the repository at this point in the history
  • Loading branch information
Noah authored Jan 14, 2022
1 parent ee8edaa commit a44f1b9
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ See examples of how syrupy can be used and extended in the [test examples](https
- [Custom snapshot directory](https://github.com/tophat/syrupy/tree/master/tests/examples/test_custom_snapshot_directory.py)
- [Custom snapshot name](https://github.com/tophat/syrupy/tree/master/tests/examples/test_custom_snapshot_name.py)
- [Custom object snapshots](https://github.com/tophat/syrupy/tree/master/tests/examples/test_custom_object_repr.py)
- [Custom comparator](https://github.com/tophat/syrupy/tree/master/tests/integration/test_custom_comparator.py)
- [JPEG image extension](https://github.com/tophat/syrupy/tree/master/tests/examples/test_custom_image_extension.py)
- [Built-in image extensions](https://github.com/tophat/syrupy/blob/master/tests/syrupy/extensions/image/test_image_svg.py)

Expand Down
4 changes: 3 additions & 1 deletion src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ def _assert(self, data: "SerializableData") -> bool:
try:
snapshot_data = self._recall_data()
serialized_data = self._serialize(data)
matches = snapshot_data is not None and serialized_data == snapshot_data
matches = snapshot_data is not None and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
)
assertion_success = matches
if not matches and self._update_snapshots:
self.extension.write_snapshot(
Expand Down
18 changes: 17 additions & 1 deletion src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,23 @@ def __strip_ends(self, line: str) -> str:
return line.rstrip("".join(self._ends.keys()))


class AbstractSyrupyExtension(SnapshotSerializer, SnapshotFossilizer, SnapshotReporter):
class SnapshotComparator(ABC):
def matches(
self,
*,
serialized_data: "SerializableData",
snapshot_data: "SerializableData",
) -> bool:
"""
Compares serialized data and snapshot data and returns
whether they match.
"""
return bool(serialized_data == snapshot_data)


class AbstractSyrupyExtension(
SnapshotSerializer, SnapshotFossilizer, SnapshotReporter, SnapshotComparator
):
def __init__(self, test_location: "PyTestLocation"):
self._test_location = test_location

Expand Down
84 changes: 84 additions & 0 deletions tests/integration/test_custom_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest


@pytest.fixture
def testcases_initial(testdir):
testdir.makeconftest(
"""
import pytest
import math
from syrupy.extensions.amber import AmberSnapshotExtension
class CustomSnapshotExtension(AmberSnapshotExtension):
def matches(self, *, serialized_data, snapshot_data):
try:
a = float(serialized_data)
b = float(snapshot_data)
return math.isclose(a, b, rel_tol=1e-5)
except:
return False
@pytest.fixture
def snapshot_custom(snapshot):
return snapshot.use_extension(CustomSnapshotExtension)
"""
)
return {
"passed": (
"""
def test_passed_custom(snapshot_custom):
assert snapshot_custom == 3.0
"""
),
"failed": (
"""
def test_passed_custom(snapshot_custom):
# this comment is required or the test breaks
assert snapshot_custom == 4.0
"""
),
}


@pytest.fixture
def generate_snapshots(testdir, testcases_initial):
testdir.makepyfile(test_file=testcases_initial["passed"])
result = testdir.runpytest("-v", "--snapshot-update")
return result, testdir, testcases_initial


def test_generated_snapshots(generate_snapshots):
result = generate_snapshots[0]
result.stdout.re_match_lines((r"1 snapshot generated\."))
assert "snapshots unused" not in result.stdout.str()
assert result.ret == 0


def test_approximate_match(generate_snapshots):
testdir = generate_snapshots[1]
testdir.makepyfile(
test_file="""
def test_passed_custom(snapshot_custom):
assert snapshot_custom == 3.2
"""
)
result = testdir.runpytest("-v")
result.stdout.re_match_lines((r"test_file.py::test_passed_custom PASSED"))
assert result.ret == 0


def test_failed_snapshots(generate_snapshots):
testdir = generate_snapshots[1]
testdir.makepyfile(test_file=generate_snapshots[2]["failed"])
result = testdir.runpytest("-v")
result.stdout.re_match_lines((r"1 snapshot failed\."))
assert result.ret == 1


def test_updated_snapshots(generate_snapshots):
_, testdir, initial = generate_snapshots
testdir.makepyfile(test_file=initial["failed"])
result = testdir.runpytest("-v", "--snapshot-update")
result.stdout.re_match_lines((r"1 snapshot updated\."))
assert result.ret == 0

0 comments on commit a44f1b9

Please sign in to comment.