Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Generic dataclasses #259

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 82 additions & 45 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,23 @@ class User:
TypeVar,
Union,
cast,
get_type_hints,
overload,
)

import marshmallow
import typing_extensions
import typing_inspect

from marshmallow_dataclass.generic_resolver import (
UnboundTypeVarError,
get_resolved_dataclass_fields,
is_generic_alias,
)
from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute

if sys.version_info >= (3, 9):
from typing import Annotated
from typing import Annotated, get_args, get_origin
else:
from typing_extensions import Annotated
from typing_extensions import Annotated, get_args, get_origin

if sys.version_info >= (3, 11):
from typing import dataclass_transform
Expand Down Expand Up @@ -139,6 +142,18 @@ def _maybe_get_callers_frame(
del frame


def _check_decorated_type(cls: object) -> None:
if not isinstance(cls, type):
raise TypeError(f"expected a class not {cls!r}")
if is_generic_alias(cls):
# A .Schema attribute doesn't make sense on a generic alias — there's
# no way for it to know the generic parameters at run time.
raise TypeError(
"decorator does not support generic aliases "
"(hint: use class_schema directly instead)"
)


@overload
def dataclass(
_cls: Type[_U],
Expand Down Expand Up @@ -214,12 +229,15 @@ def dataclass(
)

def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:
_check_decorated_type(cls)

return add_schema(
dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1
)

if _cls is None:
return decorator

return decorator(_cls, stacklevel=stacklevel + 1)


Expand Down Expand Up @@ -268,6 +286,8 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1):
"""

def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
_check_decorated_type(clazz)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_check_decorate_type requires only that isinstance(clazz, type).
Do we want to require that clazz is a dataclass (isinstance(clazz, type) and dataclasses.is_dataclass(clazz)) here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because we already have a big warning when a class is not a dataclass in _internal_class_schema
So non-dataclasses are allowed, but not supported

Copy link
Collaborator

@dairiki dairiki Sep 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just run a simple test.

You are correct in that decorating non-data class classes with add_schema works — at least in simple cases. (That doesn't necessarily mean that we should allow it.)

It appears, however, that, as things stand in this PR, no big warning is emitted in that case. Further investigation reveals that get_resolved_dataclass_fields "just works" (with no warnings emitted) even if its argument is not a dataclass. (I have some suspicion that perhaps fields from the classes __mro__ are not properly handled in that case, but I haven't looked close enough to say for sure.)

In any case, I think that either:

  • We should disallow using the add_schema decorator on non-dataclasses. (Why allow it if it's unsupported/untested?)
  • Or, we need a test to ensure the warning is, in fact, emitted when it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had to do some digging.

  1. In today's version, 8.7.1, I can call class_schema(NonDataclass) just fine. It only shows the warning if one of it's fields is not a dataclass. i.e.: Nested non dataclasses.
  2. This goes back to when the warning was originally added: e31faa8
  3. The behaviour still works the same with this PR.

I don't disagree with removing support for non-dataclasses, but don't see why that should be part of this PR.


if cls_frame is not None:
frame = cls_frame
else:
Expand Down Expand Up @@ -453,7 +473,7 @@ def class_schema(
>>> class_schema(Custom)().load({})
Custom(name=None)
"""
if not dataclasses.is_dataclass(clazz):
if not dataclasses.is_dataclass(clazz) and not is_generic_alias_of_dataclass(clazz):
clazz = dataclasses.dataclass(clazz)
if localns is None:
if clazz_frame is None:
Expand Down Expand Up @@ -514,17 +534,21 @@ def _internal_class_schema(
) -> Type[marshmallow.Schema]:
schema_ctx = _schema_ctx_stack.top

if typing_extensions.get_origin(clazz) is Annotated and sys.version_info < (3, 10):
if get_origin(clazz) is Annotated and sys.version_info < (3, 10):
# https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined]
else:
class_name = clazz.__name__
# generic aliases do not have a __name__ prior python 3.10
class_name = getattr(clazz, "__name__", repr(clazz))

schema_ctx.seen_classes[clazz] = class_name

try:
# noinspection PyDataclass
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
fields = get_resolved_dataclass_fields(
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
)
except UnboundTypeVarError:
raise
except TypeError: # Not a dataclass
try:
warnings.warn(
Expand All @@ -540,6 +564,8 @@ def _internal_class_schema(
)
created_dataclass: type = dataclasses.dataclass(clazz)
return _internal_class_schema(created_dataclass, base_schema)
except UnboundTypeVarError:
raise
except Exception as exc:
raise TypeError(
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
Expand All @@ -556,23 +582,11 @@ def _internal_class_schema(
include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False)

# Update the schema members to contain marshmallow fields instead of dataclass fields

if sys.version_info >= (3, 9):
type_hints = get_type_hints(
clazz,
globalns=schema_ctx.globalns,
localns=schema_ctx.localns,
include_extras=True,
)
else:
type_hints = get_type_hints(
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
)
attributes.update(
(
field.name,
_field_for_schema(
type_hints[field.name],
field.type,
_get_field_default(field),
field.metadata,
base_schema,
Expand All @@ -582,7 +596,7 @@ def _internal_class_schema(
if field.init or include_non_init
)

schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes)
return cast(Type[marshmallow.Schema], schema_class)


Expand Down Expand Up @@ -633,8 +647,8 @@ def _field_by_supertype(
)


def _generic_type_add_any(typ: type) -> type:
"""if typ is generic type without arguments, replace them by Any."""
def _container_type_add_any(typ: type) -> type:
"""if typ is container type without arguments, replace them by Any."""
if typ is list or typ is List:
typ = List[Any]
elif typ is dict or typ is Dict:
Expand All @@ -650,18 +664,20 @@ def _generic_type_add_any(typ: type) -> type:
return typ


def _field_for_generic_type(
def _field_for_container_type(
typ: type,
base_schema: Optional[Type[marshmallow.Schema]],
**metadata: Any,
) -> Optional[marshmallow.fields.Field]:
"""
If the type is a generic interface, resolve the arguments and construct the appropriate Field.
If the type is a container interface, resolve the arguments and construct the appropriate Field.

We use the term 'container' to differentiate from the Generic support
"""
origin = typing_extensions.get_origin(typ)
arguments = typing_extensions.get_args(typ)
origin = get_origin(typ)
arguments = get_args(typ)
if origin:
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
# Override base_schema.TYPE_MAPPING to change the class used for container types below
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}

if origin in (list, List):
Expand Down Expand Up @@ -705,7 +721,7 @@ def _field_for_generic_type(
),
)
return tuple_type(children, **metadata)
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
if origin in (dict, Dict, collections.abc.Mapping, Mapping):
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
return dict_type(
keys=_field_for_schema(arguments[0], base_schema=base_schema),
Expand All @@ -723,14 +739,18 @@ def _field_for_annotated_type(
"""
If the type is an Annotated interface, resolve the arguments and construct the appropriate Field.
"""
origin = typing_extensions.get_origin(typ)
arguments = typing_extensions.get_args(typ)
origin = get_origin(typ)
arguments = get_args(typ)
if origin and origin is Annotated:
marshmallow_annotations = [
arg
for arg in arguments[1:]
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
or isinstance(arg, marshmallow.fields.Field)
if _is_marshmallow_field(arg)
# Support `CustomGenericField[mf.String]`
or (
typing_inspect.is_generic_type(arg)
and _is_marshmallow_field(get_origin(arg))
)
]
if marshmallow_annotations:
if len(marshmallow_annotations) > 1:
Expand All @@ -752,7 +772,7 @@ def _field_for_union_type(
base_schema: Optional[Type[marshmallow.Schema]],
**metadata: Any,
) -> Optional[marshmallow.fields.Field]:
arguments = typing_extensions.get_args(typ)
arguments = get_args(typ)
if typing_inspect.is_union_type(typ):
if typing_inspect.is_optional_type(typ):
metadata["allow_none"] = metadata.get("allow_none", True)
Expand Down Expand Up @@ -838,6 +858,9 @@ def _field_for_schema(

"""

if isinstance(typ, TypeVar):
raise UnboundTypeVarError(f"can not resolve type variable {typ.__name__}")

metadata = {} if metadata is None else dict(metadata)

if default is not marshmallow.missing:
Expand All @@ -853,8 +876,8 @@ def _field_for_schema(
if predefined_field:
return predefined_field

# Generic types specified without type arguments
typ = _generic_type_add_any(typ)
# Container types (generics like List) specified without type arguments
typ = _container_type_add_any(typ)

# Base types
field = _field_by_type(typ, base_schema)
Expand All @@ -867,7 +890,7 @@ def _field_for_schema(

# i.e.: Literal['abc']
if typing_inspect.is_literal_type(typ):
arguments = typing_inspect.get_args(typ)
arguments = get_args(typ)
return marshmallow.fields.Raw(
validate=(
marshmallow.validate.Equal(arguments[0])
Expand All @@ -879,7 +902,7 @@ def _field_for_schema(

# i.e.: Final[str] = 'abc'
if typing_inspect.is_final_type(typ):
arguments = typing_inspect.get_args(typ)
arguments = get_args(typ)
if arguments:
subtyp = arguments[0]
elif default is not marshmallow.missing:
Expand Down Expand Up @@ -920,10 +943,10 @@ def _field_for_schema(
if union_field:
return union_field

# Generic types
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
if generic_field:
return generic_field
# Container types
container_field = _field_for_container_type(typ, base_schema, **metadata)
if container_field:
return container_field

# typing.NewType returns a function (in python <= 3.9) or a class (python >= 3.10) with a
# __supertype__ attribute
Expand Down Expand Up @@ -952,7 +975,7 @@ def _field_for_schema(
nested_schema
or forward_reference
or _schema_ctx_stack.top.seen_classes.get(typ)
or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME
or _internal_class_schema(typ, base_schema) # type: ignore [arg-type]
)

return marshmallow.fields.Nested(nested, **metadata)
Expand Down Expand Up @@ -996,6 +1019,20 @@ def _get_field_default(field: dataclasses.Field):
return field.default


def is_generic_alias_of_dataclass(clazz: type) -> bool:
"""
Check if given class is a generic alias of a dataclass, if the dataclass is
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
"""
return is_generic_alias(clazz) and dataclasses.is_dataclass(get_origin(clazz))


def _is_marshmallow_field(obj) -> bool:
return (
inspect.isclass(obj) and issubclass(obj, marshmallow.fields.Field)
) or isinstance(obj, marshmallow.fields.Field)


def NewType(
name: str,
typ: Type[_U],
Expand Down
Loading