Skip to content

Commit

Permalink
fix: fix TypedDict test for 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Oct 19, 2023
1 parent 77cbb25 commit ab4a0e5
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 43 deletions.
18 changes: 0 additions & 18 deletions apischema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,6 @@ def is_type_var(tp: Any) -> bool:
return isinstance(tp, TypeVar)


# Don't use sys.version_info because it can also depend of typing_extensions version
def required_keys(typed_dict: Type) -> Collection[str]:
assert is_typed_dict(typed_dict)
if hasattr(typed_dict, "__required_keys__"):
return typed_dict.__required_keys__
else:
required: Set[str] = set()
bases_annotations: Set = set()
for base in typed_dict.__bases__:
if not is_typed_dict(base):
continue
bases_annotations.update(base.__annotations__)
required.update(required_keys(base))
if typed_dict.__total__:
required.update(typed_dict.__annotations__.keys() - bases_annotations)
return required


# py38 get_origin of builtin wrapped generics return the unsubscriptable builtin
# type.
if (3, 8) <= sys.version_info < (3, 9):
Expand Down
6 changes: 2 additions & 4 deletions apischema/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
is_type_var,
is_typed_dict,
is_union,
required_keys,
resolve_type_hints,
)
from apischema.utils import PREFIX, get_origin_or_type, has_type_vars
Expand Down Expand Up @@ -190,9 +189,8 @@ def visit(self, tp: AnyType) -> Result:
types = {f: Any for f in origin._fields} # noqa: E501
return self.named_tuple(origin, types, origin._field_defaults)
if is_typed_dict(origin):
return self.typed_dict(
origin, resolve_type_hints(origin), required_keys(origin)
)
required_keys = getattr(origin, "__required_keys__", ()) # py38
return self.typed_dict(origin, resolve_type_hints(origin), required_keys)
if is_type_var(origin):
if origin.__constraints__:
return self.visit(Union[origin.__constraints__])
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_dict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from datetime import date
from typing import Any, Dict, Mapping, TypedDict

Expand All @@ -8,6 +9,9 @@
from apischema.metadata import flatten
from apischema.typing import Annotated

if sys.version_info < (3, 9):
from typing_extensions import TypedDict # type: ignore


class MyDict(dict):
pass
Expand Down
23 changes: 2 additions & 21 deletions tests/unit/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from types import new_class
from typing import Generic, TypedDict, TypeVar
from unittest.mock import Mock
from typing import Generic, TypeVar

import pytest

from apischema.typing import Annotated, generic_mro, required_keys, resolve_type_hints
from apischema.typing import Annotated, generic_mro, resolve_type_hints

T = TypeVar("T")
U = TypeVar("U")
Expand Down Expand Up @@ -51,20 +49,3 @@ def test_generic_mro(tp, result, _):
@pytest.mark.parametrize("tp, _, result", test_cases)
def test_resolve_type_hints(tp, _, result):
assert resolve_type_hints(tp) == result


def test_required_keys():
_TypedDictMeta = type(new_class("_TypedDictMeta", (TypedDict,)))
td1, td2, td3 = Mock(_TypedDictMeta), Mock(_TypedDictMeta), Mock(_TypedDictMeta)
td1.__annotations__ = {"key": str}
td1.__total__ = False
td1.__bases__ = ()
td2.__annotations__ = {"key": str, "other": int}
td2.__total__ = True
td2.__bases__ = (td1,)
td3.__annotations__ = {"key": str, "other": int, "last": bool}
td3.__total__ = False
td3.__bases__ = (td2, object)
assert required_keys(td1) == set()
assert required_keys(td2) == {"other"}
assert required_keys(td3) == {"other"}

0 comments on commit ab4a0e5

Please sign in to comment.