Skip to content

Commit

Permalink
feat: Allow recursive types
Browse files Browse the repository at this point in the history
Like this.
```
@DataClass
class Foo:
    f: Optional['Foo']

serde(Foo)
```
  • Loading branch information
yukinarit committed Dec 4, 2022
1 parent e3f5c47 commit 3ff1cde
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 101 deletions.
17 changes: 17 additions & 0 deletions examples/recursive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dataclasses import dataclass
from typing import Optional

from serde import from_dict, serde, to_dict


@dataclass
class Foo:
f: Optional['Foo']


serde(Foo)


f = Foo(Foo(Foo(None)))
print(to_dict(f))
print(from_dict(Foo, {'f': {'f': {'f': None}}}))
2 changes: 2 additions & 0 deletions examples/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import literal
import newtype
import pep681
import recursive
import rename
import rename_all
import simple
Expand Down Expand Up @@ -73,6 +74,7 @@ def run_all():
run(init_var)
run(class_var)
run(alias)
run(recursive)
if PY310:
import union_operator

Expand Down
209 changes: 120 additions & 89 deletions serde/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,109 +304,140 @@ def dataclass_fields(cls: Type[Any]) -> Iterator[dataclasses.Field]:
return iter(raw_fields)


def iter_types(cls: Type) -> Iterator[Union[Type, typing.Any]]:
TypeLike = Union[Type[Any], typing.Any]


def iter_types(cls: TypeLike) -> List[TypeLike]:
"""
Iterate field types recursively.
The correct return type is `Iterator[Union[Type, typing._specialform]],
but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead.
"""
if is_dataclass(cls):
yield cls
for f in dataclass_fields(cls):
yield from iter_types(f.type)
elif isinstance(cls, str):
yield cls
elif is_opt(cls):
yield Optional
arg = type_args(cls)
if arg:
yield from iter_types(arg[0])
elif is_union(cls):
yield Union
for arg in type_args(cls):
yield from iter_types(arg)
elif is_list(cls) or is_set(cls):
yield List
arg = type_args(cls)
if arg:
yield from iter_types(arg[0])
elif is_set(cls):
yield Set
arg = type_args(cls)
if arg:
yield from iter_types(arg[0])
elif is_tuple(cls):
yield Tuple
for arg in type_args(cls):
yield from iter_types(arg)
elif is_dict(cls):
yield Dict
arg = type_args(cls)
if arg and len(arg) >= 2:
yield from iter_types(arg[0])
yield from iter_types(arg[1])
else:
yield cls
lst: Set[TypeLike] = set()

def recursive(cls: TypeLike) -> None:
if cls in lst:
return

if is_dataclass(cls):
lst.add(cls)
for f in dataclass_fields(cls):
recursive(f.type)
elif isinstance(cls, str):
lst.add(cls)
elif is_opt(cls):
lst.add(Optional)
arg = type_args(cls)
if arg:
recursive(arg[0])
elif is_union(cls):
lst.add(Union)
for arg in type_args(cls):
recursive(arg)
elif is_list(cls) or is_set(cls):
lst.add(List)
arg = type_args(cls)
if arg:
recursive(arg[0])
elif is_set(cls):
lst.add(Set)
arg = type_args(cls)
if arg:
recursive(arg[0])
elif is_tuple(cls):
lst.add(Tuple)
for arg in type_args(cls):
recursive(arg)
elif is_dict(cls):
lst.add(Dict)
arg = type_args(cls)
if arg and len(arg) >= 2:
recursive(arg[0])
recursive(arg[1])
else:
lst.add(cls)

recursive(cls)
return list(lst)


def iter_unions(cls: Type) -> Iterator[Type]:
def iter_unions(cls: TypeLike) -> List[TypeLike]:
"""
Iterate over all unions that are used in the dataclass
"""
if is_union(cls):
yield cls
for arg in type_args(cls):
yield from iter_unions(arg)
if is_dataclass(cls):
for f in dataclass_fields(cls):
yield from iter_unions(f.type)
elif is_opt(cls):
arg = type_args(cls)
if arg:
yield from iter_unions(arg[0])
elif is_list(cls) or is_set(cls):
arg = type_args(cls)
if arg:
yield from iter_unions(arg[0])
elif is_tuple(cls):
for arg in type_args(cls):
yield from iter_unions(arg)
elif is_dict(cls):
arg = type_args(cls)
if arg and len(arg) >= 2:
yield from iter_unions(arg[0])
yield from iter_unions(arg[1])


