Skip to content

Commit

Permalink
Break generic functions out into it's own file and add support for an…
Browse files Browse the repository at this point in the history
…notated generics, partials, and callables
  • Loading branch information
mvanderlee committed Jun 25, 2024
1 parent 80dab91 commit 4531c35
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 185 deletions.
222 changes: 41 additions & 181 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class User:
"""

import collections.abc
import copy
import dataclasses
import inspect
import sys
Expand Down Expand Up @@ -64,6 +63,12 @@ class User:
import typing_extensions
import typing_inspect

from marshmallow_dataclass.generic_resolver import (
UnboundTypeVarError,
get_generic_dataclass_fields,
is_generic_alias,
is_generic_type,
)
from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute

if sys.version_info >= (3, 9):
Expand Down Expand Up @@ -134,55 +139,10 @@ def _maybe_get_callers_frame(
del frame


class UnboundTypeVarError(TypeError):
"""TypeVar instance can not be resolved to a type spec.
This exception is raised when an unbound TypeVar is encountered.
"""


class InvalidStateError(Exception):
"""Raised when an operation is performed on a future that is not
allowed in the current state.
"""


class _Future(Generic[_U]):
"""The _Future class allows deferred access to a result that is not
yet available.
"""

_done: bool
_result: _U

def __init__(self) -> None:
self._done = False

def done(self) -> bool:
"""Return ``True`` if the value is available"""
return self._done

def result(self) -> _U:
"""Return the deferred value.
Raises ``InvalidStateError`` if the value has not been set.
"""
if self.done():
return self._result
raise InvalidStateError("result has not been set")

def set_result(self, result: _U) -> None:
if self.done():
raise InvalidStateError("result has already been set")
self._result = result
self._done = True


def _check_decorated_type(cls: object) -> None:
if not isinstance(cls, type):
raise TypeError(f"expected a class not {cls!r}")
if _is_generic_alias(cls):
if is_generic_alias(cls):
# A .Schema attribute doesn't make sense on a generic alias — there's
# no way for it to know the generic parameters at run time.
raise TypeError(
Expand Down Expand Up @@ -513,9 +473,7 @@ def class_schema(
>>> class_schema(Custom)().load({})
Custom(name=None)
"""
if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass(
clazz
):
if not dataclasses.is_dataclass(clazz) and not is_generic_alias_of_dataclass(clazz):
clazz = dataclasses.dataclass(clazz)
if localns is None:
if clazz_frame is None:
Expand Down Expand Up @@ -791,8 +749,16 @@ def _field_for_annotated_type(
marshmallow_annotations = [
arg
for arg in arguments[1:]
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
or isinstance(arg, marshmallow.fields.Field)
if _is_marshmallow_field(arg)
# Support `CustomGenericField[mf.String]`
or (
is_generic_type(arg)
and _is_marshmallow_field(typing_extensions.get_origin(arg))
)
# Support `partial(mf.List, mf.String)`
or (isinstance(arg, partial) and _is_marshmallow_field(arg.func))
# Support `lambda *args, **kwargs: mf.List(mf.String, *args, **kwargs)`
or (_is_callable_marshmallow_field(arg))
]
if marshmallow_annotations:
if len(marshmallow_annotations) > 1:
Expand Down Expand Up @@ -932,7 +898,7 @@ def _field_for_schema(

# i.e.: Literal['abc']
if typing_inspect.is_literal_type(typ):
arguments = typing_inspect.get_args(typ)
arguments = typing_extensions.get_args(typ)
return marshmallow.fields.Raw(
validate=(
marshmallow.validate.Equal(arguments[0])
Expand All @@ -944,7 +910,7 @@ def _field_for_schema(

# i.e.: Final[str] = 'abc'
if typing_inspect.is_final_type(typ):
arguments = typing_inspect.get_args(typ)
arguments = typing_extensions.get_args(typ)
if arguments:
subtyp = arguments[0]
elif default is not marshmallow.missing:
Expand Down Expand Up @@ -1061,14 +1027,14 @@ def _get_field_default(field: dataclasses.Field):
return field.default


def _is_generic_alias_of_dataclass(clazz: type) -> bool:
def is_generic_alias_of_dataclass(clazz: type) -> bool:
"""
Check if given class is a generic alias of a dataclass, if the dataclass is
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
"""
is_generic = is_generic_type(clazz)
type_arguments = typing_inspect.get_args(clazz)
origin_class = typing_inspect.get_origin(clazz)
type_arguments = typing_extensions.get_args(clazz)
origin_class = typing_extensions.get_origin(clazz)
return (
is_generic
and len(type_arguments) > 0
Expand Down Expand Up @@ -1107,136 +1073,30 @@ class X:
return _get_type_hints(X, schema_ctx)["x"]


def _is_generic_alias(clazz: type) -> bool:
"""
Check if given class is a generic alias of a class is
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
"""
is_generic = is_generic_type(clazz)
type_arguments = typing_inspect.get_args(clazz)
return is_generic and len(type_arguments) > 0


def is_generic_type(clazz: type) -> bool:
"""
typing_inspect.is_generic_type explicitly ignores Union, Tuple, Callable, ClassVar
"""
return (
isinstance(clazz, type)
and issubclass(clazz, Generic) # type: ignore[arg-type]
or isinstance(clazz, typing_inspect.typingGenericAlias)
)


def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
"""
Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics.
Returns a dict of each base class and the resolved generics.
"""
# Use Tuples so can zip (order matters)
args_by_class: Dict[type, Tuple[Tuple[TypeVar, _Future], ...]] = {}
parent_class: Optional[type] = None
# Loop in reversed order and iteratively resolve types
for subclass in reversed(clazz.mro()):
if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type]
args = typing_inspect.get_args(subclass.__orig_bases__[0])

if parent_class and args_by_class.get(parent_class):
subclass_generic_params_to_args: List[Tuple[TypeVar, _Future]] = []
for (_arg, future), potential_type in zip(
args_by_class[parent_class], args
):
if isinstance(potential_type, TypeVar):
subclass_generic_params_to_args.append((potential_type, future))
else:
future.set_result(potential_type)

args_by_class[subclass] = tuple(subclass_generic_params_to_args)

else:
args_by_class[subclass] = tuple((arg, _Future()) for arg in args)

parent_class = subclass

# clazz itself is a generic alias i.e.: A[int]. So it hold the last types.
if _is_generic_alias(clazz):
origin = typing_inspect.get_origin(clazz)
args = typing_inspect.get_args(clazz)
for (_arg, future), potential_type in zip(args_by_class[origin], args):
if not isinstance(potential_type, TypeVar):
future.set_result(potential_type)

# Convert to nested dict for easier lookup
return {k: {typ: fut for typ, fut in args} for k, args in args_by_class.items()}


def _replace_typevars(
clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None
) -> type:
if not resolved_generics or inspect.isclass(clazz) or not is_generic_type(clazz):
return clazz

return clazz.copy_with( # type: ignore
tuple(
(
_replace_typevars(arg, resolved_generics)
if is_generic_type(arg)
else (
resolved_generics[arg].result() if arg in resolved_generics else arg
)
)
for arg in typing_inspect.get_args(clazz)
)
)


def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
if not is_generic_type(clazz):
return dataclasses.fields(clazz)

else:
unbound_fields = set()
# Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and
# looses the source class. Thus I don't know how to resolve this at later on.
# Instead we recreate the type but with all known TypeVars resolved to their actual types.
resolved_typevars = _resolve_typevars(clazz)
# Dict[field_name, Tuple[original_field, resolved_field]]
fields: Dict[str, Tuple[dataclasses.Field, dataclasses.Field]] = {}

for subclass in reversed(clazz.mro()):
if not dataclasses.is_dataclass(subclass):
continue

for field in dataclasses.fields(subclass):
try:
if field.name in fields and fields[field.name][0] == field:
continue # identical, so already resolved.

# Either the first time we see this field, or it got overridden
# If it's a class we handle it later as a Nested. Nothing to resolve now.
new_field = field
if not inspect.isclass(field.type) and is_generic_type(field.type):
new_field = copy.copy(field)
new_field.type = _replace_typevars(
field.type, resolved_typevars[subclass]
)
elif isinstance(field.type, TypeVar):
new_field = copy.copy(field)
new_field.type = resolved_typevars[subclass][
field.type
].result()

fields[field.name] = (field, new_field)
except InvalidStateError:
unbound_fields.add(field.name)

if unbound_fields:
raise UnboundTypeVarError(
f"{clazz.__name__} has unbound fields: {', '.join(unbound_fields)}"
)
return get_generic_dataclass_fields(clazz)


def _is_marshmallow_field(obj) -> bool:
return (
inspect.isclass(obj) and issubclass(obj, marshmallow.fields.Field)
) or isinstance(obj, marshmallow.fields.Field)


def _is_callable_marshmallow_field(obj) -> bool:
"""Checks if the object is a callable and if the callable returns a marshmallow field"""
if callable(obj):
try:
potential_field = obj()
return _is_marshmallow_field(potential_field)
except Exception:
return False

return tuple(v[1] for v in fields.values())
return False


def NewType(
Expand Down
Loading

0 comments on commit 4531c35

Please sign in to comment.