Skip to content

Commit

Permalink
Merge branch 'master' into add-fast-serialization-settings-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
pingsutw authored Dec 27, 2022
2 parents 29ddce0 + 26cd39d commit ad040de
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 12 deletions.
14 changes: 13 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A
from flytekit.types.schema.types import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset

# Handle Optional
if get_origin(python_type) is typing.Union and type(None) in get_args(python_type):
if python_val is None:
return None
return self._serialize_flyte_type(python_val, get_args(python_type)[0])

if hasattr(python_type, "__origin__") and python_type.__origin__ is list:
return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val]

Expand Down Expand Up @@ -400,12 +406,18 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A
python_val.__setattr__(v.name, self._serialize_flyte_type(val, field_type))
return python_val

def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T:
def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> Optional[T]:
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine

# Handle Optional
if get_origin(expected_python_type) is typing.Union and type(None) in get_args(expected_python_type):
if python_val is None:
return None
return self._deserialize_flyte_type(python_val, get_args(expected_python_type)[0])

if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list:
return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val] # type: ignore

Expand Down
98 changes: 88 additions & 10 deletions plugins/flytekit-deck-standard/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
#
-e file:.#egg=flytekitplugins-deck-standard
# via -r requirements.in
appnope==0.1.3
# via
# ipykernel
# ipython
arrow==1.2.3
# via jinja2-time
asttokens==2.2.1
# via stack-data
attrs==22.1.0
# via visions
backcall==0.2.0
# via ipython
binaryornot==0.4.4
# via cookiecutter
certifi==2022.12.7
Expand All @@ -26,22 +34,26 @@ click==8.1.3
# flytekit
cloudpickle==2.2.0
# via flytekit
comm==0.1.2
# via ipykernel
contourpy==1.0.6
# via matplotlib
cookiecutter==2.1.1
# via flytekit
croniter==1.3.8
# via flytekit
cryptography==38.0.4
# via
# pyopenssl
# secretstorage
# via pyopenssl
cycler==0.11.0
# via matplotlib
dataclasses-json==0.5.7
# via flytekit
debugpy==1.6.4
# via ipykernel
decorator==5.1.1
# via retry
# via
# ipython
# retry
deprecated==1.2.13
# via flytekit
diskcache==5.4.0
Expand All @@ -52,6 +64,10 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
entrypoints==0.4
# via jupyter-client
executing==1.2.0
# via stack-data
flyteidl==1.3.0
# via flytekit
flytekit==1.3.0b2
Expand Down Expand Up @@ -79,12 +95,18 @@ importlib-metadata==5.1.0
# flytekit
# keyring
# markdown
ipykernel==6.19.4
# via ipywidgets
ipython==8.7.0
# via
# ipykernel
# ipywidgets
ipywidgets==8.0.4
# via flytekitplugins-deck-standard
jaraco-classes==3.2.3
# via keyring
jeepney==0.8.0
# via
# keyring
# secretstorage
jedi==0.18.2
# via ipython
jinja2==3.1.2
# via
# cookiecutter
Expand All @@ -96,6 +118,12 @@ joblib==1.2.0
# via
# flytekit
# phik
jupyter-client==7.4.8
# via ipykernel
jupyter-core==5.1.1
# via jupyter-client
jupyterlab-widgets==3.0.5
# via ipywidgets
keyring==23.11.0
# via flytekit
kiwisolver==1.4.4
Expand All @@ -118,6 +146,10 @@ matplotlib==3.6.2
# pandas-profiling
# phik
# seaborn
matplotlib-inline==0.1.6
# via
# ipykernel
# ipython
more-itertools==9.0.0
# via jaraco-classes
multimethod==1.9
Expand All @@ -128,6 +160,10 @@ mypy-extensions==0.4.3
# via typing-inspect
natsort==8.2.0
# via flytekit
nest-asyncio==1.5.6
# via
# ipykernel
# jupyter-client
networkx==2.8.8
# via visions
numpy==1.23.5
Expand All @@ -148,6 +184,7 @@ numpy==1.23.5
packaging==22.0
# via
# docker
# ipykernel
# marshmallow
# matplotlib
# statsmodels
Expand All @@ -161,17 +198,27 @@ pandas==1.5.2
# visions
pandas-profiling==3.5.0
# via flytekitplugins-deck-standard
parso==0.8.3
# via jedi
patsy==0.5.3
# via statsmodels
pexpect==4.8.0
# via ipython
phik==0.12.3
# via pandas-profiling
pickleshare==0.7.5
# via ipython
pillow==9.3.0
# via
# imagehash
# matplotlib
# visions
platformdirs==2.6.0
# via jupyter-core
plotly==5.11.0
# via flytekitplugins-deck-standard
prompt-toolkit==3.0.36
# via ipython
protobuf==4.21.11
# via
# flyteidl
Expand All @@ -180,6 +227,12 @@ protobuf==4.21.11
# protoc-gen-swagger
protoc-gen-swagger==0.1.0
# via flyteidl
psutil==5.9.4
# via ipykernel
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.2
# via stack-data
py==1.11.0
# via retry
pyarrow==10.0.1
Expand All @@ -188,6 +241,8 @@ pycparser==2.21
# via cffi
pydantic==1.10.2
# via pandas-profiling
pygments==2.13.0
# via ipython
pyopenssl==22.1.0
# via flytekit
pyparsing==3.0.9
Expand All @@ -197,6 +252,7 @@ python-dateutil==2.8.2
# arrow
# croniter
# flytekit
# jupyter-client
# matplotlib
# pandas
python-json-logger==2.0.4
Expand All @@ -216,6 +272,10 @@ pyyaml==6.0
# cookiecutter
# flytekit
# pandas-profiling
pyzmq==24.0.1
# via
# ipykernel
# jupyter-client
regex==2022.10.31
# via docker-image-py
requests==2.28.1
Expand All @@ -237,14 +297,15 @@ scipy==1.9.3
# statsmodels
seaborn==0.12.1
# via pandas-profiling
secretstorage==3.3.3
# via keyring
six==1.16.0
# via
# asttokens
# patsy
# python-dateutil
sortedcontainers==2.4.0
# via flytekit
stack-data==0.6.2
# via ipython
statsd==3.3.0
# via flytekit
statsmodels==0.13.5
Expand All @@ -257,8 +318,21 @@ text-unidecode==1.3
# via python-slugify
toml==0.10.2
# via responses
tornado==6.2
# via
# ipykernel
# jupyter-client
tqdm==4.64.1
# via pandas-profiling
traitlets==5.8.0
# via
# comm
# ipykernel
# ipython
# ipywidgets
# jupyter-client
# jupyter-core
# matplotlib-inline
typeguard==2.13.3
# via pandas-profiling
types-toml==0.10.8.1
Expand All @@ -278,10 +352,14 @@ urllib3==1.26.13
# responses
visions[type_image_path]==0.7.5
# via pandas-profiling
wcwidth==0.2.5
# via prompt-toolkit
websocket-client==1.4.2
# via docker
wheel==0.38.4
# via flytekit
widgetsnbextension==4.0.5
# via ipywidgets
wrapt==1.14.1
# via
# deprecated
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-deck-standard/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}-standard"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "markdown", "plotly", "pandas_profiling"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "markdown", "plotly", "pandas_profiling", "ipywidgets"]

