Skip to content

Commit

Permalink
wip: only support pytest functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Noah Negin-Ulster committed Jun 1, 2020
1 parent 7b6b330 commit 4ca0d8d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/syrupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def pytest_sessionstart(session: Any) -> None:
config._syrupy.start()


def pytest_collection_modifyitems(session: Any, config: Any, items: List[Any]) -> None:
def pytest_collection_modifyitems(
session: Any, config: Any, items: List["pytest.Item"]
) -> None:
"""
After tests are collected and before any modification is performed.
https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_collection_modifyitems
Expand Down
10 changes: 5 additions & 5 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Expand All @@ -10,6 +9,7 @@
)

import attr
import pytest

from .constants import EXIT_STATUS_FAIL_UNUSED
from .data import SnapshotFossils
Expand All @@ -29,8 +29,8 @@ class SnapshotSession:
is_providing_paths: bool = attr.ib()
is_providing_nodes: bool = attr.ib()
report: Optional["SnapshotReport"] = attr.ib(default=None)
_all_items: Set[Any] = attr.ib(factory=set)
_ran_items: Set[Any] = attr.ib(factory=set)
_all_items: Set["pytest.Item"] = attr.ib(factory=set)
_ran_items: Set["pytest.Item"] = attr.ib(factory=set)
_assertions: List["SnapshotAssertion"] = attr.ib(factory=list)
_extensions: Dict[str, "AbstractSyrupyExtension"] = attr.ib(factory=dict)

Expand Down Expand Up @@ -92,5 +92,5 @@ def remove_unused_snapshots(
Path(snapshot_location).unlink()

@staticmethod
def filter_valid_items(items: List[Any]) -> Iterable[Any]:
return (item for item in items if hasattr(item, "obj"))
def filter_valid_items(items: List["pytest.Item"]) -> Iterable["pytest.Item"]:
return (item for item in items if isinstance(item, pytest.Function))
3 changes: 3 additions & 0 deletions stubs/pytest.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ from typing import Any, Callable, TypeVar
ReturnType = TypeVar("ReturnType")

def fixture(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: ...

class Function: ...
class Item: ...

0 comments on commit 4ca0d8d

Please sign in to comment.