Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add annotation option for serialization #1615

Merged
merged 7 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,24 +715,32 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
d = dictionary of registered transformers, where is a python `type`
v = lookup type
Step 1:
find a transformer that matches v exactly
If the type is annotated with a TypeTransformer instance, use that.

Step 2:
find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc
find a transformer that matches v exactly

Step 3:
find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc

Step 4:
Walk the inheritance hierarchy of v and find a transformer that matches the first base class.
This is potentially non-deterministic - will depend on the registration pattern.

TODO lets make this deterministic by using an ordered dict

Step 4:
Step 5:
if v is of type data class, use the dataclass transformer
"""

# Step 1
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
args = get_args(python_type)
for annotation in args:
if isinstance(annotation, TypeTransformer):
return annotation

python_type = args[0]

if python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]
Expand Down
47 changes: 47 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import datetime
import json
import os
import tempfile
import typing
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum
from typing import Optional, Type

import mock
import pandas as pd
Expand Down Expand Up @@ -170,6 +172,51 @@ class Foo(object):
assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182])


def test_annotated_type():
class JsonTypeTransformer(TypeTransformer[T]):
LiteralType = LiteralType(
simple=SimpleType.STRING, annotation=TypeAnnotation(annotations=dict(protocol="json"))
)

def get_literal_type(self, t: Type[T]) -> LiteralType:
return self.LiteralType

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]:
return json.loads(lv.scalar.primitive.string_value)

def to_literal(
self, ctx: FlyteContext, python_val: T, python_type: typing.Type[T], expected: LiteralType
) -> Literal:
return Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(python_val))))

class JSONSerialized:
def __class_getitem__(cls, item: Type[T]):
return Annotated[item, JsonTypeTransformer(name=f"json[{item}]", t=item)]

MyJsonDict = JSONSerialized[typing.Dict[str, int]]
_, test_transformer = get_args(MyJsonDict)

assert TypeEngine.get_transformer(MyJsonDict) is test_transformer
assert TypeEngine.to_literal_type(MyJsonDict) == JsonTypeTransformer.LiteralType

test_dict = {"foo": 1}
test_literal = Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(test_dict))))

assert (
TypeEngine.to_python_value(
FlyteContext.current_context(),
test_literal,
MyJsonDict,
)
== test_dict
)

assert (
TypeEngine.to_literal(FlyteContext.current_context(), test_dict, MyJsonDict, JsonTypeTransformer.LiteralType)
== test_literal
)


def test_list_of_dataclass_getting_python_value():
@dataclass_json
@dataclass()
Expand Down