__version__ = "0.0.0+develop"

Expand Down
85 changes: 85 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import timedelta
from enum import Enum

import mock
import pandas as pd
import pyarrow as pa
import pytest
Expand Down Expand Up @@ -569,6 +570,90 @@ def test_dataclass_int_preserving():
assert ot == o


@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data")
def test_optional_flytefile_in_dataclass(mock_upload_dir):
mock_upload_dir.return_value = True

@dataclass_json
@dataclass
class A(object):
a: int

@dataclass_json
@dataclass
class TestFileStruct(object):
a: FlyteFile
b: typing.Optional[FlyteFile]
b_prime: typing.Optional[FlyteFile]
c: typing.Union[FlyteFile, None]
d: typing.List[FlyteFile]
e: typing.List[typing.Optional[FlyteFile]]
e_prime: typing.List[typing.Optional[FlyteFile]]
f: typing.Dict[str, FlyteFile]
g: typing.Dict[str, typing.Optional[FlyteFile]]
g_prime: typing.Dict[str, typing.Optional[FlyteFile]]
h: typing.Optional[FlyteFile] = None
h_prime: typing.Optional[FlyteFile] = None
i: typing.Optional[A] = None
i_prime: typing.Optional[A] = A(a=99)

remote_path = "s3://tmp/file"
with tempfile.TemporaryFile() as f:
f.write(b"abc")
f1 = FlyteFile("f1", remote_path=remote_path)
o = TestFileStruct(
a=f1,
b=f1,
b_prime=None,
c=f1,
d=[f1],
e=[f1],
e_prime=[None],
f={"a": f1},
g={"a": f1},
g_prime={"a": None},
h=f1,
i=A(a=42),
)

ctx = FlyteContext.current_context()
tf = DataclassTransformer()
lt = tf.get_literal_type(TestFileStruct)
lv = tf.to_literal(ctx, o, TestFileStruct, lt)

assert lv.scalar.generic["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["b"].fields["path"].string_value == remote_path
assert lv.scalar.generic["b_prime"] is None
assert lv.scalar.generic["c"].fields["path"].string_value == remote_path
assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path
assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path
assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value"
assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["g_prime"]["a"] is None
assert lv.scalar.generic["h"].fields["path"].string_value == remote_path
assert lv.scalar.generic["h_prime"] is None
assert lv.scalar.generic["i"].fields["a"].number_value == 42
assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99

ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct)

assert o.a.path == ot.a.remote_source
assert o.b.path == ot.b.remote_source
assert ot.b_prime is None
assert o.c.path == ot.c.remote_source
assert o.d[0].path == ot.d[0].remote_source
assert o.e[0].path == ot.e[0].remote_source
assert o.e_prime == [None]
assert o.f["a"].path == ot.f["a"].remote_source
assert o.g["a"].path == ot.g["a"].remote_source
assert o.g_prime == {"a": None}
assert o.h.path == ot.h.remote_source
assert ot.h_prime is None
assert o.i == ot.i
assert o.i_prime == A(a=99)


def test_flyte_file_in_dataclass():
@dataclass_json
@dataclass
Expand Down

0 comments on commit ad040de

Please sign in to comment.