def iter_literals(cls: Type) -> Iterator[Type]:
lst: Set[TypeLike] = set()

def recursive(cls: TypeLike) -> None:
if cls in lst:
return

if is_union(cls):
lst.add(cls)
for arg in type_args(cls):
recursive(arg)
if is_dataclass(cls):
for f in dataclass_fields(cls):
recursive(f.type)
elif is_opt(cls):
arg = type_args(cls)
if arg:
recursive(arg[0])
elif is_list(cls) or is_set(cls):
arg = type_args(cls)
if arg:
recursive(arg[0])
elif is_tuple(cls):
for arg in type_args(cls):
recursive(arg)
elif is_dict(cls):
arg = type_args(cls)
if arg and len(arg) >= 2:
recursive(arg[0])
recursive(arg[1])

recursive(cls)
return list(lst)


def iter_literals(cls: TypeLike) -> List[TypeLike]:
"""
Iterate over all literals that are used in the dataclass
"""
if is_literal(cls):
yield cls
if is_union(cls):
for arg in type_args(cls):
yield from iter_literals(arg)
if is_dataclass(cls):
for f in dataclass_fields(cls):
yield from iter_literals(f.type)
elif is_opt(cls):
arg = type_args(cls)
if arg:
yield from iter_literals(arg[0])
elif is_list(cls) or is_set(cls):
arg = type_args(cls)
if arg:
yield from iter_literals(arg[0])
elif is_tuple(cls):
for arg in type_args(cls):
yield from iter_literals(arg)
elif is_dict(cls):
arg = type_args(cls)
if arg and len(arg) >= 2:
yield from iter_literals(arg[0])
yield from iter_literals(arg[1])
lst: Set[TypeLike] = set()

def recursive(cls: TypeLike) -> None:
if cls in lst:
return

if is_literal(cls):
lst.add(cls)
if is_union(cls):
for arg in type_args(cls):
recursive(arg)
if is_dataclass(cls):
lst.add(cls)
for f in dataclass_fields(cls):
recursive(f.type)
elif is_opt(cls):
arg = type_args(cls)
if arg:
recursive(arg[0])
elif is_list(cls) or is_set(cls):
arg = type_args(cls)
if arg:
recursive(arg[0])
elif is_tuple(cls):
for arg in type_args(cls):
recursive(arg)
elif is_dict(cls):
arg = type_args(cls)
if arg and len(arg) >= 2:
recursive(arg[0])
recursive(arg[1])

recursive(cls)
return list(lst)


def is_union(typ) -> bool:
Expand Down
13 changes: 11 additions & 2 deletions serde/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,10 @@ class Foo:
serializer: Optional[Func] = None # Custom field serializer.
deserializer: Optional[Func] = None # Custom field deserializer.
flatten: Optional[FlattenOpts] = None
parent: Optional[Type] = None

@classmethod
def from_dataclass(cls, f: dataclasses.Field) -> 'Field':
def from_dataclass(cls, f: dataclasses.Field, parent: Optional[Type] = None) -> 'Field':
"""
Create `Field` object from `dataclasses.Field`.
"""
Expand Down Expand Up @@ -496,6 +497,7 @@ def from_dataclass(cls, f: dataclasses.Field) -> 'Field':
serializer=serializer,
deserializer=deserializer,
flatten=flatten,
parent=parent,
)

def to_dataclass(self) -> dataclasses.Field:
Expand All @@ -513,6 +515,13 @@ def to_dataclass(self) -> dataclasses.Field:
f.type = self.type
return f

def is_self_referencing(self) -> bool:
if self.type is None:
return False
if self.parent is None:
return False
return self.type == self.parent

