Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify data serializer for ambr #676

Merged
merged 2 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 62 additions & 34 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
Dict,
List,
Optional,
Tuple,
Type,
)

from .exceptions import SnapshotDoesNotExist
from .exceptions import (
SnapshotDoesNotExist,
TaintedSnapshotError,
)
from .extensions.amber.serializer import Repr

if TYPE_CHECKING:
Expand Down Expand Up @@ -125,13 +129,15 @@ def __repr(self) -> "SerializableData":
SnapshotAssertionRepr = namedtuple( # type: ignore
"SnapshotAssertion", ["name", "num_executions"]
)
assertion_result = self.executions.get(
(self._custom_index and self._execution_name_index.get(self._custom_index))
or self.num_executions - 1
)
execution_index = (
self._custom_index and self._execution_name_index.get(self._custom_index)
) or self.num_executions - 1
assertion_result = self.executions.get(execution_index)
return (
Repr(str(assertion_result.final_data))
if assertion_result
if execution_index in self.executions
and assertion_result
and assertion_result.final_data is not None
else SnapshotAssertionRepr(
name=self.name,
num_executions=self.num_executions,
Expand Down Expand Up @@ -179,15 +185,23 @@ def _serialize(self, data: "SerializableData") -> "SerializedData":
def get_assert_diff(self) -> List[str]:
assertion_result = self._execution_results[self.num_executions - 1]
if assertion_result.exception:
lines = [
line
for lines in traceback.format_exception(
assertion_result.exception.__class__,
assertion_result.exception,
assertion_result.exception.__traceback__,
)
for line in lines.splitlines()
]
if isinstance(assertion_result.exception, (TaintedSnapshotError,)):
lines = [
gettext(
"This snapshot needs to be regenerated. "
"This is typically due to a major Syrupy update."
)
]
else:
lines = [
line
for lines in traceback.format_exception(
assertion_result.exception.__class__,
assertion_result.exception,
assertion_result.exception.__traceback__,
)
for line in lines.splitlines()
]
# Rotate to place exception with message at first line
return lines[-1:] + lines[:-1]
snapshot_data = assertion_result.recalled_data
Expand Down Expand Up @@ -232,7 +246,7 @@ def __call__(
return self

def __repr__(self) -> str:
return str(self._serialize(self.__repr))
return str(self.__repr)

def __eq__(self, other: "SerializableData") -> bool:
return self._assert(other)
Expand All @@ -250,29 +264,36 @@ def _assert(self, data: "SerializableData") -> bool:
assertion_success = False
assertion_exception = None
try:
snapshot_data = self._recall_data(index=self.index)
snapshot_data, tainted = self._recall_data(index=self.index)
serialized_data = self._serialize(data)
snapshot_diff = getattr(self, "_snapshot_diff", None)
if snapshot_diff is not None:
snapshot_data_diff = self._recall_data(index=snapshot_diff)
snapshot_data_diff, _ = self._recall_data(index=snapshot_diff)
if snapshot_data_diff is None:
raise SnapshotDoesNotExist()
serialized_data = self.extension.diff_snapshots(
serialized_data=serialized_data,
snapshot_data=snapshot_data_diff,
)
matches = snapshot_data is not None and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
matches = (
not tainted
and snapshot_data is not None
and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
)
)
assertion_success = matches
if not matches and self.update_snapshots:
self.session.queue_snapshot_write(
extension=self.extension,
test_location=self.test_location,
data=serialized_data,
index=self.index,
)
assertion_success = True
if not matches:
if self.update_snapshots:
self.session.queue_snapshot_write(
extension=self.extension,
test_location=self.test_location,
data=serialized_data,
index=self.index,
)
assertion_success = True
elif tainted:
raise TaintedSnapshotError
return assertion_success
except Exception as e:
assertion_exception = e
Expand Down Expand Up @@ -301,12 +322,19 @@ def _post_assert(self) -> None:
while self._post_assert_actions:
self._post_assert_actions.pop()()

def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
def _recall_data(
self, index: "SnapshotIndex"
) -> Tuple[Optional["SerializableData"], bool]:
try:
return self.extension.read_snapshot(
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
return (
self.extension.read_snapshot(
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
),
False,
)
except SnapshotDoesNotExist:
return None
return None, False
except TaintedSnapshotError as e:
return e.snapshot_data, True
5 changes: 5 additions & 0 deletions src/syrupy/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
class Snapshot:
name: str
data: Optional["SerializedData"] = None
# A tainted snapshot needs to be regenerated
tainted: Optional[bool] = field(default=None)


@dataclass(frozen=True)
Expand All @@ -42,6 +44,9 @@ class SnapshotCollection:
location: str
_snapshots: Dict[str, "Snapshot"] = field(default_factory=dict)

# A tainted collection needs to be regenerated
tainted: Optional[bool] = field(default=None)

@property
def has_snapshots(self) -> bool:
return bool(self._snapshots)
Expand Down
15 changes: 15 additions & 0 deletions src/syrupy/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from typing import Optional

from syrupy.types import SerializedData


class SnapshotDoesNotExist(Exception):
"""Snapshot does not exist"""


class FailedToLoadModuleMember(Exception):
"""Failed to load specific member in a module"""


class TaintedSnapshotError(Exception):
"""The snapshot needs to be regenerated."""

snapshot_data: Optional["SerializedData"]

def __init__(self, snapshot_data: Optional["SerializedData"] = None) -> None:
super().__init__()
self.snapshot_data = snapshot_data
23 changes: 14 additions & 9 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
)

from syrupy.data import SnapshotCollection
from syrupy.exceptions import TaintedSnapshotError
from syrupy.extensions.base import AbstractSyrupyExtension

from .serializer import DataSerializer
from .serializer import AmberDataSerializer

if TYPE_CHECKING:
from syrupy.types import SerializableData
Expand All @@ -28,29 +29,29 @@ def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
Returns the serialized form of 'data' to be compared
with the snapshot data written to disk.
"""
return DataSerializer.serialize(data, **kwargs)
return AmberDataSerializer.serialize(data, **kwargs)

def delete_snapshots(
self, snapshot_location: str, snapshot_names: Set[str]
) -> None:
snapshot_collection_to_update = DataSerializer.read_file(snapshot_location)
snapshot_collection_to_update = AmberDataSerializer.read_file(snapshot_location)
for snapshot_name in snapshot_names:
snapshot_collection_to_update.remove(snapshot_name)

if snapshot_collection_to_update.has_snapshots:
DataSerializer.write_file(snapshot_collection_to_update)
AmberDataSerializer.write_file(snapshot_collection_to_update)
else:
Path(snapshot_location).unlink()

def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)
return AmberDataSerializer.read_file(snapshot_location)

@staticmethod
@lru_cache()
def __cacheable_read_snapshot(
snapshot_location: str, cache_key: str
) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)
return AmberDataSerializer.read_file(snapshot_location)

def _read_snapshot_data_from_location(
self, snapshot_location: str, snapshot_name: str, session_id: str
Expand All @@ -59,13 +60,17 @@ def _read_snapshot_data_from_location(
snapshot_location=snapshot_location, cache_key=session_id
)
snapshot = snapshots.get(snapshot_name)
return snapshot.data if snapshot else None
tainted = bool(snapshots.tainted or (snapshot and snapshot.tainted))
data = snapshot.data if snapshot else None
if tainted:
raise TaintedSnapshotError(snapshot_data=data)
return data

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
) -> None:
DataSerializer.write_file(snapshot_collection, merge=True)
AmberDataSerializer.write_file(snapshot_collection, merge=True)


__all__ = ["AmberSnapshotExtension", "DataSerializer"]
__all__ = ["AmberSnapshotExtension", "AmberDataSerializer"]
Loading