Skip to content

Commit

Permalink
Merge pull request #19 from klauer/enh_deps_and_serialization
Browse files Browse the repository at this point in the history
ENH: dependencies, first pass at serialization, and more
  • Loading branch information
klauer authored Feb 26, 2022
2 parents 21e35de + f486129 commit a4f241a
Show file tree
Hide file tree
Showing 11 changed files with 1,735 additions and 1,396 deletions.
150 changes: 150 additions & 0 deletions blark/apischema_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
Serialization helpers for apischema, an optional dependency.
"""
# Largely based on issue discussions regarding tagged unions.
from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable, Iterator
from types import new_class
from typing import Any, Dict, Generic, List, Tuple, TypeVar, get_type_hints

import lark
from apischema import deserializer, serializer, type_name
from apischema.conversions import Conversion
from apischema.metadata import conversion
from apischema.objects import object_deserialization
from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged
from apischema.typing import get_origin
from apischema.utils import to_pascal_case

_alternative_constructors: Dict[type, List[Callable]] = defaultdict(list)
Func = TypeVar("Func", bound=Callable)


def alternative_constructor(func: Func) -> Func:
"""Alternative constructor for a given type."""
return_type = get_type_hints(func)["return"]
_alternative_constructors[get_origin(return_type) or return_type].append(func)
return func


def get_all_subclasses(cls: type) -> Iterator[type]:
"""Recursive implementation of type.__subclasses__"""
for sub_cls in cls.__subclasses__():
yield sub_cls
yield from get_all_subclasses(sub_cls)


Cls = TypeVar("Cls", bound=type)


def _get_generic_name_factory(cls: type, *args: type):
def _capitalized(name: str) -> str:
return name[0].upper() + name[1:]

return "".join((cls.__name__, *(_capitalized(arg.__name__) for arg in args)))


generic_name = type_name(_get_generic_name_factory)


def as_tagged_union(cls: Cls) -> Cls:
"""
Tagged union decorator, to be used on base class.
Supports generics as well, with names generated by way of
`_get_generic_name_factory`.
"""
params = tuple(getattr(cls, "__parameters__", ()))
tagged_union_bases: Tuple[type, ...] = (TaggedUnion,)

# Generic handling is here:
if params:
tagged_union_bases = (TaggedUnion, Generic[params])
generic_name(cls)
prev_init_subclass = getattr(cls, "__init_subclass__", None)

def __init_subclass__(cls, **kwargs):
if prev_init_subclass is not None:
prev_init_subclass(**kwargs)
generic_name(cls)

cls.__init_subclass__ = classmethod(__init_subclass__)

def with_params(cls: type) -> Any:
"""Specify type of Generic if set."""
return cls[params] if params else cls

def serialization() -> Conversion:
"""
Define the serializer Conversion for the tagged union.
source is the base ``cls`` (or ``cls[T]``).
target is the new tagged union class ``TaggedUnion`` which gets the
dictionary {cls.__name__: obj} as its arguments.
"""
annotations = {
# Assume that subclasses have same generic parameters than cls
sub.__name__: Tagged[with_params(sub)]
for sub in get_all_subclasses(cls)
}
namespace = {"__annotations__": annotations}
tagged_union = new_class(
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace)
)
return Conversion(
lambda obj: tagged_union(**{obj.__class__.__name__: obj}),
source=with_params(cls),
target=with_params(tagged_union),
# Conversion must not be inherited because it would lead to
# infinite recursion otherwise
inherited=False,
)

def deserialization() -> Conversion:
"""
Define the deserializer Conversion for the tagged union.
Allows for alternative standalone constructors as per the apischema
example.
"""
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {"__annotations__": annotations}
for sub in get_all_subclasses(cls):
annotations[sub.__name__] = Tagged[with_params(sub)]
for constructor in _alternative_constructors.get(sub, ()):
# Build the alias of the field
alias = to_pascal_case(constructor.__name__)
# object_deserialization uses get_type_hints, but the constructor
# return type is stringified and the class not defined yet,
# so it must be assigned manually
constructor.__annotations__["return"] = with_params(sub)
# Use object_deserialization to wrap constructor as deserializer
deserialization = object_deserialization(constructor, generic_name)
# Add constructor tagged field with its conversion
annotations[alias] = Tagged[with_params(sub)]
namespace[alias] = Tagged(conversion(deserialization=deserialization))
# Create the deserialization tagged union class
tagged_union = new_class(
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace)
)
return Conversion(
lambda obj: get_tagged(obj)[1],
source=with_params(tagged_union),
target=with_params(cls),
)

deserializer(lazy=deserialization, target=cls)
serializer(lazy=serialization, source=cls)
return cls


@serializer
def token_serializer(token: lark.Token) -> List[str]:
return [token.type, token.value]


@deserializer
def token_deserializer(parts: List[str]) -> lark.Token:
return lark.Token(*parts)
3 changes: 3 additions & 0 deletions blark/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

BLARK_TWINCAT_ROOT = os.environ.get("BLARK_TWINCAT_ROOT", ".")
Loading

0 comments on commit a4f241a

Please sign in to comment.