Skip to content

Commit

Permalink
refactor: simplify data serializer for ambr (#676)
Browse files Browse the repository at this point in the history
* refactor: simplify data serializer for ambr

* feat: introduce concept of a tainted snapshot

BREAKING CHANGE: Serializers may now throw a TaintedSnapshotError which will tell the user to regenerate the snapshot even if the underlying data has not changed. This is to support rolling out more subtle changes to the serializers, such as the introduction of serializer metadata.

BREAKING CHANGE: Renamed DataSerializer to AmberDataSerializer.
  • Loading branch information
noahnu authored Jan 26, 2023
1 parent 69f04ab commit 3d296e1
Show file tree
Hide file tree
Showing 20 changed files with 209 additions and 87 deletions.
96 changes: 62 additions & 34 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
Dict,
List,
Optional,
Tuple,
Type,
)

from .exceptions import SnapshotDoesNotExist
from .exceptions import (
SnapshotDoesNotExist,
TaintedSnapshotError,
)
from .extensions.amber.serializer import Repr

if TYPE_CHECKING:
Expand Down Expand Up @@ -125,13 +129,15 @@ def __repr(self) -> "SerializableData":
SnapshotAssertionRepr = namedtuple( # type: ignore
"SnapshotAssertion", ["name", "num_executions"]
)
assertion_result = self.executions.get(
(self._custom_index and self._execution_name_index.get(self._custom_index))
or self.num_executions - 1
)
execution_index = (
self._custom_index and self._execution_name_index.get(self._custom_index)
) or self.num_executions - 1
assertion_result = self.executions.get(execution_index)
return (
Repr(str(assertion_result.final_data))
if assertion_result
if execution_index in self.executions
and assertion_result
and assertion_result.final_data is not None
else SnapshotAssertionRepr(
name=self.name,
num_executions=self.num_executions,
Expand Down Expand Up @@ -179,15 +185,23 @@ def _serialize(self, data: "SerializableData") -> "SerializedData":
def get_assert_diff(self) -> List[str]:
assertion_result = self._execution_results[self.num_executions - 1]
if assertion_result.exception:
lines = [
line
for lines in traceback.format_exception(
assertion_result.exception.__class__,
assertion_result.exception,
assertion_result.exception.__traceback__,
)
for line in lines.splitlines()
]
if isinstance(assertion_result.exception, (TaintedSnapshotError,)):
lines = [
gettext(
"This snapshot needs to be regenerated. "
"This is typically due to a major Syrupy update."
)
]
else:
lines = [
line
for lines in traceback.format_exception(
assertion_result.exception.__class__,
assertion_result.exception,
assertion_result.exception.__traceback__,
)
for line in lines.splitlines()
]
# Rotate to place exception with message at first line
return lines[-1:] + lines[:-1]
snapshot_data = assertion_result.recalled_data
Expand Down Expand Up @@ -232,7 +246,7 @@ def __call__(
return self

def __repr__(self) -> str:
return str(self._serialize(self.__repr))
return str(self.__repr)

def __eq__(self, other: "SerializableData") -> bool:
return self._assert(other)
Expand All @@ -250,29 +264,36 @@ def _assert(self, data: "SerializableData") -> bool:
assertion_success = False
assertion_exception = None
try:
snapshot_data = self._recall_data(index=self.index)
snapshot_data, tainted = self._recall_data(index=self.index)
serialized_data = self._serialize(data)
snapshot_diff = getattr(self, "_snapshot_diff", None)
if snapshot_diff is not None:
snapshot_data_diff = self._recall_data(index=snapshot_diff)
snapshot_data_diff, _ = self._recall_data(index=snapshot_diff)
if snapshot_data_diff is None:
raise SnapshotDoesNotExist()
serialized_data = self.extension.diff_snapshots(
serialized_data=serialized_data,
snapshot_data=snapshot_data_diff,
)
matches = snapshot_data is not None and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
matches = (
not tainted
and snapshot_data is not None
and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
)
)
assertion_success = matches
if not matches and self.update_snapshots:
self.session.queue_snapshot_write(
extension=self.extension,
test_location=self.test_location,
data=serialized_data,
index=self.index,
)
assertion_success = True
if not matches:
if self.update_snapshots:
self.session.queue_snapshot_write(
extension=self.extension,
test_location=self.test_location,
data=serialized_data,
index=self.index,
)
assertion_success = True
elif tainted:
raise TaintedSnapshotError
return assertion_success
except Exception as e:
assertion_exception = e
Expand Down Expand Up @@ -301,12 +322,19 @@ def _post_assert(self) -> None:
while self._post_assert_actions:
self._post_assert_actions.pop()()

def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
def _recall_data(
self, index: "SnapshotIndex"
) -> Tuple[Optional["SerializableData"], bool]:
try:
return self.extension.read_snapshot(
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
return (
self.extension.read_snapshot(
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
),
False,
)
except SnapshotDoesNotExist:
return None
return None, False
except TaintedSnapshotError as e:
return e.snapshot_data, True
5 changes: 5 additions & 0 deletions src/syrupy/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
class Snapshot:
name: str
data: Optional["SerializedData"] = None
# A tainted snapshot needs to be regenerated
tainted: Optional[bool] = field(default=None)


@dataclass(frozen=True)
Expand All @@ -42,6 +44,9 @@ class SnapshotCollection:
location: str
_snapshots: Dict[str, "Snapshot"] = field(default_factory=dict)

# A tainted collection needs to be regenerated
tainted: Optional[bool] = field(default=None)

@property
def has_snapshots(self) -> bool:
return bool(self._snapshots)
Expand Down
15 changes: 15 additions & 0 deletions src/syrupy/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from typing import Optional

from syrupy.types import SerializedData


class SnapshotDoesNotExist(Exception):
"""Snapshot does not exist"""


class FailedToLoadModuleMember(Exception):
"""Failed to load specific member in a module"""


class TaintedSnapshotError(Exception):
"""The snapshot needs to be regenerated."""

snapshot_data: Optional["SerializedData"]

def __init__(self, snapshot_data: Optional["SerializedData"] = None) -> None:
super().__init__()
self.snapshot_data = snapshot_data
23 changes: 14 additions & 9 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
)

from syrupy.data import SnapshotCollection
from syrupy.exceptions import TaintedSnapshotError
from syrupy.extensions.base import AbstractSyrupyExtension

from .serializer import DataSerializer
from .serializer import AmberDataSerializer

if TYPE_CHECKING:
from syrupy.types import SerializableData
Expand All @@ -28,29 +29,29 @@ def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
Returns the serialized form of 'data' to be compared
with the snapshot data written to disk.
"""
return DataSerializer.serialize(data, **kwargs)
return AmberDataSerializer.serialize(data, **kwargs)

def delete_snapshots(
self, snapshot_location: str, snapshot_names: Set[str]
) -> None:
snapshot_collection_to_update = DataSerializer.read_file(snapshot_location)
snapshot_collection_to_update = AmberDataSerializer.read_file(snapshot_location)
for snapshot_name in snapshot_names:
snapshot_collection_to_update.remove(snapshot_name)

if snapshot_collection_to_update.has_snapshots:
DataSerializer.write_file(snapshot_collection_to_update)
AmberDataSerializer.write_file(snapshot_collection_to_update)
else:
Path(snapshot_location).unlink()

def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)
return AmberDataSerializer.read_file(snapshot_location)

@staticmethod
@lru_cache()
def __cacheable_read_snapshot(
snapshot_location: str, cache_key: str
) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)
return AmberDataSerializer.read_file(snapshot_location)

def _read_snapshot_data_from_location(
self, snapshot_location: str, snapshot_name: str, session_id: str
Expand All @@ -59,13 +60,17 @@ def _read_snapshot_data_from_location(
snapshot_location=snapshot_location, cache_key=session_id
)
snapshot = snapshots.get(snapshot_name)
return snapshot.data if snapshot else None
tainted = bool(snapshots.tainted or (snapshot and snapshot.tainted))
data = snapshot.data if snapshot else None
if tainted:
raise TaintedSnapshotError(snapshot_data=data)
return data

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
) -> None:
DataSerializer.write_file(snapshot_collection, merge=True)
AmberDataSerializer.write_file(snapshot_collection, merge=True)


__all__ = ["AmberSnapshotExtension", "DataSerializer"]
__all__ = ["AmberSnapshotExtension", "AmberDataSerializer"]
Loading

1 comment on commit 3d296e1

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 3d296e1 Previous: 02abef5 Ratio
benchmarks/test_1000x.py::test_1000x_reads 0.8414131647996508 iter/sec (stddev: 0.04099141291505383) 0.8381195242511715 iter/sec (stddev: 0.04240394140227035) 1.00
benchmarks/test_1000x.py::test_1000x_writes 0.8196875242073718 iter/sec (stddev: 0.047827666637548824) 0.8626650008455868 iter/sec (stddev: 0.05153168408309042) 1.05
benchmarks/test_standard.py::test_standard 0.7870675206649624 iter/sec (stddev: 0.0533161189311347) 0.7465173870618954 iter/sec (stddev: 0.1502009356924296) 0.95

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.