@staticmethod
def mangle(field: dataclasses.Field, name: str) -> str:
"""
Expand All @@ -538,7 +547,7 @@ def fields(field_cls: Type[F], cls: Type) -> List[F]:
"""
Iterate fields of the dataclass and returns `serde.core.Field`.
"""
return [field_cls.from_dataclass(f) for f in dataclass_fields(cls)]
return [field_cls.from_dataclass(f, parent=cls) for f in dataclass_fields(cls)]


def conv(f: Field, case: Optional[str] = None) -> str:
Expand Down
11 changes: 10 additions & 1 deletion serde/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ def __getitem__(self, n) -> 'DeField':
'skip': self.skip,
'skip_if': self.skip_if,
'skip_if_false': self.skip_if_false,
'flatten': self.flatten,
'alias': self.alias,
'parent': self.parent,
}
if is_list(self.type) or is_dict(self.type) or is_set(self.type):
return InnerField(typ, 'v', datavar='v', **opts)
Expand Down Expand Up @@ -651,7 +654,13 @@ def dataclass(self, arg: DeField) -> str:
var = arg.datavar

opts = "maybe_generic=maybe_generic, reuse_instances=reuse_instances"
return f"{typename(arg.type)}.{SERDE_SCOPE}.funcs['{self.func}'](data={var}, {opts})"

if arg.is_self_referencing():
class_name = "cls"
else:
class_name = typename(arg.type)

return f"{class_name}.{SERDE_SCOPE}.funcs['{self.func}'](data={var}, {opts})"

def opt(self, arg: DeField) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import collections
import datetime
import decimal
import ipaddress
Expand Down Expand Up @@ -92,6 +91,7 @@ def toml_not_supported(se, de, opt) -> bool:
param(data.Pri(10, 'foo', 100.0, True), data.Pri), # dataclass
param(data.Pri(10, 'foo', 100.0, True), Optional[data.Pri]),
param(None, Optional[data.Pri], toml_not_supported),
param(data.Recur(data.Recur(None, None, None), None, None), data.Recur, toml_not_supported),
param(10, NewType('Int', int)), # NewType
param({'a': 1}, Any), # Any
param(GenericClass[str, int]('foo', 10), GenericClass[str, int]), # Generic
Expand Down
10 changes: 10 additions & 0 deletions tests/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ class EnumInClass:
i: imported.IE = imported.IE.V1


@dataclass(unsafe_hash=True)
class Recur:
a: Optional['Recur']
b: Optional[List['Recur']]
c: Optional[Dict[str, 'Recur']]


serde(Recur)


ListPri = List[Pri]

DictPri = Dict[str, Pri]
Expand Down
16 changes: 8 additions & 8 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ class Foo(Generic[T]):


def test_iter_types():
assert [Pri, int, str, float, bool] == list(iter_types(Pri))
assert [Dict, str, Pri, int, str, float, bool] == list(iter_types(Dict[str, Pri]))
assert [List, str] == list(iter_types(List[str]))
assert [Tuple, int, str, bool, float] == list(iter_types(Tuple[int, str, bool, float]))
assert [Tuple, int, Ellipsis] == list(iter_types(Tuple[int, ...]))
assert [PriOpt, Optional, int, Optional, str, Optional, float, Optional, bool] == list(iter_types(PriOpt))
assert set([Pri, int, str, float, bool]) == set(iter_types(Pri))
assert set([Dict, str, Pri, int, float, bool]) == set(iter_types(Dict[str, Pri]))
assert set([List, str]) == set(iter_types(List[str]))
assert set([Tuple, int, str, bool, float]) == set(iter_types(Tuple[int, str, bool, float]))
assert set([Tuple, int, Ellipsis]) == set(iter_types(Tuple[int, ...]))
assert set([PriOpt, Optional, int, str, float, bool]) == set(iter_types(PriOpt))

@serde.serde
class Foo:
Expand All @@ -107,7 +107,7 @@ class Foo:
e: Union[str, int] = 10
f: List[int] = serde.field(default_factory=list)

assert [Foo, int, datetime, datetime, Optional, str, Union, str, int, List, int] == list(iter_types(Foo))
assert set([Foo, datetime, Optional, str, Union, List, int]) == set(iter_types(Foo))


def test_iter_unions():
Expand All @@ -124,7 +124,7 @@ class A:
b: Dict[str, List[Union[float, int]]]
C: Dict[Union[bool, str], Union[float, int]]

assert [Union[int, str], Union[float, int], Union[bool, str], Union[float, int]] == list(iter_unions(A))
assert set([Union[int, str], Union[float, int], Union[bool, str], Union[float, int]]) == set(iter_unions(A))


def test_type_args():
Expand Down

0 comments on commit 3ff1cde

Please sign in to comment.