From 1a28915a6b5ee2533e0445cf8486b40192b3160c Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Mon, 30 Jan 2023 22:14:23 -0800 Subject: [PATCH] Fixes guess type bug in UnionTransformer Signed-off-by: Ketan Umare --- flytekit/core/type_engine.py | 2 +- tests/flytekit/unit/core/test_type_engine.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 6ddeb5c58c..61c448b365 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1138,7 +1138,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def guess_python_type(self, literal_type: LiteralType) -> type: if literal_type.union_type is not None: - return typing.Union[tuple(TypeEngine.guess_python_type(v.type) for v in literal_type.union_type.variants)] # type: ignore + return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] raise ValueError(f"Union transformer cannot reverse {literal_type}") diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index bbe46845fd..eb38a8d80b 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -45,7 +45,7 @@ from flytekit.models.annotation import TypeAnnotation from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Void -from flytekit.models.types import LiteralType, SimpleType, TypeStructure +from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.types.directory import TensorboardLogs from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FileExt, JPEGImageFile @@ -941,6 +941,18 @@ def test_union_transformer(): assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int +def test_union_guess_type(): + ut = UnionTransformer() + t = ut.guess_python_type( + LiteralType( + union_type=UnionType( + variants=[LiteralType(simple=SimpleType.STRING), LiteralType(simple=SimpleType.INTEGER)] + ) + ) + ) + assert t == typing.Union[str, int] + + def test_union_type_with_annotated(): pt = typing.Union[ Annotated[str, FlyteAnnotation({"hello": "world"})], Annotated[int, FlyteAnnotation({"test": 123})]