diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index f82b13a..7e00b04 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -61,20 +61,23 @@ class User: TypeVar, Union, cast, - get_type_hints, overload, ) import marshmallow -import typing_extensions import typing_inspect +from marshmallow_dataclass.generic_resolver import ( + UnboundTypeVarError, + get_resolved_dataclass_fields, + is_generic_alias, +) from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute if sys.version_info >= (3, 9): - from typing import Annotated + from typing import Annotated, get_args, get_origin else: - from typing_extensions import Annotated + from typing_extensions import Annotated, get_args, get_origin if sys.version_info >= (3, 11): from typing import dataclass_transform @@ -139,6 +142,18 @@ def _maybe_get_callers_frame( del frame +def _check_decorated_type(cls: object) -> None: + if not isinstance(cls, type): + raise TypeError(f"expected a class not {cls!r}") + if is_generic_alias(cls): + # A .Schema attribute doesn't make sense on a generic alias — there's + # no way for it to know the generic parameters at run time. + raise TypeError( + "decorator does not support generic aliases " + "(hint: use class_schema directly instead)" + ) + + @overload def dataclass( _cls: Type[_U], @@ -214,12 +229,15 @@ def dataclass( ) def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + _check_decorated_type(cls) + return add_schema( dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 ) if _cls is None: return decorator + return decorator(_cls, stacklevel=stacklevel + 1) @@ -268,6 +286,8 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): """ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]: + _check_decorated_type(clazz) + if cls_frame is not None: frame = cls_frame else: @@ -453,7 +473,7 @@ def class_schema( >>> class_schema(Custom)().load({}) Custom(name=None) """ - if not dataclasses.is_dataclass(clazz): + if not dataclasses.is_dataclass(clazz) and not is_generic_alias_of_dataclass(clazz): clazz = dataclasses.dataclass(clazz) if localns is None: if clazz_frame is None: @@ -514,17 +534,21 @@ def _internal_class_schema( ) -> Type[marshmallow.Schema]: schema_ctx = _schema_ctx_stack.top - if typing_extensions.get_origin(clazz) is Annotated and sys.version_info < (3, 10): + if get_origin(clazz) is Annotated and sys.version_info < (3, 10): # https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977 class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined] else: - class_name = clazz.__name__ + # generic aliases do not have a __name__ prior python 3.10 + class_name = getattr(clazz, "__name__", repr(clazz)) schema_ctx.seen_classes[clazz] = class_name try: - # noinspection PyDataclass - fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) + fields = get_resolved_dataclass_fields( + clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns + ) + except UnboundTypeVarError: + raise except TypeError: # Not a dataclass try: warnings.warn( @@ -540,6 +564,8 @@ def _internal_class_schema( ) created_dataclass: type = dataclasses.dataclass(clazz) return _internal_class_schema(created_dataclass, base_schema) + except UnboundTypeVarError: + raise except Exception as exc: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." @@ -556,23 +582,11 @@ def _internal_class_schema( include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False) # Update the schema members to contain marshmallow fields instead of dataclass fields - - if sys.version_info >= (3, 9): - type_hints = get_type_hints( - clazz, - globalns=schema_ctx.globalns, - localns=schema_ctx.localns, - include_extras=True, - ) - else: - type_hints = get_type_hints( - clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns - ) attributes.update( ( field.name, _field_for_schema( - type_hints[field.name], + field.type, _get_field_default(field), field.metadata, base_schema, @@ -582,7 +596,7 @@ def _internal_class_schema( if field.init or include_non_init ) - schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -633,8 +647,8 @@ def _field_by_supertype( ) -def _generic_type_add_any(typ: type) -> type: - """if typ is generic type without arguments, replace them by Any.""" +def _container_type_add_any(typ: type) -> type: + """if typ is container type without arguments, replace them by Any.""" if typ is list or typ is List: typ = List[Any] elif typ is dict or typ is Dict: @@ -650,18 +664,20 @@ def _generic_type_add_any(typ: type) -> type: return typ -def _field_for_generic_type( +def _field_for_container_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ - If the type is a generic interface, resolve the arguments and construct the appropriate Field. + If the type is a container interface, resolve the arguments and construct the appropriate Field. + + We use the term 'container' to differentiate from the Generic support """ - origin = typing_extensions.get_origin(typ) - arguments = typing_extensions.get_args(typ) + origin = get_origin(typ) + arguments = get_args(typ) if origin: - # Override base_schema.TYPE_MAPPING to change the class used for generic types below + # Override base_schema.TYPE_MAPPING to change the class used for container types below type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): @@ -705,7 +721,7 @@ def _field_for_generic_type( ), ) return tuple_type(children, **metadata) - elif origin in (dict, Dict, collections.abc.Mapping, Mapping): + if origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( keys=_field_for_schema(arguments[0], base_schema=base_schema), @@ -723,14 +739,18 @@ def _field_for_annotated_type( """ If the type is an Annotated interface, resolve the arguments and construct the appropriate Field. """ - origin = typing_extensions.get_origin(typ) - arguments = typing_extensions.get_args(typ) + origin = get_origin(typ) + arguments = get_args(typ) if origin and origin is Annotated: marshmallow_annotations = [ arg for arg in arguments[1:] - if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field)) - or isinstance(arg, marshmallow.fields.Field) + if _is_marshmallow_field(arg) + # Support `CustomGenericField[mf.String]` + or ( + typing_inspect.is_generic_type(arg) + and _is_marshmallow_field(get_origin(arg)) + ) ] if marshmallow_annotations: if len(marshmallow_annotations) > 1: @@ -752,7 +772,7 @@ def _field_for_union_type( base_schema: Optional[Type[marshmallow.Schema]], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: - arguments = typing_extensions.get_args(typ) + arguments = get_args(typ) if typing_inspect.is_union_type(typ): if typing_inspect.is_optional_type(typ): metadata["allow_none"] = metadata.get("allow_none", True) @@ -838,6 +858,9 @@ def _field_for_schema( """ + if isinstance(typ, TypeVar): + raise UnboundTypeVarError(f"can not resolve type variable {typ.__name__}") + metadata = {} if metadata is None else dict(metadata) if default is not marshmallow.missing: @@ -853,8 +876,8 @@ def _field_for_schema( if predefined_field: return predefined_field - # Generic types specified without type arguments - typ = _generic_type_add_any(typ) + # Container types (generics like List) specified without type arguments + typ = _container_type_add_any(typ) # Base types field = _field_by_type(typ, base_schema) @@ -867,7 +890,7 @@ def _field_for_schema( # i.e.: Literal['abc'] if typing_inspect.is_literal_type(typ): - arguments = typing_inspect.get_args(typ) + arguments = get_args(typ) return marshmallow.fields.Raw( validate=( marshmallow.validate.Equal(arguments[0]) @@ -879,7 +902,7 @@ def _field_for_schema( # i.e.: Final[str] = 'abc' if typing_inspect.is_final_type(typ): - arguments = typing_inspect.get_args(typ) + arguments = get_args(typ) if arguments: subtyp = arguments[0] elif default is not marshmallow.missing: @@ -920,10 +943,10 @@ def _field_for_schema( if union_field: return union_field - # Generic types - generic_field = _field_for_generic_type(typ, base_schema, **metadata) - if generic_field: - return generic_field + # Container types + container_field = _field_for_container_type(typ, base_schema, **metadata) + if container_field: + return container_field # typing.NewType returns a function (in python <= 3.9) or a class (python >= 3.10) with a # __supertype__ attribute @@ -952,7 +975,7 @@ def _field_for_schema( nested_schema or forward_reference or _schema_ctx_stack.top.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME + or _internal_class_schema(typ, base_schema) # type: ignore [arg-type] ) return marshmallow.fields.Nested(nested, **metadata) @@ -996,6 +1019,20 @@ def _get_field_default(field: dataclasses.Field): return field.default +def is_generic_alias_of_dataclass(clazz: type) -> bool: + """ + Check if given class is a generic alias of a dataclass, if the dataclass is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + return is_generic_alias(clazz) and dataclasses.is_dataclass(get_origin(clazz)) + + +def _is_marshmallow_field(obj) -> bool: + return ( + inspect.isclass(obj) and issubclass(obj, marshmallow.fields.Field) + ) or isinstance(obj, marshmallow.fields.Field) + + def NewType( name: str, typ: Type[_U], diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py new file mode 100644 index 0000000..6726640 --- /dev/null +++ b/marshmallow_dataclass/generic_resolver.py @@ -0,0 +1,282 @@ +import copy +import dataclasses +import inspect +import sys +from typing import ( + Any, + Dict, + ForwardRef, + Generic, + List, + Optional, + Tuple, + TypeVar, +) + +import typing_inspect + +if sys.version_info >= (3, 9): + from typing import Annotated, get_args, get_origin + + def eval_forward_ref(t: ForwardRef, globalns, localns, recursive_guard=frozenset()): + return t._evaluate(globalns, localns, recursive_guard=recursive_guard) + +else: + from typing_extensions import Annotated, get_args, get_origin + + def eval_forward_ref(t: ForwardRef, globalns, localns): + return t._evaluate(globalns, localns) + + +_U = TypeVar("_U") + + +class UnboundTypeVarError(TypeError): + """TypeVar instance can not be resolved to a type spec. + + This exception is raised when an unbound TypeVar is encountered. + + """ + + +class InvalidStateError(Exception): + """Raised when an operation is performed on a future that is not + allowed in the current state. + """ + + +class _Future(Generic[_U]): + """The _Future class allows deferred access to a result that is not + yet available. + """ + + _done: bool + _result: _U + + def __init__(self) -> None: + self._done = False + + def done(self) -> bool: + """Return ``True`` if the value is available""" + return self._done + + def result(self) -> _U: + """Return the deferred value. + + Raises ``InvalidStateError`` if the value has not been set. + """ + if self.done(): + return self._result + raise InvalidStateError("result has not been set") + + def set_result(self, result: _U) -> None: + if self.done(): + raise InvalidStateError("result has already been set") + self._result = result + self._done = True + + +def is_generic_alias(clazz: type) -> bool: + """ + Check if given object is a Generic Alias. + + A `generic alias`__ is a generic type bound to generic parameters. + + E.g., given + + class A(Generic[T]): + pass + + ``A[int]`` is a _generic alias_ (while ``A`` is a *generic type*, but not a *generic alias*). + """ + is_generic = typing_inspect.is_generic_type(clazz) + type_arguments = get_args(clazz) + return is_generic and len(type_arguments) > 0 + + +def may_contain_typevars(clazz: type) -> bool: + """ + Check if the class can contain typevars. This includes Special Forms. + + Different from typing_inspect.is_generic_type as that explicitly ignores Union and Tuple. + + We still need to resolve typevars for Union and Tuple + """ + origin = get_origin(clazz) + return origin is not Annotated and ( + (isinstance(clazz, type) and issubclass(clazz, Generic)) # type: ignore[arg-type] + or isinstance(clazz, typing_inspect.typingGenericAlias) + ) + + +def _get_namespaces( + clazz: type, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # region - Copied from typing.get_type_hints + if globalns is None: + base_globals = getattr(sys.modules.get(clazz.__module__, None), "__dict__", {}) + else: + base_globals = globalns + base_locals = dict(vars(clazz)) if localns is None else localns + if localns is None and globalns is None: + # This is surprising, but required. Before Python 3.10, + # get_type_hints only evaluated the globalns of + # a class. To maintain backwards compatibility, we reverse + # the globalns and localns order so that eval() looks into + # *base_globals* first rather than *base_locals*. + # This only affects ForwardRefs. + base_globals, base_locals = base_locals, base_globals + # endregion - Copied from typing.get_type_hints + + return base_globals, base_locals + + +def _resolve_typevars( + clazz: type, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, +) -> Dict[type, Dict[TypeVar, _Future]]: + """ + Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics. + + Returns a dict of each base class and the resolved generics. + """ + # Use Tuples so can zip (order matters) + args_by_class: Dict[type, Tuple[Tuple[TypeVar, _Future], ...]] = {} + parent_class: Optional[type] = None + # Loop in reversed order and iteratively resolve types + for subclass in reversed(clazz.mro()): + base_globals, base_locals = _get_namespaces(subclass, globalns, localns) + if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type] + args = get_args(subclass.__orig_bases__[0]) + + if parent_class and args_by_class.get(parent_class): + subclass_generic_params_to_args: List[Tuple[TypeVar, _Future]] = [] + for (_arg, future), potential_type in zip( + args_by_class[parent_class], args + ): + if isinstance(potential_type, TypeVar): + subclass_generic_params_to_args.append((potential_type, future)) + else: + future.set_result( + eval_forward_ref( + potential_type, + globalns=base_globals, + localns=base_locals, + ) + if isinstance(potential_type, ForwardRef) + else potential_type + ) + + args_by_class[subclass] = tuple(subclass_generic_params_to_args) + else: + args_by_class[subclass] = tuple((arg, _Future()) for arg in args) + + parent_class = subclass + + # clazz itself is a generic alias i.e.: A[int]. So it hold the last types. + if is_generic_alias(clazz): + origin = get_origin(clazz) + args = get_args(clazz) + for (_arg, future), potential_type in zip(args_by_class[origin], args): # type: ignore[index] + if not isinstance(potential_type, TypeVar): + future.set_result( + eval_forward_ref(potential_type, globalns=globalns, localns=localns) + if isinstance(potential_type, ForwardRef) + else potential_type + ) + + # Convert to nested dict for easier lookup + return {k: {typ: fut for typ, fut in args} for k, args in args_by_class.items()} + + +def _replace_typevars( + clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None +) -> type: + if ( + not resolved_generics + or inspect.isclass(clazz) + or not may_contain_typevars(clazz) + ): + return clazz + + return clazz.copy_with( # type: ignore + tuple( + ( + _replace_typevars(arg, resolved_generics) + if may_contain_typevars(arg) + else ( + resolved_generics[arg].result() if arg in resolved_generics else arg + ) + ) + for arg in get_args(clazz) + ) + ) + + +def get_resolved_dataclass_fields( + clazz: type, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, +) -> Tuple[dataclasses.Field, ...]: + unbound_fields = set() + # Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and + # looses the source class. Thus I don't know how to resolve this at later on. + # Instead we recreate the type but with all known TypeVars resolved to their actual types. + resolved_typevars = _resolve_typevars(clazz, globalns=globalns, localns=localns) + # Dict[field_name, Tuple[original_field, resolved_field]] + fields: Dict[str, Tuple[dataclasses.Field, dataclasses.Field]] = {} + + for subclass in reversed(clazz.mro()): + if not dataclasses.is_dataclass(subclass): + continue + + for field in dataclasses.fields(subclass): + try: + if field.name in fields and fields[field.name][0] == field: + continue # identical, so already resolved. + + # Either the first time we see this field, or it got overridden + # If it's a class we handle it later as a Nested. Nothing to resolve now. + new_field = field + if not inspect.isclass(field.type) and may_contain_typevars(field.type): + new_field = copy.copy(field) + new_field.type = _replace_typevars( + field.type, resolved_typevars.get(subclass) + ) + elif isinstance(field.type, TypeVar): + new_field = copy.copy(field) + new_field.type = resolved_typevars[subclass][field.type].result() + elif isinstance(field.type, ForwardRef): + base_globals, base_locals = _get_namespaces( + subclass, globalns, localns + ) + new_field = copy.copy(field) + new_field.type = eval_forward_ref( + field.type, globalns=base_globals, localns=base_locals + ) + elif isinstance(field.type, str): + base_globals, base_locals = _get_namespaces( + subclass, globalns, localns + ) + new_field = copy.copy(field) + new_field.type = eval_forward_ref( + ForwardRef(field.type, is_argument=False, is_class=True) + if sys.version_info >= (3, 9) + else ForwardRef(field.type, is_argument=False), + globalns=base_globals, + localns=base_locals, + ) + + fields[field.name] = (field, new_field) + except (InvalidStateError, KeyError): + unbound_fields.add(field.name) + + if unbound_fields: + raise UnboundTypeVarError( + f"{clazz.__name__} has unbound fields: {', '.join(unbound_fields)}" + ) + + return tuple(v[1] for v in fields.values()) diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 28185a4..8a1beb7 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -1,7 +1,7 @@ import inspect import typing import unittest -from typing import Any, cast, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast from uuid import UUID try: @@ -10,11 +10,14 @@ from typing_extensions import Final, Literal # type: ignore[assignment] import dataclasses + from marshmallow import Schema, ValidationError -from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer +from marshmallow.fields import UUID as UUIDField +from marshmallow.fields import Field, Integer +from marshmallow.fields import List as ListField from marshmallow.validate import Validator -from marshmallow_dataclass import class_schema, NewType +from marshmallow_dataclass import NewType, class_schema class TestClassSchema(unittest.TestCase): diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 0000000..ee853b7 --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,350 @@ +import dataclasses +import inspect +import sys +import typing +import unittest +from typing_inspect import is_generic_type + +import marshmallow.fields +from marshmallow import ValidationError + +from marshmallow_dataclass import ( + UnboundTypeVarError, + add_schema, + class_schema, + dataclass, + is_generic_alias_of_dataclass, +) + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + + +def get_orig_class(obj): + """ + Allows you got get the runtime origin class inside __init__ + + Near duplicate of https://github.com/Stewori/pytypes/blob/master/pytypes/type_util.py#L182 + """ + try: + # See https://github.com/Stewori/pytypes/pull/53: + # Returns `obj.__orig_class__` protecting from infinite recursion in `__getattr[ibute]__` + # wrapped in a `checker_tp`. + # (See `checker_tp` in `typechecker._typeinspect_func for context) + # Necessary if: + # - we're wrapping a method (`obj` is `self`/`cls`) and either + # - the object's class defines __getattribute__ + # or + # - the object doesn't have an `__orig_class__` attribute + # and the object's class defines __getattr__. + # In such a situation, `parent_class = obj.__orig_class__` + # would call `__getattr[ibute]__`. But that method is wrapped in a `checker_tp` too, + # so then we'd go into the wrapped `__getattr[ibute]__` and do + # `parent_class = obj.__orig_class__`, which would call `__getattr[ibute]__` + # again, and so on. So to bypass `__getattr[ibute]__` we do this: + return object.__getattribute__(obj, "__orig_class__") + except AttributeError: + cls = object.__getattribute__(obj, "__class__") + if is_generic_type(cls): + # Searching from index 1 is sufficient: At 0 is get_orig_class, at 1 is the caller. + frame = inspect.currentframe().f_back + try: + while frame: + try: + res = frame.f_locals["self"] + if res.__origin__ is cls: + return res + except (KeyError, AttributeError): + frame = frame.f_back + finally: + del frame + + raise + + +class TestGenerics(unittest.TestCase): + def test_generic_dataclass(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class NestedFixed: + data: SimpleGeneric[int] + + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data": {"data": "str"}}) + + def test_generic_dataclass_repeated_fields(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class AA: + a: int + + @dataclasses.dataclass + class BB(typing.Generic[T]): + b: T + + @dataclasses.dataclass + class Nested: + x: BB[float] + z: BB[float] + # if y is the first field in this class, deserialisation will fail. + # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 + y: BB[AA] + + schema_nested = class_schema(Nested)() + self.assertEqual( + Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))), + schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), + ) + + def test_marshmallow_dataclass_decorator_raises_on_generic_alias(self): + """ + We can't support `GenClass[int].Schema` because the class function was created on `GenClass` + Therefore the function does not know about the `int` type. + This is a Python limitation, not a marshmallow_dataclass limitation. + """ + import marshmallow_dataclass + + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + marshmallow_dataclass.dataclass(GenClass[int]) + + def test_add_schema_raises_on_generic_alias(self): + """ + We can't support `GenClass[int].Schema` because the class function was created on `GenClass` + Therefore the function does not know about the `int` type. + This is a Python limitation, not a marshmallow_dataclass limitation. + """ + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + add_schema(GenClass[int]) + + def test_deep_generic(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + pairs: typing.List[typing.Tuple[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)]) + ) + + def test_deep_generic_with_union(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + either: typing.List[typing.Union[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"either": ["first", 1]}), TestClass(["first", 1]) + ) + + def test_deep_generic_with_overrides(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + V = typing.TypeVar("V") + W = typing.TypeVar("W") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U, V]): + pairs: typing.List[typing.Tuple[T, U]] + gen: V + override: int + + # Don't only override typevar, but switch order to further confuse things + @dataclasses.dataclass + class TestClass2(TestClass[str, W, U]): + override: str # type: ignore # Want to test that it works, even if incompatible types + + TestAlias = TestClass2[int, T] + + # inherit from alias + @dataclasses.dataclass + class TestClass3(TestAlias[typing.List[int]]): + pass + + test_schema = class_schema(TestClass3)() + + self.assertEqual( + test_schema.load( + {"pairs": [("first", "1")], "gen": ["1", 2], "override": "overridden"} + ), + TestClass3([("first", 1)], [1, 2], "overridden"), + ) + + def test_generic_bases(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[T]): + pass + + test_schema = class_schema(TestClass[int])() + + self.assertEqual(test_schema.load({"answer": "1"}), TestClass(1)) + + def test_bound_generic_base(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[int]): + pass + + with self.assertRaisesRegex( + UnboundTypeVarError, "Base1 has unbound fields: answer" + ): + class_schema(Base1) + + test_schema = class_schema(TestClass)() + self.assertEqual(test_schema.load({"answer": "1"}), TestClass(1)) + + def test_unbound_type_var(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base: + answer: T # type: ignore[valid-type] + + with self.assertRaises(UnboundTypeVarError): + class_schema(Base) + + with self.assertRaises(TypeError): + class_schema(Base) + + def test_annotated_generic_mf_field(self) -> None: + T = typing.TypeVar("T") + + class GenericList(marshmallow.fields.List, typing.Generic[T]): + """ + Generic Marshmallow List Field that can be used in Annotated and still get all kwargs + from marshmallow_dataclass. + """ + + def __init__( + self, + **kwargs, + ): + cls_or_instance = get_orig_class(self).__args__[0] + + super().__init__(cls_or_instance, **kwargs) + + @dataclass + class AnnotatedValue: + emails: Annotated[ + typing.List[str], GenericList[marshmallow.fields.Email] + ] = dataclasses.field(default_factory=lambda: ["default@email.com"]) + + schema = AnnotatedValue.Schema() # type: ignore[attr-defined] + + self.assertEqual( + schema.load({}), + AnnotatedValue(emails=["default@email.com"]), + ) + self.assertEqual( + schema.load({"emails": ["test@test.com"]}), + AnnotatedValue( + emails=["test@test.com"], + ), + ) + + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"emails": "notavalidemail"}) + + def test_generic_dataclass_with_forwardref(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class ForwardGeneric(typing.Generic[T]): + data: T + + schema_s = class_schema(ForwardGeneric["str"])() + self.assertEqual(ForwardGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(ForwardGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + def test_generic_dataclass_with_optional(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class OptionalGeneric(typing.Generic[T]): + data: typing.Optional[T] + + schema_s = class_schema(OptionalGeneric["str"])() + self.assertEqual(OptionalGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(OptionalGeneric(data="a")), {"data": "a"}) + + self.assertEqual(OptionalGeneric(data=None), schema_s.load({})) + self.assertEqual(schema_s.dump(OptionalGeneric(data=None)), {"data": None}) + + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + +if __name__ == "__main__": + unittest.main()