From 38703b7452a9ba05479169a0965d0a75c62fe156 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 10 Nov 2022 15:09:29 -0800 Subject: [PATCH 01/22] wip Signed-off-by: Kevin Su --- Makefile | 8 ++++---- flytekit/core/data_persistence.py | 10 +++++----- flytekit/types/file/file.py | 4 ++-- flytekit/types/numpy/ndarray.py | 2 +- flytekit/types/schema/types.py | 9 +++++---- flytekit/types/structured/structured_dataset.py | 13 +++++++------ .../unit/core/flyte_functools/decorator_source.py | 5 +++-- .../unit/core/flyte_functools/nested_function.py | 2 +- .../unit/core/flyte_functools/simple_decorator.py | 2 +- .../unit/core/flyte_functools/stacked_decorators.py | 2 +- .../core/flyte_functools/unwrapped_decorator.py | 2 +- tests/flytekit/unit/core/test_composition.py | 6 ++---- tests/flytekit/unit/core/test_conditions.py | 8 +++----- tests/flytekit/unit/core/test_imperative.py | 6 ++---- tests/flytekit/unit/core/test_interface.py | 6 +++--- tests/flytekit/unit/core/test_launch_plan.py | 2 +- tests/flytekit/unit/core/test_node_creation.py | 12 ++++-------- tests/flytekit/unit/core/test_realworld_examples.py | 2 +- tests/flytekit/unit/core/test_references.py | 2 +- tests/flytekit/unit/core/test_serialization.py | 12 ++++++------ tests/flytekit/unit/core/test_type_engine.py | 8 ++++---- tests/flytekit/unit/core/test_type_hints.py | 6 ++---- tests/flytekit/unit/core/test_typing_annotation.py | 2 +- tests/flytekit/unit/core/test_workflows.py | 10 +++++----- 24 files changed, 66 insertions(+), 75 deletions(-) diff --git a/Makefile b/Makefile index 4b3278bec0..d8364f8f71 100644 --- a/Makefile +++ b/Makefile @@ -35,12 +35,12 @@ fmt: ## Format code with black and isort .PHONY: lint lint: ## Run linters - mypy flytekit/core || true + # mypy flytekit/core || true mypy flytekit/types || true - mypy tests/flytekit/unit/core || true + # mypy tests/flytekit/unit/core || true # Exclude setup.py to fix error: Duplicate module named "setup" - mypy plugins --exclude setup.py || true - pre-commit run --all-files + # mypy plugins --exclude setup.py || true + # pre-commit run --all-files .PHONY: spellcheck spellcheck: ## Runs a spellchecker over all code and documentation diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index a2ad5311f1..652ec87cd3 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -94,7 +94,7 @@ def put(self, from_path: str, to_path: str, recursive: bool = False): pass @abstractmethod - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> str: + def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> os.PathLike: """ if add_protocol is true then is prefixed else Constructs a path in the format *args @@ -138,7 +138,7 @@ def register_plugin(cls, protocol: str, plugin: typing.Type[DataPersistence], fo cls._PLUGINS[protocol] = plugin @staticmethod - def get_protocol(url: str): + def get_protocol(url: str) -> str: # copy from fsspec https://github.com/fsspec/filesystem_spec/blob/fe09da6942ad043622212927df7442c104fe7932/fsspec/utils.py#L387-L391 parts = re.split(r"(\:\:|\://)", url, 1) if len(parts) > 1: @@ -350,7 +350,7 @@ def local_access(self) -> DiskPersistence: def construct_random_path( self, persist: DataPersistence, file_path_or_file_name: typing.Optional[str] = None - ) -> str: + ) -> os.PathLike: """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ @@ -363,7 +363,7 @@ def construct_random_path( logger.warning(f"No filename detected in {file_path_or_file_name}, generating random path") return persist.construct_path(False, True, key) - def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: + def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = None) -> os.PathLike: """ Constructs a randomized path on the configured raw_output_prefix (persistence layer). the random bit is a UUID and allows for disambiguating paths within the same directory. @@ -375,7 +375,7 @@ def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = def get_random_remote_directory(self): return self.get_random_remote_path(None) - def get_random_local_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: + def get_random_local_path(self, file_path_or_file_name: typing.Optional[str] = None) -> os.PathLike: """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 9fc55f76ce..6537f85cae 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -346,13 +346,13 @@ def to_python_value( return FlyteFile(uri) # The rest of the logic is only for FlyteFile types. - if not issubclass(expected_python_type, FlyteFile): + if not issubclass(expected_python_type, FlyteFile): # type: ignore raise TypeError(f"Neither os.PathLike nor FlyteFile specified {expected_python_type}") # This is a local file path, like /usr/local/my_file, don't mess with it. Certainly, downloading it doesn't # make any sense. if not ctx.file_access.is_remote(uri): - return expected_python_type(uri) + return expected_python_type(uri) # type: ignore # For the remote case, return an FlyteFile object that can download local_path = ctx.file_access.get_random_local_path(uri) diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index 38fedfacca..d766818bfd 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -77,7 +77,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return np.load( file=local_path, allow_pickle=metadata.get("allow_pickle", False), - mmap_mode=metadata.get("mmap_mode"), + mmap_mode=metadata.get("mmap_mode"), # type: ignore ) def guess_python_type(self, literal_type: LiteralType) -> typing.Type[np.ndarray]: diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 6f01cea085..cd50b4fb62 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -227,10 +227,10 @@ def format(cls) -> SchemaFormat: def __init__( self, - local_path: os.PathLike = None, - remote_path: os.PathLike = None, + local_path: typing.Optional[os.PathLike] = None, + remote_path: typing.Optional[os.PathLike] = None, supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE, - downloader: typing.Callable[[str, os.PathLike], None] = None, + downloader: typing.Optional[typing.Callable] = None, ): if supported_mode == SchemaOpenMode.READ and remote_path is None: @@ -286,6 +286,7 @@ def open( raise AssertionError("downloader cannot be None in read mode!") # Only for readable objects if they are not downloaded already, we should download them # Write objects should already have everything written to + assert self.remote_path is not None self._downloader(self.remote_path, self.local_path) self._downloaded = True if mode == SchemaOpenMode.WRITE: @@ -303,7 +304,7 @@ def as_readonly(self) -> FlyteSchema: s = FlyteSchema.__class_getitem__(self.columns(), self.format())( local_path=self.local_path, # Dummy path is ok, as we will assume data is already downloaded and will not download again - remote_path=self.remote_path if self.remote_path else "", + remote_path=self.remote_path, supported_mode=SchemaOpenMode.READ, ) s._downloaded = True diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 99fcb49d7b..8b11778321 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -44,7 +44,7 @@ class StructuredDataset(object): class (that is just a model, a Python class representation of the protobuf). """ - uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + uri: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String())) DEFAULT_FILE_FORMAT = PARQUET @@ -60,23 +60,23 @@ def column_names(cls) -> typing.List[str]: def __init__( self, dataframe: typing.Optional[typing.Any] = None, - uri: Optional[str, os.PathLike] = None, + uri: typing.Optional[typing.Union[str, os.PathLike]] = None, metadata: typing.Optional[literals.StructuredDatasetMetadata] = None, **kwargs, ): self._dataframe = dataframe # Make these fields public, so that the dataclass transformer can set a value for it # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298 - self.uri = uri + self.uri = str(uri) # This is a special attribute that indicates if the data was either downloaded or uploaded self._metadata = metadata # This is not for users to set, the transformer will set this. self._literal_sd: Optional[literals.StructuredDataset] = None # Not meant for users to set, will be set by an open() call - self._dataframe_type: Optional[Type[DF]] = None + self._dataframe_type: Optional[DF] = None @property - def dataframe(self) -> Optional[Type[DF]]: + def dataframe(self) -> Optional[DF]: return self._dataframe @property @@ -381,7 +381,7 @@ def register_renderer(cls, python_type: Type, renderer: Renderable): cls.Renderers[python_type] = renderer @classmethod - def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False): + def register(cls, h: Handlers, default_for_type: bool = False, override: bool = False): """ Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not specify a protocol (e.g. s3, gs, etc.) field, then @@ -691,6 +691,7 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ # Here we only render column information by default instead of opening the structured dataset. col = typing.cast(StructuredDataset, python_val).columns() df = pd.DataFrame(col, ["column type"]) + assert hasattr(df, "to_html") return df.to_html() else: df = python_val diff --git a/tests/flytekit/unit/core/flyte_functools/decorator_source.py b/tests/flytekit/unit/core/flyte_functools/decorator_source.py index 9c92364649..c0ce833263 100644 --- a/tests/flytekit/unit/core/flyte_functools/decorator_source.py +++ b/tests/flytekit/unit/core/flyte_functools/decorator_source.py @@ -1,10 +1,11 @@ """Script used for testing local execution of functool.wraps-wrapped tasks for stacked decorators""" - +import functools +import typing from functools import wraps from typing import List -def task_setup(function: callable = None, *, integration_requests: List = None) -> None: +def task_setup(function: typing.Callable, *, integration_requests: List = None) -> typing.Callable: integration_requests = integration_requests or [] @wraps(function) diff --git a/tests/flytekit/unit/core/flyte_functools/nested_function.py b/tests/flytekit/unit/core/flyte_functools/nested_function.py index 6a3ccfd9e1..98a39e497a 100644 --- a/tests/flytekit/unit/core/flyte_functools/nested_function.py +++ b/tests/flytekit/unit/core/flyte_functools/nested_function.py @@ -32,4 +32,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/simple_decorator.py b/tests/flytekit/unit/core/flyte_functools/simple_decorator.py index a51a283be5..3278af1bb0 100644 --- a/tests/flytekit/unit/core/flyte_functools/simple_decorator.py +++ b/tests/flytekit/unit/core/flyte_functools/simple_decorator.py @@ -38,4 +38,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py b/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py index 07c46cd46a..dd445a6fb3 100644 --- a/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py +++ b/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py @@ -48,4 +48,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py b/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py index 9f7e6599c6..6e22ca9840 100644 --- a/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py +++ b/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py @@ -26,4 +26,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 3963c77c8d..8eb105777e 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -35,14 +35,12 @@ def my_wf(a: int, b: str) -> (int, str, str): def test_single_named_output_subwf(): - nt = NamedTuple("SubWfOutput", sub_int=int) + nt = NamedTuple("SubWfOutput", [("sub_int", int)]) @task def t1(a: int) -> nt: a = a + 2 - return nt( - a, - ) # returns a named tuple + return nt(a) @task def t2(a: int, b: int) -> nt: diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index be85918b74..1e0446df98 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -283,7 +283,7 @@ def branching(x: int): def test_subworkflow_condition_named_tuple(): - nt = typing.NamedTuple("SampleNamedTuple", b=int, c=str) + nt = typing.NamedTuple("SampleNamedTuple", [("b", int), ("c", str)]) @task def t() -> nt: @@ -302,13 +302,11 @@ def branching(x: int) -> nt: def test_subworkflow_condition_single_named_tuple(): - nt = typing.NamedTuple("SampleNamedTuple", b=int) + nt = typing.NamedTuple("SampleNamedTuple", [("b", int)]) @task def t() -> nt: - return nt( - 5, - ) + return nt(5) @workflow def wf1() -> nt: diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index ab90d991b1..14b84ca0b6 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -68,15 +68,13 @@ def t2(): assert len(wf_spec.template.interface.outputs) == 1 # docs_equivalent_start - nt = typing.NamedTuple("wf_output", from_n0t1=str) + nt = typing.NamedTuple("wf_output", [("from_n0t1", str)]) @workflow def my_workflow(in1: str) -> nt: x = t1(a=in1) t2() - return nt( - x, - ) + return nt(x) # docs_equivalent_end diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 442851a8a2..a0068d08ca 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -95,10 +95,10 @@ def t(a: int, b: str) -> Dict[str, int]: def test_named_tuples(): - nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int) + nt1 = typing.NamedTuple("NT1", [("x_str", str), ("y_int", int)]) - def x(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): - return ("hello world", 5) + def x(a: int, b: str) -> typing.NamedTuple("NT1", [("x_str", str), ("y_int", int)]): + return "hello world", 5 def y(a: int, b: str) -> nt1: return nt1("hello world", 5) diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py index ffaff8daad..3addd13e42 100644 --- a/tests/flytekit/unit/core/test_launch_plan.py +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -292,7 +292,7 @@ def wf(a: int, c: str) -> (int, str): def test_lp_all_parameters(): - nt = typing.NamedTuple("OutputsBC", t1_int_output=int, c=str) + nt = typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)]) @task def t1(a: int) -> nt: diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index bb1773ec5c..23a4bac5a3 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -96,14 +96,12 @@ def empty_wf2(): def test_more_normal_task(): - nt = typing.NamedTuple("OneOutput", t1_str_output=str) + nt = typing.NamedTuple("OneOutput", [("t1_str_output", str)]) @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt( - f"{a + 2}", - ) + return nt(f"{a + 2}") @task def t1_nt(a: int) -> nt: @@ -126,14 +124,12 @@ def my_wf(a: int, b: str) -> (str, str): def test_reserved_keyword(): - nt = typing.NamedTuple("OneOutput", outputs=str) + nt = typing.NamedTuple("OneOutput", [("outputs", str)]) @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt( - f"{a + 2}", - ) + return nt(f"{a + 2}") # Test that you can't name an output "outputs" with pytest.raises(FlyteAssertion): diff --git a/tests/flytekit/unit/core/test_realworld_examples.py b/tests/flytekit/unit/core/test_realworld_examples.py index 83e859c1da..c5b3e374fc 100644 --- a/tests/flytekit/unit/core/test_realworld_examples.py +++ b/tests/flytekit/unit/core/test_realworld_examples.py @@ -105,7 +105,7 @@ def split_traintest_dataset( # We will fake train test split. Just return the same dataset multiple times return x, x, y, y - nt = typing.NamedTuple("Outputs", model=FlyteFile[MODELSER_JOBLIB]) + nt = typing.NamedTuple("Outputs", [("model", FlyteFile[MODELSER_JOBLIB])]) @task(cache_version="1.0", cache=True, limits=Resources(mem="200Mi")) def fit(x: FlyteSchema[FEATURE_COLUMNS], y: FlyteSchema[CLASSES_COLUMNS], hyperparams: dict) -> nt: diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index df6e093b55..75dc431918 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -155,7 +155,7 @@ def inner_test(ref_mock): inner_test() - nt1 = typing.NamedTuple("DummyNamedTuple", t1_int_output=int, c=str) + nt1 = typing.NamedTuple("DummyNamedTuple", [("t1_int_output", int), ("c", str)]) @task def t1(a: int) -> nt1: diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index a96a94843b..8deb406fb2 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -406,17 +406,17 @@ def wf() -> typing.NamedTuple("OP", a=str): def test_named_outputs_nested(): - nm = typing.NamedTuple("OP", greet=str) + nm = typing.NamedTuple("OP", [("greet", str)]) @task def say_hello() -> nm: return nm("hello world") - wf_outputs = typing.NamedTuple("OP2", greet1=str, greet2=str) + wf_outputs = typing.NamedTuple("OP2", [("greet1", str), ("greet2", str)]) @workflow def my_wf() -> wf_outputs: - # Note only Namedtuples can be created like this + # Note only Namedtuple can be created like this return wf_outputs(say_hello().greet, say_hello().greet) x, y = my_wf() @@ -425,19 +425,19 @@ def my_wf() -> wf_outputs: def test_named_outputs_nested_fail(): - nm = typing.NamedTuple("OP", greet=str) + nm = typing.NamedTuple("OP", [("greet", str)]) @task def say_hello() -> nm: return nm("hello world") - wf_outputs = typing.NamedTuple("OP2", greet1=str, greet2=str) + wf_outputs = typing.NamedTuple("OP2", [("greet1", str), ("greet2", str)]) with pytest.raises(AssertionError): # this should fail because say_hello returns a tuple, but we do not de-reference it @workflow def my_wf() -> wf_outputs: - # Note only Namedtuples can be created like this + # Note only Namedtuple can be created like this return wf_outputs(say_hello(), say_hello()) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 3e813c0fb7..ead14e9052 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -17,7 +17,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated +from typing_extensions import Annotated, TypeAlias from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation @@ -1374,21 +1374,21 @@ def test_multiple_annotations(): TypeEngine.to_literal_type(t) -TestSchema = FlyteSchema[kwtypes(some_str=str)] +TestSchema = FlyteSchema[kwtypes(some_str=str)] # type: ignore @dataclass_json @dataclass class InnerResult: number: int - schema: TestSchema + schema: TestSchema # type: ignore @dataclass_json @dataclass class Result: result: InnerResult - schema: TestSchema + schema: TestSchema # type: ignore def test_schema_in_dataclass(): diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index a5a9430328..739adff86e 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -184,13 +184,11 @@ def my_wf(a: int, b: str) -> (int, str): assert my_wf._output_bindings[0].var == "o0" assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output" - nt = typing.NamedTuple("SingleNT", t1_int_output=float) + nt = typing.NamedTuple("SingleNT", [("t1_int_output", float)]) @task def t3(a: int) -> nt: - return nt( - a + 2, - ) + return nt(a + 2) assert t3.python_interface.output_tuple_name == "SingleNT" assert t3.interface.outputs["t1_int_output"] is not None diff --git a/tests/flytekit/unit/core/test_typing_annotation.py b/tests/flytekit/unit/core/test_typing_annotation.py index 9c2d09c145..2937d9f978 100644 --- a/tests/flytekit/unit/core/test_typing_annotation.py +++ b/tests/flytekit/unit/core/test_typing_annotation.py @@ -18,7 +18,7 @@ env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) -entity_mapping = OrderedDict() +entity_mapping: OrderedDict = OrderedDict() @task diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 46389daed2..eb5c10f719 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -94,7 +94,7 @@ def list_output_wf() -> typing.List[int]: def test_sub_wf_single_named_tuple(): - nt = typing.NamedTuple("SingleNamedOutput", named1=int) + nt = typing.NamedTuple("SingleNamedOutput", [("named1", int)]) @task def t1(a: int) -> nt: @@ -115,7 +115,7 @@ def wf(b: int) -> nt: def test_sub_wf_multi_named_tuple(): - nt = typing.NamedTuple("Multi", named1=int, named2=int) + nt = typing.NamedTuple("Multi", [("named1", int), ("named2", int)]) @task def t1(a: int) -> nt: @@ -153,7 +153,7 @@ def no_outputs_wf(): with pytest.raises(AssertionError): @workflow - def one_output_wf() -> int: # noqa + def one_output_wf() -> int: # type: ignore t1(a=3) @@ -309,10 +309,10 @@ def sd_to_schema_wf() -> pd.DataFrame: @workflow -def schema_to_sd_wf() -> (pd.DataFrame, pd.DataFrame): +def schema_to_sd_wf() -> typing.Tuple[pd.DataFrame, pd.DataFrame]: # schema -> StructuredDataset df = t4() - return t2(df=df), t5(sd=df) + return t2(df=df), t5(sd=df) # type: ignore def test_structured_dataset_wf(): From a90efb2e4dc64d46a466541a597732ccaa2c0195 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 11 Nov 2022 17:40:50 -0800 Subject: [PATCH 02/22] Fix mypy errors Signed-off-by: Kevin Su --- Makefile | 8 +- flytekit/core/base_sql_task.py | 6 +- flytekit/core/base_task.py | 38 +++--- flytekit/core/class_based_resolver.py | 4 +- flytekit/core/condition.py | 6 +- flytekit/core/container_task.py | 4 +- flytekit/core/context_manager.py | 17 +-- flytekit/core/data_persistence.py | 12 +- flytekit/core/interface.py | 81 +++++++------ flytekit/core/launch_plan.py | 8 +- flytekit/core/map_task.py | 10 +- flytekit/core/node.py | 2 +- flytekit/core/node_creation.py | 9 +- flytekit/core/promise.py | 44 ++++--- flytekit/core/python_auto_container.py | 13 +- .../core/python_customized_container_task.py | 6 +- flytekit/core/python_function_task.py | 2 +- flytekit/core/reference.py | 2 +- flytekit/core/reference_entity.py | 4 +- flytekit/core/resources.py | 4 +- flytekit/core/shim_task.py | 14 ++- flytekit/core/task.py | 4 +- flytekit/core/testing.py | 7 +- flytekit/core/tracked_abc.py | 2 +- flytekit/core/tracker.py | 2 +- flytekit/core/type_engine.py | 111 ++++++++++-------- flytekit/core/utils.py | 4 +- flytekit/core/workflow.py | 27 +++-- flytekit/models/literals.py | 2 +- flytekit/types/schema/types.py | 35 +++--- .../types/structured/structured_dataset.py | 23 ++-- tests/flytekit/unit/core/test_type_engine.py | 2 +- 32 files changed, 282 insertions(+), 231 deletions(-) diff --git a/Makefile b/Makefile index d8364f8f71..484d4dafb4 100644 --- a/Makefile +++ b/Makefile @@ -35,12 +35,12 @@ fmt: ## Format code with black and isort .PHONY: lint lint: ## Run linters - # mypy flytekit/core || true - mypy flytekit/types || true - # mypy tests/flytekit/unit/core || true + mypy flytekit/core + mypy flytekit/types + mypy tests/flytekit/unit/core # Exclude setup.py to fix error: Duplicate module named "setup" # mypy plugins --exclude setup.py || true - # pre-commit run --all-files + pre-commit run --all-files .PHONY: spellcheck spellcheck: ## Runs a spellchecker over all code and documentation diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index d2e4838ed8..09b35dec5d 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, Optional, Tuple, Type, TypeVar from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface @@ -22,10 +22,10 @@ def __init__( self, name: str, query_template: str, + task_config: T, task_type="sql_task", - inputs: Optional[Dict[str, Type]] = None, + inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - task_config: Optional[T] = None, outputs: Dict[str, Type] = None, **kwargs, ): diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index ef345aadda..fa0c1eaf94 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -21,10 +21,16 @@ import datetime from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast from flytekit.configuration import SerializationSettings -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities +from flytekit.core.context_manager import ( + ExecutionParameters, + ExecutionState, + FlyteContext, + FlyteContextManager, + FlyteEntities, +) from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.local_cache import LocalTaskCache from flytekit.core.promise import ( @@ -168,7 +174,7 @@ def __init__( FlyteEntities.entities.append(self) @property - def interface(self) -> Optional[_interface_models.TypedInterface]: + def interface(self) -> _interface_models.TypedInterface: return self._interface @property @@ -232,8 +238,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr kwargs = translate_inputs_to_literals( ctx, incoming_values=kwargs, - flyte_interface_types=self.interface.inputs, # type: ignore - native_types=self.get_input_types(), + flyte_interface_types=self.interface.inputs, + native_types=self.get_input_types(), # type: ignore ) input_literal_map = _literal_models.LiteralMap(literals=kwargs) @@ -258,8 +264,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr else: logger.info("Cache hit") else: - es = ctx.execution_state - b = es.user_space_params.with_task_sandbox() + es = cast(ExecutionState, ctx.execution_state) + b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox() ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) outputs_literals = outputs_literal_map.literals @@ -279,8 +285,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface) - def __call__(self, *args, **kwargs): - return flyte_entity_call_handler(self, *args, **kwargs) + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore def compile(self, ctx: FlyteContext, *args, **kwargs): raise Exception("not implemented") @@ -406,19 +412,19 @@ def task_config(self) -> T: """ return self._task_config - def get_type_for_input_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_input_var(self, k: str, v: Any) -> Type[Any]: """ Returns the python type for an input variable by name. """ return self._python_interface.inputs[k] - def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_output_var(self, k: str, v: Any) -> Type[Any]: """ Returns the python type for the specified output variable by name. """ return self._python_interface.outputs[k] - def get_input_types(self) -> Optional[Dict[str, type]]: + def get_input_types(self) -> Dict[str, type]: """ Returns the names and python types as a dictionary for the inputs of this task. """ @@ -464,7 +470,9 @@ def dispatch_execute( # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params)) + ctx.with_execution_state( + cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params) + ) # type: ignore ) as exec_ctx: # TODO We could support default values here too - but not part of the plan right now @@ -544,7 +552,7 @@ def dispatch_execute( # After the execute has been successfully completed return outputs_literal_map - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore """ This is the method that will be invoked directly before executing the task method and before all the inputs are converted. One particular case where this is useful is if the context is to be modified for the user process @@ -562,7 +570,7 @@ def execute(self, **kwargs) -> Any: """ pass - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any: """ Post execute is called after the execution has completed, with the user_params and can be used to clean-up, or alter the outputs to match the intended tasks outputs. If not overridden, then this function is a No-op diff --git a/flytekit/core/class_based_resolver.py b/flytekit/core/class_based_resolver.py index d47820f811..49970d5623 100644 --- a/flytekit/core/class_based_resolver.py +++ b/flytekit/core/class_based_resolver.py @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs): def name(self) -> str: return "ClassStorageTaskResolver" - def get_all_tasks(self) -> List[PythonAutoContainerTask]: + def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type:ignore return self.mapping def add(self, t: PythonAutoContainerTask): @@ -33,7 +33,7 @@ def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: idx = int(loader_args[0]) return self.mapping[idx] - def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: # type: ignore """ This is responsible for turning an instance of a task into args that the load_task function can reconstitute. """ diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index b5cae86923..76553db702 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -111,7 +111,7 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP return self._compute_outputs(n) return self._condition - def if_(self, expr: bool) -> Case: + def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case: return self._condition._if(expr) def compute_output_vars(self) -> typing.Optional[typing.List[str]]: @@ -360,7 +360,7 @@ def create_branch_node_promise_var(node_id: str, var: str) -> str: return f"{node_id}.{var}" -def merge_promises(*args: Promise) -> typing.List[Promise]: +def merge_promises(*args: Optional[Promise]) -> typing.List[Promise]: node_vars: typing.Set[typing.Tuple[str, str]] = set() merged_promises: typing.List[Promise] = [] for p in args: @@ -414,7 +414,7 @@ def transform_to_boolexpr( def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]: - expr, promises = transform_to_boolexpr(c.expr) + expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr)) n = c.output_promise.ref.node # type: ignore return _core_wf.IfBlock(condition=expr, then_node=n), promises diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 1194720b0e..5527375c44 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata @@ -35,7 +35,7 @@ def __init__( name: str, image: str, command: List[str], - inputs: Optional[Dict[str, Type]] = None, + inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, arguments: List[str] = None, outputs: Dict[str, Type] = None, diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 0f74c1dee1..db738ed77c 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -27,6 +27,7 @@ from enum import Enum from typing import Generator, List, Optional, Union +from flytekit import LaunchPlan from flytekit.clients import friendly as friendly_client # noqa from flytekit.configuration import Config, SecretsConfig, SerializationSettings from flytekit.core import mock_stats, utils @@ -108,7 +109,7 @@ def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: def build(self) -> ExecutionParameters: if not isinstance(self.working_dir, utils.AutoDeletingTempDir): - pathlib.Path(self.working_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(typing.cast(str, self.working_dir)).mkdir(parents=True, exist_ok=True) return ExecutionParameters( execution_date=self.execution_date, stats=self.stats, @@ -130,7 +131,7 @@ def with_task_sandbox(self) -> Builder: prefix = self.working_directory if isinstance(self.working_directory, utils.AutoDeletingTempDir): prefix = self.working_directory.name - task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) + task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) # type: ignore p = pathlib.Path(task_sandbox_dir) cp_dir = p.joinpath("__cp") cp_dir.mkdir(exist_ok=True) @@ -287,7 +288,7 @@ def get(self, key: str) -> typing.Any: """ Returns task specific context if present else raise an error. The returned context will match the key """ - return self.__getattr__(attr_name=key) + return self.__getattr__(attr_name=key) # type: ignore class SecretsManager(object): @@ -467,14 +468,14 @@ class Mode(Enum): LOCAL_TASK_EXECUTION = 3 mode: Optional[ExecutionState.Mode] - working_dir: os.PathLike + working_dir: Union[os.PathLike, str] engine_dir: Optional[Union[os.PathLike, str]] branch_eval_mode: Optional[BranchEvalMode] user_space_params: Optional[ExecutionParameters] def __init__( self, - working_dir: os.PathLike, + working_dir: Union[os.PathLike, str], mode: Optional[ExecutionState.Mode] = None, engine_dir: Optional[Union[os.PathLike, str]] = None, branch_eval_mode: Optional[BranchEvalMode] = None, @@ -607,7 +608,7 @@ def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> Exec return ExecutionState(working_dir=working_dir, user_space_params=self.user_space_params) @staticmethod - def current_context() -> Optional[FlyteContext]: + def current_context() -> FlyteContext: """ This method exists only to maintain backwards compatibility. Please use ``FlyteContextManager.current_context()`` instead. @@ -639,7 +640,7 @@ def get_deck(self) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ig """ from flytekit.deck.deck import _get_deck - return _get_deck(self.execution_state.user_space_params) + return _get_deck(typing.cast(ExecutionState, self.execution_state).user_space_params) @dataclass class Builder(object): @@ -852,7 +853,7 @@ class FlyteEntities(object): registration process """ - entities = [] + entities: LaunchPlan = [] FlyteContextManager.initialize() diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 652ec87cd3..1ca61f9268 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -54,7 +54,7 @@ class DataPersistence(object): Base abstract type for all DataPersistence operations. This can be extended using the flytekitplugins architecture """ - def __init__(self, name: str, default_prefix: typing.Optional[str] = None, **kwargs): + def __init__(self, name: str = "", default_prefix: typing.Optional[str] = None, **kwargs): self._name = name self._default_prefix = default_prefix @@ -94,7 +94,7 @@ def put(self, from_path: str, to_path: str, recursive: bool = False): pass @abstractmethod - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> os.PathLike: + def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> str: """ if add_protocol is true then is prefixed else Constructs a path in the format *args @@ -350,7 +350,7 @@ def local_access(self) -> DiskPersistence: def construct_random_path( self, persist: DataPersistence, file_path_or_file_name: typing.Optional[str] = None - ) -> os.PathLike: + ) -> str: """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ @@ -363,7 +363,7 @@ def construct_random_path( logger.warning(f"No filename detected in {file_path_or_file_name}, generating random path") return persist.construct_path(False, True, key) - def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = None) -> os.PathLike: + def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: """ Constructs a randomized path on the configured raw_output_prefix (persistence layer). the random bit is a UUID and allows for disambiguating paths within the same directory. @@ -375,7 +375,7 @@ def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = def get_random_remote_directory(self): return self.get_random_remote_path(None) - def get_random_local_path(self, file_path_or_file_name: typing.Optional[str] = None) -> os.PathLike: + def get_random_local_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ @@ -437,7 +437,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart=False): f"Original exception: {str(ex)}" ) - def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart=False): + def put_data(self, local_path: str, remote_path: str, is_multipart=False): """ The implication here is that we're always going to put data to the remote location, so we .remote to ensure we don't use the true local proxy if the remote path is a file:// diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 63d7c8106f..bcf3490a79 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -5,7 +5,7 @@ import inspect import typing from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast from typing_extensions import Annotated, get_args, get_origin, get_type_hints @@ -28,8 +28,8 @@ class Interface(object): def __init__( self, - inputs: typing.Optional[typing.Dict[str, Union[Type, Tuple[Type, Any]], None]] = None, - outputs: typing.Optional[typing.Dict[str, Type]] = None, + inputs: Optional[Dict[str, Type]] | Optional[Dict[str, Tuple[Type, Any]]] = None, + outputs: Optional[Dict[str, Type]] = None, output_tuple_name: Optional[str] = None, docstring: Optional[Docstring] = None, ): @@ -43,13 +43,13 @@ def __init__( primarily used when handling one-element NamedTuples. :param docstring: Docstring of the annotated @task or @workflow from which the interface derives from. """ - self._inputs = {} + self._inputs: Dict[str, Tuple[Type, Any]] | Dict[str, Type] if inputs: for k, v in inputs.items(): - if isinstance(v, Tuple) and len(v) > 1: - self._inputs[k] = v + if type(v) is tuple and len(cast(Tuple, v)) > 1: + self._inputs[k] = v # type: ignore else: - self._inputs[k] = (v, None) + self._inputs[k] = (v, None) # type: ignore self._outputs = outputs if outputs else {} self._output_tuple_name = output_tuple_name @@ -57,7 +57,7 @@ def __init__( variables = [k for k in outputs.keys()] # TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one. - class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): + class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): # type: ignore """ This class can be used in two different places. For multivariate-return entities this class is used to rewrap the outputs so that our with_overrides function can work. @@ -90,7 +90,7 @@ def __rshift__(self, *args, **kwargs): self._docstring = docstring @property - def output_tuple(self) -> Optional[Type[collections.namedtuple]]: + def output_tuple(self) -> Type[collections.namedtuple]: # type: ignore return self._output_tuple_class @property @@ -98,7 +98,7 @@ def output_tuple_name(self) -> Optional[str]: return self._output_tuple_name @property - def inputs(self) -> typing.Dict[str, Type]: + def inputs(self) -> Dict[str, type]: r = {} for k, v in self._inputs.items(): r[k] = v[0] @@ -111,8 +111,8 @@ def output_names(self) -> Optional[List[str]]: return None @property - def inputs_with_defaults(self) -> typing.Dict[str, Tuple[Type, Any]]: - return self._inputs + def inputs_with_defaults(self) -> Dict[str, Tuple[Type, Any]]: + return cast(Dict[str, Tuple[Type, Any]], self._inputs) @property def default_inputs_as_kwargs(self) -> Dict[str, Any]: @@ -126,7 +126,7 @@ def outputs(self) -> typing.Dict[str, type]: def docstring(self) -> Optional[Docstring]: return self._docstring - def remove_inputs(self, vars: List[str]) -> Interface: + def remove_inputs(self, vars: Optional[List[str]]) -> Interface: """ This method is useful in removing some variables from the Flyte backend inputs specification, as these are implicit local only inputs or will be supplied by the library at runtime. For example, spark-session etc @@ -151,7 +151,7 @@ def with_inputs(self, extra_inputs: Dict[str, Type]) -> Interface: for k, v in extra_inputs.items(): if k in new_inputs: raise ValueError(f"Input {k} cannot be added as it already exists in the interface") - new_inputs[k] = v + cast(Dict[str, Type], new_inputs)[k] = v return Interface(new_inputs, self._outputs, docstring=self.docstring) def with_outputs(self, extra_outputs: Dict[str, Type]) -> Interface: @@ -241,7 +241,7 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: om = {} for k, v in m.items(): - om[k] = typing.List[v] + om[k] = typing.List[v] # type: ignore return om # type: ignore @@ -256,18 +256,20 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface: return Interface(inputs=map_inputs, outputs=map_outputs) -def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T], Annotated]: +def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T]]: try: if hasattr(t, "__origin__") and hasattr(t, "__args__"): - if get_origin(t) is list: - return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])] - elif get_origin(t) is dict and t.__args__[0] == str: - return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])] - elif get_origin(t) is typing.Union: - return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] - elif get_origin(t) is Annotated: + ot = get_origin(t) + args = getattr(t, "__args__") + if ot is list: + return typing.List[_change_unrecognized_type_to_pickle(args[0])] # type: ignore + elif ot is dict and args[0] == str: + return typing.Dict[str, _change_unrecognized_type_to_pickle(args[1])] # type: ignore + elif ot is typing.Union: + return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] # type: ignore + elif ot is Annotated: base_type, *config = get_args(t) - return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] + return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] # type: ignore TypeEngine.get_transformer(t) except ValueError: logger.warning( @@ -295,12 +297,12 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc outputs = extract_return_annotation(return_annotation) for k, v in outputs.items(): outputs[k] = _change_unrecognized_type_to_pickle(v) # type: ignore - inputs = OrderedDict() + inputs: Dict[str, Tuple[Type, Any]] = OrderedDict() for k, v in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) default = v.default if v.default is not inspect.Parameter.empty else None # Inputs with default values are currently ignored, we may want to look into that in the future - inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) + inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) # type: ignore # This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We # would like to preserve that name in our custom collections.namedtuple. @@ -326,18 +328,19 @@ def transform_variable_map( if variable_map: for k, v in variable_map.items(): res[k] = transform_type(v, descriptions.get(k, k)) - sub_type: Type[T] = v + sub_type: type = v if hasattr(v, "__origin__") and hasattr(v, "__args__"): - if v.__origin__ is list: - sub_type = v.__args__[0] - elif v.__origin__ is dict: - sub_type = v.__args__[1] - if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle: - if hasattr(sub_type.python_type(), "__name__"): - res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} - elif hasattr(sub_type.python_type(), "_name"): + if getattr(v, "__origin__") is list: + sub_type = getattr(v, "__args__")[0] + elif getattr(v, "__origin__") is dict: + sub_type = getattr(v, "__args__")[1] + if hasattr(sub_type, "__origin__") and getattr(sub_type, "__origin__") is FlytePickle: + original_type = cast(FlytePickle, sub_type).python_type() + if hasattr(original_type, "__name__"): + res[k].type.metadata = {"python_class_name": original_type.__name__} + elif hasattr(original_type, "_name"): # If the class doesn't have the __name__ attribute, like typing.Sequence, use _name instead. - res[k].type.metadata = {"python_class_name": sub_type.python_type()._name} + res[k].type.metadata = {"python_class_name": original_type._name} return res @@ -394,13 +397,13 @@ def t(a: int, b: str) -> Dict[str, int]: ... # This statement results in true for typing.Namedtuple, single and void return types, so this # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python - if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): + if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): # type: ignore # isinstance / issubclass does not work for Namedtuple. # Options 1 and 2 bases = return_annotation.__bases__ # type: ignore if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"): logger.debug(f"Task returns named tuple {return_annotation}") - return dict(get_type_hints(return_annotation, include_extras=True)) + return dict(get_type_hints(cast(Type, return_annotation), include_extras=True)) if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore # Handle option 3 @@ -420,7 +423,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... else: # Handle all other single return types logger.debug(f"Task returns unnamed native tuple {return_annotation}") - return {default_output_name(): return_annotation} + return {default_output_name(): cast(Type, return_annotation)} def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]: diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 0d143e5fe8..ed77574e35 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -74,7 +74,7 @@ def wf(a: int, c: str) -> str: # The reason we cache is simply because users may get the default launch plan twice for a single Workflow. We # don't want to create two defaults, could be confusing. - CACHE = {} + CACHE: typing.Dict[str, LaunchPlan] = {} @staticmethod def get_default_launch_plan(ctx: FlyteContext, workflow: _annotated_workflow.WorkflowBase) -> LaunchPlan: @@ -130,7 +130,7 @@ def create( temp_inputs = {} for k, v in default_inputs.items(): temp_inputs[k] = (workflow.python_interface.inputs[k], v) - temp_interface = Interface(inputs=temp_inputs, outputs={}) + temp_interface = Interface(inputs=temp_inputs, outputs={}) # type: ignore temp_signature = transform_inputs_to_parameters(ctx, temp_interface) wf_signature_parameters._parameters.update(temp_signature.parameters) @@ -313,7 +313,7 @@ def __init__( self._parameters = _interface_models.ParameterMap(parameters=parameters) self._fixed_inputs = fixed_inputs # See create() for additional information - self._saved_inputs = {} + self._saved_inputs: Dict[str, Any] = {} self._schedule = schedule self._notifications = notifications or [] @@ -335,7 +335,6 @@ def clone_with( labels: _common_models.Labels = None, annotations: _common_models.Annotations = None, raw_output_data_config: _common_models.RawOutputDataConfig = None, - auth_role: _common_models.AuthRole = None, max_parallelism: int = None, security_context: typing.Optional[security.SecurityContext] = None, ) -> LaunchPlan: @@ -349,7 +348,6 @@ def clone_with( labels=labels or self.labels, annotations=annotations or self.annotations, raw_output_data_config=raw_output_data_config or self.raw_output_data_config, - auth_role=auth_role or self._auth_role, max_parallelism=max_parallelism or self.max_parallelism, security_context=security_context or self.security_context, ) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 3b5c0a09ca..327c225a75 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -7,7 +7,7 @@ import typing from contextlib import contextmanager from itertools import count -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional from flytekit.configuration import SerializationSettings from flytekit.core import tracker @@ -149,8 +149,8 @@ def _compute_array_job_index() -> int: environment variable and the offset (if one's set). The offset will be set and used when the user request that the job runs in a number of slots less than the size of the input. """ - return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", 0)) + int( - os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")) + return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", "0")) + int( + os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "0"), "0") ) @property @@ -168,7 +168,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]: return self.interface.outputs return self._run_task.interface.outputs - def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_output_var(self, k: str, v: Any) -> type: """ We override this method from flytekit.core.base_task Task because the dispatch_execute method uses this interface to construct outputs. Each instance of an container_array task will however produce outputs @@ -181,7 +181,7 @@ def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: return self._python_interface.outputs[k] return self._run_task._python_interface.outputs[k] - def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any: + def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any: """ This is called during ExecutionState.Mode.TASK_EXECUTION executions, that is executions orchestrated by the Flyte platform. Individual instances of the map task, aka array task jobs are passed the full set of inputs but diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 95fdebb4eb..2c967d7e9e 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -113,7 +113,7 @@ def with_overrides(self, *args, **kwargs): def _convert_resource_overrides( resources: typing.Optional[Resources], resource_name: str -) -> [_resources_model.ResourceEntry]: +) -> typing.List[_resources_model.ResourceEntry]: if resources is None: return [] if not isinstance(resources, Resources): diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index de33393c13..62065f6869 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -1,7 +1,6 @@ from __future__ import annotations -import collections -from typing import TYPE_CHECKING, Type, Union +from typing import TYPE_CHECKING, Union from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext @@ -21,7 +20,7 @@ def create_node( entity: Union[PythonTask, LaunchPlan, WorkflowBase, RemoteEntity], *args, **kwargs -) -> Union[Node, VoidPromise, Type[collections.namedtuple]]: +) -> Union[Node, VoidPromise]: """ This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or don't produce outputs. For example, if you have t1() and t2(), both of which do not take in nor produce any @@ -173,9 +172,9 @@ def sub_wf(): if len(output_names) == 1: # See explanation above for why we still tupletize a single element. - return entity.python_interface.output_tuple(results) + return entity.python_interface.output_tuple(results) # type: ignore - return entity.python_interface.output_tuple(*results) + return entity.python_interface.output_tuple(*results) # type: ignore else: raise Exception(f"Cannot use explicit run to call Flyte entities {entity.name}") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index ab7ff23931..6256b90874 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -11,7 +11,13 @@ from flytekit.core import context_manager as _flyte_context from flytekit.core import interface as flyte_interface from flytekit.core import type_engine -from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ( + BranchEvalMode, + ExecutionParameters, + ExecutionState, + FlyteContext, + FlyteContextManager, +) from flytekit.core.interface import Interface from flytekit.core.node import Node from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine @@ -80,7 +86,7 @@ def extract_value( if lt.collection_type is None: raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}") try: - sub_type = ListTransformer.get_sub_type(python_type) + sub_type: type = ListTransformer.get_sub_type(python_type) except ValueError: if len(input_val) == 0: raise @@ -348,7 +354,7 @@ def __hash__(self): return hash(id(self)) def __rshift__(self, other: typing.Union[Promise, VoidPromise]): - if not self.is_ready: + if self.is_ready and other.ref: self.ref.node.runs_before(other.ref.node) return other @@ -408,10 +414,10 @@ def is_false(self) -> ComparisonExpression: def is_true(self): return self.is_(True) - def __eq__(self, other) -> ComparisonExpression: + def __eq__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.EQ, other) - def __ne__(self, other) -> ComparisonExpression: + def __ne__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.NE, other) def __gt__(self, other) -> ComparisonExpression: @@ -454,7 +460,7 @@ def __str__(self): def create_native_named_tuple( - ctx: FlyteContext, promises: Optional[Union[Promise, typing.List[Promise]]], entity_interface: Interface + ctx: FlyteContext, promises: Union[Tuple[Promise], Promise, VoidPromise, None], entity_interface: Interface ) -> Optional[Tuple]: """ Creates and returns a Named tuple with all variables that match the expected named outputs. this makes @@ -474,7 +480,7 @@ def create_native_named_tuple( except Exception as e: raise AssertionError(f"Failed to convert value of output {k}, expected type {v}.") from e - if len(promises) == 0: + if len(cast(Tuple[Promise], promises)) == 0: return None named_tuple_name = "DefaultNamedTupleOutput" @@ -482,7 +488,7 @@ def create_native_named_tuple( named_tuple_name = entity_interface.output_tuple_name outputs = {} - for p in promises: + for p in cast(Tuple[Promise], promises): if not isinstance(p, Promise): raise AssertionError( "Workflow outputs can only be promises that are returned by tasks. Found a value of" @@ -495,8 +501,8 @@ def create_native_named_tuple( raise AssertionError(f"Failed to convert value of output {p.var}, expected type {t}.") from e # Should this class be part of the Interface? - t = collections.namedtuple(named_tuple_name, list(outputs.keys())) - return t(**outputs) + nt = collections.namedtuple(named_tuple_name, list(outputs.keys())) # type: ignore + return nt(**outputs) # To create a class that is a named tuple, we might have to create namedtuplemeta and manipulate the tuple @@ -539,7 +545,7 @@ def create_task_output( named_tuple_name = entity_interface.output_tuple_name # Should this class be part of the Interface? - class Output(collections.namedtuple(named_tuple_name, variables)): + class Output(collections.namedtuple(named_tuple_name, variables)): # type: ignore def with_overrides(self, *args, **kwargs): val = self.__getattribute__(self._fields[0]) val.with_overrides(*args, **kwargs) @@ -598,7 +604,7 @@ def binding_data_from_python_std( if expected_literal_type.collection_type is None: raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}") - sub_type = ListTransformer.get_sub_type(t_value_type) if t_value_type else None + sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None collection = _literals_models.BindingDataCollection( bindings=[ binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value @@ -677,7 +683,7 @@ def ref(self) -> typing.Optional[NodeOutput]: return self._ref def __rshift__(self, other: typing.Union[Promise, VoidPromise]): - if self.ref: + if self.ref and other.ref: self.ref.node.runs_before(other.ref.node) return other @@ -977,11 +983,13 @@ def create_and_link_node( class LocallyExecutable(Protocol): - def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: ... -def flyte_entity_call_handler(entity: Union[SupportsNodeCreation], *args, **kwargs): +def flyte_entity_call_handler( + entity: Union[SupportsNodeCreation], *args, **kwargs +) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying workflow). The logic is the same for all three, but we did not want to create base class, hence this separate @@ -1034,7 +1042,7 @@ def flyte_entity_call_handler(entity: Union[SupportsNodeCreation], *args, **kwar ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) ) ) as child_ctx: - cast(FlyteContext, child_ctx).user_space_params._decks = [] + cast(ExecutionParameters, child_ctx).user_space_params._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs) @@ -1044,7 +1052,9 @@ def flyte_entity_call_handler(entity: Union[SupportsNodeCreation], *args, **kwar else: raise Exception(f"Received an output when workflow local execution expected None. Received: {result}") - if (1 < expected_outputs == len(result)) or (result is not None and expected_outputs == 1): + if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( + result is not None and expected_outputs == 1 + ): return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface) raise ValueError( diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 06133d9784..ae317a1a3e 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,8 +3,7 @@ import importlib import re from abc import ABC -from types import ModuleType -from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Callable, Dict, List, Optional, TypeVar from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, TaskResolverMixin @@ -99,7 +98,7 @@ def __init__( self._get_command_fn = self.get_default_command @property - def task_resolver(self) -> Optional[TaskResolverMixin]: + def task_resolver(self) -> TaskResolverMixin: return self._task_resolver @property @@ -188,14 +187,14 @@ class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): def name(self) -> str: return "DefaultTaskResolver" - def load_task(self, loader_args: List[Union[T, ModuleType]]) -> PythonAutoContainerTask: + def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: _, task_module, _, task_name, *_ = loader_args - task_module = importlib.import_module(task_module) + task_module = importlib.import_module(name=task_module) # type: ignore task_def = getattr(task_module, task_name) return task_def - def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore from flytekit.core.python_function_task import PythonFunctionTask if isinstance(task, PythonFunctionTask): @@ -205,7 +204,7 @@ def loader_args(self, settings: SerializationSettings, task: PythonAutoContainer _, m, t, _ = extract_task_module(task) return ["task-module", m, "task-name", t] - def get_all_tasks(self) -> List[PythonAutoContainerTask]: + def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore raise Exception("should not be needed") diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index 235ddaa2cc..8aa49a7f20 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -21,7 +21,7 @@ TC = TypeVar("TC") -class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): +class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): # type: ignore """ Please take a look at the comments for :py:class`flytekit.extend.ExecutableTemplateShimTask` as well. This class should be subclassed and a custom Executor provided as a default to this parent class constructor @@ -227,7 +227,7 @@ def name(self) -> str: # The return type of this function is different, it should be a Task, but it's not because it doesn't make # sense for ExecutableTemplateShimTask to inherit from Task. - def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: + def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: # type: ignore logger.info(f"Task template loader args: {loader_args}") ctx = FlyteContext.current_context() task_template_local_path = os.path.join(ctx.execution_state.working_dir, "task_template.pb") # type: ignore @@ -238,7 +238,7 @@ def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: executor_class = load_object_from_module(loader_args[1]) return ExecutableTemplateShimTask(task_template_model, executor_class) - def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]: # type: ignore return ["{{.taskTemplatePath}}", f"{t.executor_type.__module__}.{t.executor_type.__name__}"] def get_all_tasks(self) -> List[Task]: diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index fb428a89a2..d8bd4b27d0 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -192,7 +192,7 @@ def compile_into_workflow( self._wf.compile(**kwargs) wf = self._wf - model_entities = OrderedDict() + model_entities: OrderedDict = OrderedDict() # See comment on reference entity checking a bit down below in this function. # This is the only circular dependency between the translator.py module and the rest of the flytekit # authoring experience. diff --git a/flytekit/core/reference.py b/flytekit/core/reference.py index 6a88549c43..cad44268ff 100644 --- a/flytekit/core/reference.py +++ b/flytekit/core/reference.py @@ -15,7 +15,7 @@ def get_reference_entity( domain: str, name: str, version: str, - inputs: Dict[str, Type], + inputs: Dict[str, type], outputs: Dict[str, Type], ): """ diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 77b96e6892..d8a8f620f7 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -21,7 +21,7 @@ from flytekit.models.core import workflow as _workflow_model -@dataclass +@dataclass # type: ignore class Reference(ABC): project: str domain: str @@ -66,7 +66,7 @@ class ReferenceEntity(object): def __init__( self, reference: Union[WorkflowReference, TaskReference, LaunchPlanReference], - inputs: Optional[Dict[str, Union[Type[Any], Tuple[Type[Any], Any]]]], + inputs: Dict[str, type], outputs: Dict[str, Type], ): if ( diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 7b46cbe05c..2b6cebead9 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -33,5 +33,5 @@ class Resources(object): @dataclass class ResourceSpec(object): - requests: Optional[Resources] = None - limits: Optional[Resources] = None + requests: Resources + limits: Resources diff --git a/flytekit/core/shim_task.py b/flytekit/core/shim_task.py index d8d18293c5..f96db3e49c 100644 --- a/flytekit/core/shim_task.py +++ b/flytekit/core/shim_task.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, Generic, Type, TypeVar, Union +from typing import Any, Generic, Optional, Type, TypeVar, Union, cast -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger @@ -47,7 +47,7 @@ def name(self) -> str: if self._task_template is not None: return self._task_template.id.name # if not access the subclass's name - return self._name + return self._name # type: ignore @property def task_template(self) -> _task_model.TaskTemplate: @@ -67,13 +67,13 @@ def execute(self, **kwargs) -> Any: """ return self.executor.execute_from_model(self.task_template, **kwargs) - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: """ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. """ return user_params - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + def post_execute(self, _: Optional[ExecutionParameters], rval: Any) -> Any: """ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. """ @@ -92,7 +92,9 @@ def dispatch_execute( # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params)) + ctx.with_execution_state( + cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params) + ) ) as exec_ctx: # Added: Have to reverse the Python interface from the task template Flyte interface # See docstring for more details. diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 6e5b0a6b6a..a772775d82 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -87,7 +87,7 @@ def task( requests: Optional[Resources] = None, limits: Optional[Resources] = None, secret_requests: Optional[List[Secret]] = None, - execution_mode: Optional[PythonFunctionTask.ExecutionBehavior] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, disable_deck: bool = True, ) -> Union[Callable, PythonFunctionTask]: @@ -222,7 +222,7 @@ class ReferenceTask(ReferenceEntity, PythonFunctionTask): """ def __init__( - self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type] + self, project: str, domain: str, name: str, version: str, inputs: Dict[str, type], outputs: Dict[str, Type] ): super().__init__(TaskReference(project, domain, name, version), inputs, outputs) diff --git a/flytekit/core/testing.py b/flytekit/core/testing.py index 772a4b6df6..055e47efd4 100644 --- a/flytekit/core/testing.py +++ b/flytekit/core/testing.py @@ -1,3 +1,4 @@ +import typing from contextlib import contextmanager from typing import Union from unittest.mock import MagicMock @@ -9,7 +10,7 @@ @contextmanager -def task_mock(t: PythonTask) -> MagicMock: +def task_mock(t: PythonTask) -> typing.Iterator[MagicMock]: """ Use this method to mock a task declaration. It can mock any Task in Flytekit as long as it has a python native interface associated with it. @@ -41,9 +42,9 @@ def _log(*args, **kwargs): return m(*args, **kwargs) _captured_fn = t.execute - t.execute = _log + t.execute = _log # type: ignore yield m - t.execute = _captured_fn + t.execute = _captured_fn # type: ignore def patch(target: Union[PythonTask, WorkflowBase, ReferenceEntity]): diff --git a/flytekit/core/tracked_abc.py b/flytekit/core/tracked_abc.py index bad4f8c555..3c39d3725c 100644 --- a/flytekit/core/tracked_abc.py +++ b/flytekit/core/tracked_abc.py @@ -3,7 +3,7 @@ from flytekit.core.tracker import TrackedInstance -class FlyteTrackedABC(type(TrackedInstance), type(ABC)): +class FlyteTrackedABC(type(TrackedInstance), type(ABC)): # type: ignore """ This class exists because if you try to inherit from abc.ABC and TrackedInstance by itself, you'll get the well-known ``TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 2a203d4861..e6a8645c15 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -179,7 +179,7 @@ class _ModuleSanitizer(object): def __init__(self): self._module_cache = {} - def _resolve_abs_module_name(self, path: str, package_root: str) -> str: + def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str]) -> str: """ Recursively finds the root python package under-which basename exists """ diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 98969e41b3..ec0d5fc25e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -22,6 +22,7 @@ from google.protobuf.struct_pb2 import Struct from marshmallow_enum import EnumField, LoadDumpOptions from marshmallow_jsonschema import JSONSchema +from proto import Message from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation @@ -117,7 +118,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp raise NotImplementedError(f"Conversion to Literal for python type {python_type} not implemented") @abstractmethod - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: """ Converts the given Literal to a Python Type. If the conversion cannot be done an AssertionError should be raised :param ctx: FlyteContext @@ -361,11 +362,14 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A from flytekit.types.schema.types import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDataset - if hasattr(python_type, "__origin__") and python_type.__origin__ is list: - return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val] + has_attr = hasattr(python_type, "__origin__") + ot = getattr(python_type, "__origin__") + args = getattr(python_type, "__args__") + if has_attr and ot is list: + return [self._serialize_flyte_type(v, args[0]) for v in cast(list, python_val)] - if hasattr(python_type, "__origin__") and python_type.__origin__ is dict: - return {k: self._serialize_flyte_type(v, python_type.__args__[1]) for k, v in python_val.items()} + if has_attr and ot is dict: + return {k: self._serialize_flyte_type(v, args[1]) for k, v in cast(dict, python_val).items()} if not dataclasses.is_dataclass(python_type): return python_val @@ -417,7 +421,13 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> t = FlyteSchemaTransformer() return t.to_python_value( FlyteContext.current_context(), - Literal(scalar=Scalar(schema=Schema(python_val.remote_path, t._get_schema_type(expected_python_type)))), + Literal( + scalar=Scalar( + schema=Schema( + cast(FlyteSchema, python_val).remote_path, t._get_schema_type(expected_python_type) + ) + ) + ), expected_python_type, ) elif issubclass(expected_python_type, FlyteFile): @@ -431,7 +441,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) ), - uri=python_val.path, + uri=cast(FlyteFile, python_val).path, ) ) ), @@ -448,7 +458,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART ) ), - uri=python_val.path, + uri=cast(FlyteDirectory, python_val).path, ) ) ), @@ -461,9 +471,11 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> scalar=Scalar( structured_dataset=StructuredDataset( metadata=StructuredDatasetMetadata( - structured_dataset_type=StructuredDatasetType(format=python_val.file_format) + structured_dataset_type=StructuredDatasetType( + format=cast(StructuredDataset, python_val).file_format + ) ), - uri=python_val.uri, + uri=cast(StructuredDataset, python_val).uri, ) ) ), @@ -502,7 +514,9 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if isinstance(val, dict): ktype, vtype = DictTransformer.get_dict_types(t) # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}}) - return {self._fix_val_int(ktype, k): self._fix_val_int(vtype, v) for k, v in val.items()} + return { + self._fix_val_int(cast(type, ktype), k): self._fix_val_int(cast(type, vtype), v) for k, v in val.items() + } if dataclasses.is_dataclass(t): return self._fix_dataclass_int(t, val) # type: ignore @@ -542,7 +556,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # calls to guess_python_type would result in a logically equivalent (but new) dataclass, which # TypeEngine.assert_type would not be happy about. @lru_cache(typed=True) - def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: ignore if literal_type.simple == SimpleType.STRUCT: if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata: schema_name = literal_type.metadata["$ref"].split("/")[-1] @@ -567,7 +581,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: struct = Struct() try: - struct.update(_MessageToDict(python_val)) + struct.update(_MessageToDict(cast(Message, python_val))) except Exception: raise TypeTransformerFailedError("Failed to convert to generic protobuf struct") return Literal(scalar=Scalar(generic=struct)) @@ -578,7 +592,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: pb_obj = expected_python_type() dictionary = _MessageToDict(lv.scalar.generic) - pb_obj = _ParseDict(dictionary, pb_obj) + pb_obj = _ParseDict(dictionary, pb_obj) # type: ignore return pb_obj def guess_python_type(self, literal_type: LiteralType) -> Type[T]: @@ -601,7 +615,7 @@ class TypeEngine(typing.Generic[T]): _REGISTRY: typing.Dict[type, TypeTransformer[T]] = {} _RESTRICTED_TYPES: typing.List[type] = [] - _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() + _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore @classmethod def register( @@ -626,10 +640,10 @@ def register( def register_restricted_type( cls, name: str, - type: Type, + type: Type[T], ): cls._RESTRICTED_TYPES.append(type) - cls.register(RestrictedTypeTransformer(name, type)) + cls.register(RestrictedTypeTransformer(name, type)) # type: ignore @classmethod def register_additional_type(cls, transformer: TypeTransformer, additional_type: Type, override=False): @@ -886,8 +900,8 @@ def get_sub_type(t: Type[T]) -> Type[T]: if get_origin(t) is Annotated: return ListTransformer.get_sub_type(get_args(t)[0]) - if t.__origin__ is list and hasattr(t, "__args__"): - return t.__args__[0] + if getattr(t, "__origin__") is list and hasattr(t, "__args__"): + return getattr(t, "__args__")[0] raise ValueError("Only generic univariate typing.List[T] type is supported.") @@ -909,7 +923,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[T]: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore try: lits = lv.collection.literals except AttributeError: @@ -918,10 +932,10 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: st = self.get_sub_type(expected_python_type) return [TypeEngine.to_python_value(ctx, x, st) for x in lits] - def guess_python_type(self, literal_type: LiteralType) -> Type[list]: + def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore if literal_type.collection_type: - ct = TypeEngine.guess_python_type(literal_type.collection_type) - return typing.List[ct] + ct: Type = TypeEngine.guess_python_type(literal_type.collection_type) + return typing.List[ct] # type: ignore raise ValueError(f"List transformer cannot reverse {literal_type}") @@ -1034,7 +1048,9 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: t = get_args(t)[0] try: - trans = [(TypeEngine.get_transformer(x), x) for x in get_args(t)] + trans: typing.List[typing.Tuple[TypeTransformer, typing.Any]] = [ + (TypeEngine.get_transformer(x), x) for x in get_args(t) + ] # must go through TypeEngine.to_literal_type instead of trans.get_literal_type # to handle Annotated variants = [_add_tag_to_type(TypeEngine.to_literal_type(x), t.name) for (t, x) in trans] @@ -1051,7 +1067,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp res_type = None for t in get_args(python_type): try: - trans = TypeEngine.get_transformer(t) + trans: TypeTransformer[T] = TypeEngine.get_transformer(t) res = trans.to_literal(ctx, python_val, t, expected) res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) @@ -1084,7 +1100,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: res_tag = None for v in get_args(expected_python_type): try: - trans = TypeEngine.get_transformer(v) + trans: TypeTransformer[T] = TypeEngine.get_transformer(v) if union_tag is not None: if trans.name != union_tag: continue @@ -1167,10 +1183,10 @@ def get_literal_type(self, t: Type[dict]) -> LiteralType: Transforms a native python dictionary to a flyte-specific ``LiteralType`` """ tp = self.get_dict_types(t) - if tp: + if tp is not None: if tp[0] == str: try: - sub_type = TypeEngine.to_literal_type(tp[1]) + sub_type = TypeEngine.to_literal_type(cast(type, tp[1])) return _type_models.LiteralType(map_value_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") @@ -1191,7 +1207,7 @@ def to_literal( raise ValueError("Flyte MapType expects all keys to be strings") # TODO: log a warning for Annotated objects that contain HashMethod k_type, v_type = self.get_dict_types(python_type) - lit_map[k] = TypeEngine.to_literal(ctx, v, v_type, expected.map_value_type) + lit_map[k] = TypeEngine.to_literal(ctx, v, cast(type, v_type), expected.map_value_type) return Literal(map=LiteralMap(literals=lit_map)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: @@ -1207,7 +1223,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key") py_map = {} for k, v in lv.map.literals.items(): - py_map[k] = TypeEngine.to_python_value(ctx, v, tp[1]) + py_map[k] = TypeEngine.to_python_value(ctx, v, cast(Type, tp[1])) return py_map # for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict @@ -1245,10 +1261,8 @@ def _blob_type(self) -> _core_types.BlobType: dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) - def get_literal_type(self, t: typing.TextIO) -> LiteralType: - return _type_models.LiteralType( - blob=self._blob_type(), - ) + def get_literal_type(self, t: typing.TextIO) -> LiteralType: # type: ignore + return _type_models.LiteralType(blob=self._blob_type()) def to_literal( self, ctx: FlyteContext, python_val: typing.TextIO, python_type: Type[typing.TextIO], expected: LiteralType @@ -1319,7 +1333,9 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: raise TypeTransformerFailedError("Only EnumTypes with value of string are supported") return LiteralType(enum_type=_core_types.EnumType(values=values)) - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + def to_literal( + self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType + ) -> Literal: if type(python_val).__class__ != enum.EnumMeta: raise TypeTransformerFailedError("Expected an enum") if type(python_val.value) != str: @@ -1328,11 +1344,12 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - return expected_python_type(lv.scalar.primitive.string_value) + return expected_python_type(lv.scalar.primitive.string_value) # type: ignore -def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: - """Generate a model class based on the provided JSON Schema +def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema :param schema_name: dataclass name of return type """ @@ -1341,7 +1358,7 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac property_type = property_val["type"] # Handle list if property_val["type"] == "array": - attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) + attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore # Handle dataclass and dict elif property_type == "object": if property_val.get("$ref"): @@ -1349,13 +1366,13 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name))) elif property_val.get("additionalProperties"): attribute_list.append( - (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) + (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore ) else: - attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) + attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) # type: ignore # Handle int, float, bool or str else: - attribute_list.append([property_key, _get_element_type(property_val)]) + attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) @@ -1529,8 +1546,8 @@ def __init__( raise ValueError("Cannot instantiate LiteralsResolver without a map of Literals.") self._literals = literals self._variable_map = variable_map - self._native_values = {} - self._type_hints = {} + self._native_values: Dict[str, type] = {} + self._type_hints: Dict[str, type] = {} self._ctx = ctx def __str__(self) -> str: @@ -1583,7 +1600,7 @@ def __getitem__(self, key: str): return self.get(key) - def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: + def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: # type: ignore """ This will get the ``attr`` value from the Literal map, and invoke the TypeEngine to convert it into a Python native value. A Python type can optionally be supplied. If successful, the native value will be cached and @@ -1610,7 +1627,9 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: raise e else: ValueError("as_type argument not supplied and Variable map not specified in LiteralsResolver") - val = TypeEngine.to_python_value(self._ctx or FlyteContext.current_context(), self._literals[attr], as_type) + val = TypeEngine.to_python_value( + self._ctx or FlyteContext.current_context(), self._literals[attr], cast(Type, as_type) + ) self._native_values[attr] = val return val diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index d23aae3fbb..647c25556c 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -50,8 +50,8 @@ def _dnsify(value: str) -> str: def _get_container_definition( image: str, - command: List[str], - args: List[str], + command: Optional[List[str]] = None, + args: Optional[List[str]] = None, data_loading_config: Optional[_task_models.DataLoadingConfig] = None, storage_request: Optional[str] = None, ephemeral_storage_request: Optional[str] = None, diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 468a5aa7ea..c67a21427f 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask @@ -175,9 +175,9 @@ def __init__( self._workflow_metadata_defaults = workflow_metadata_defaults self._python_interface = python_interface self._interface = transform_interface_to_typed_interface(python_interface) - self._inputs = {} - self._unbound_inputs = set() - self._nodes = [] + self._inputs: Dict[str, Promise] = {} + self._unbound_inputs: set = set() + self._nodes: List[Node] = [] self._output_bindings: List[_literal_models.Binding] = [] FlyteEntities.entities.append(self) super().__init__(**kwargs) @@ -191,11 +191,11 @@ def short_name(self) -> str: return extract_obj_name(self._name) @property - def workflow_metadata(self) -> Optional[WorkflowMetadata]: + def workflow_metadata(self) -> WorkflowMetadata: return self._workflow_metadata @property - def workflow_metadata_defaults(self): + def workflow_metadata_defaults(self) -> WorkflowMetadataDefaults: return self._workflow_metadata_defaults @property @@ -228,7 +228,7 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata: interruptible=self.workflow_metadata_defaults.interruptible, ) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ Workflow needs to fill in default arguments before invoking the call handler. """ @@ -489,7 +489,7 @@ def get_input_values(input_value): self._unbound_inputs.remove(input_value) return n # type: ignore - def add_workflow_input(self, input_name: str, python_type: Type) -> Interface: + def add_workflow_input(self, input_name: str, python_type: Type) -> Promise: """ Adds an input to the workflow. """ @@ -516,7 +516,8 @@ def add_workflow_output( f"If specifying a list or dict of Promises, you must specify the python_type type for {output_name}" f" starting with the container type (e.g. List[int]" ) - python_type = p.ref.node.flyte_entity.python_interface.outputs[p.var] + promise = cast(Promise, p) + python_type = promise.ref.node.flyte_entity.python_interface.outputs[promise.var] logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}") flyte_type = TypeEngine.to_literal_type(python_type=python_type) @@ -569,8 +570,8 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver): def __init__( self, workflow_function: Callable, - metadata: Optional[WorkflowMetadata], - default_metadata: Optional[WorkflowMetadataDefaults], + metadata: WorkflowMetadata, + default_metadata: WorkflowMetadataDefaults, docstring: Docstring = None, ): name, _, _, _ = extract_task_module(workflow_function) @@ -592,7 +593,7 @@ def __init__( def function(self): return self._workflow_function - def task_name(self, t: PythonAutoContainerTask) -> str: + def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore return f"{self.name}.{t.__module__}.{t.name}" def compile(self, **kwargs): @@ -741,7 +742,7 @@ def wrapper(fn): return wrapper -class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): +class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignore """ A reference workflow is a pointer to a workflow that already exists on your Flyte installation. This object will not initiate a network call to Admin, which is why the user is asked to provide the expected interface. diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 4f06c3d3c6..e0a864e31e 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -628,7 +628,7 @@ def uri(self) -> str: return self._uri @property - def metadata(self) -> StructuredDatasetMetadata: + def metadata(self) -> Optional[StructuredDatasetMetadata]: return self._metadata def to_flyte_idl(self) -> _literals_pb2.StructuredDataset: diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index cd50b4fb62..5325b80ec4 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -51,13 +51,13 @@ class SchemaReader(typing.Generic[T]): Use the simplified base LocalIOSchemaReader for non distributed dataframes """ - def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, from_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): self._from_path = from_path self._fmt = fmt self._columns = cols @property - def from_path(self) -> str: + def from_path(self) -> os.PathLike: return self._from_path @property @@ -76,7 +76,7 @@ def all(self, **kwargs) -> T: class SchemaWriter(typing.Generic[T]): - def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, to_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): self._to_path = to_path self._fmt = fmt self._columns = cols @@ -84,7 +84,7 @@ def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], self._file_name_gen = generate_ordered_files(Path(self._to_path), 1024) @property - def to_path(self) -> str: + def to_path(self) -> os.PathLike: return self._to_path @property @@ -100,31 +100,37 @@ def write(self, *dfs, **kwargs): class LocalIOSchemaReader(SchemaReader[T]): def __init__(self, from_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(str(from_path), cols, fmt) + super().__init__(from_path, cols, fmt) @abstractmethod def _read(self, *path: os.PathLike, **kwargs) -> T: pass def iter(self, **kwargs) -> typing.Generator[T, None, None]: - with os.scandir(self._from_path) as it: + with os.scandir(self._from_path) as it: # type: ignore for entry in it: - if not entry.name.startswith(".") and entry.is_file(): - yield self._read(Path(entry.path), **kwargs) + if ( + not typing.cast(os.DirEntry, entry).name.startswith(".") + and typing.cast(os.DirEntry, entry).is_file() + ): + yield self._read(Path(typing.cast(os.DirEntry, entry).path), **kwargs) def all(self, **kwargs) -> T: files: typing.List[os.PathLike] = [] - with os.scandir(self._from_path) as it: + with os.scandir(self._from_path) as it: # type: ignore for entry in it: - if not entry.name.startswith(".") and entry.is_file(): - files.append(Path(entry.path)) + if ( + not typing.cast(os.DirEntry, entry).name.startswith(".") + and typing.cast(os.DirEntry, entry).is_file() + ): + files.append(Path(typing.cast(os.DirEntry, entry).path)) return self._read(*files, **kwargs) class LocalIOSchemaWriter(SchemaWriter[T]): def __init__(self, to_local_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(str(to_local_path), cols, fmt) + super().__init__(to_local_path, cols, fmt) @abstractmethod def _write(self, df: T, path: os.PathLike, **kwargs): @@ -290,10 +296,11 @@ def open( self._downloader(self.remote_path, self.local_path) self._downloaded = True if mode == SchemaOpenMode.WRITE: - return h.writer(typing.cast(str, self.local_path), self.columns(), self.format()) - return h.reader(typing.cast(str, self.local_path), self.columns(), self.format()) + return h.writer(self.local_path, self.columns(), self.format()) + return h.reader(self.local_path, self.columns(), self.format()) # Remote IO is handled. So we will just pass the remote reference to the object + assert self.remote_path is not None if mode == SchemaOpenMode.WRITE: return h.writer(self.remote_path, self.columns(), self.format()) return h.reader(self.remote_path, self.columns(), self.format()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 8b11778321..8a678b7700 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -73,7 +73,7 @@ def __init__( # This is not for users to set, the transformer will set this. self._literal_sd: Optional[literals.StructuredDataset] = None # Not meant for users to set, will be set by an open() call - self._dataframe_type: Optional[DF] = None + self._dataframe_type: Optional[DF] = None # type: ignore @property def dataframe(self) -> Optional[DF]: @@ -254,7 +254,7 @@ def decode( ctx: FlyteContext, flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, - ) -> Union[DF, Generator[DF, None, None]]: + ) -> Union[DF, typing.Iterator[DF]]: """ This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal value into a Python instance. @@ -561,7 +561,7 @@ def encode( # least as good as the type of the interface. if sd_model.metadata is None: sd_model._metadata = StructuredDatasetMetadata(structured_literal_type) - if sd_model.metadata.structured_dataset_type is None: + if sd_model.metadata and sd_model.metadata.structured_dataset_type is None: sd_model.metadata._structured_dataset_type = structured_literal_type # Always set the format here to the format of the handler. # Note that this will always be the same as the incoming format except for when the fallback handler @@ -691,8 +691,9 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ # Here we only render column information by default instead of opening the structured dataset. col = typing.cast(StructuredDataset, python_val).columns() df = pd.DataFrame(col, ["column type"]) - assert hasattr(df, "to_html") - return df.to_html() + if hasattr(df, "to_html"): + return df.to_html() # type: ignore + return "" else: df = python_val @@ -728,10 +729,10 @@ def iter_as( sd: literals.StructuredDataset, df_type: Type[DF], updated_metadata: StructuredDatasetMetadata, - ) -> Generator[DF, None, None]: + ) -> typing.Iterator[DF]: protocol = protocol_prefix(sd.uri) decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format] - result = decoder.decode(ctx, sd, updated_metadata) + result: Union[DF, typing.Iterator[DF]] = decoder.decode(ctx, sd, updated_metadata) if not isinstance(result, types.GeneratorType): raise ValueError(f"Decoder {decoder} didn't return iterator {result} but should have from {sd}") return result @@ -746,7 +747,7 @@ def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: raise AssertionError(f"type {t} is currently not supported by StructuredDataset") def _convert_ordered_dict_of_columns_to_list( - self, column_map: typing.OrderedDict[str, Type] + self, column_map: typing.Optional[typing.OrderedDict[str, Type]] ) -> typing.List[StructuredDatasetType.DatasetColumn]: converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = [] if column_map is None or len(column_map) == 0: @@ -757,10 +758,12 @@ def _convert_ordered_dict_of_columns_to_list( return converted_cols def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> StructuredDatasetType: - original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) + original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore # Get the column information - converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map) + converted_cols: typing.List[ + StructuredDatasetType.DatasetColumn + ] = self._convert_ordered_dict_of_columns_to_list(column_map) # Get the format default_format = ( diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index ead14e9052..743c2f0976 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -17,7 +17,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Annotated from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation From 61f6d548a5a89643e0a2ac0de2166187839d68dc Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 11 Nov 2022 19:42:29 -0800 Subject: [PATCH 03/22] Fix mypy errors Signed-off-by: Kevin Su --- Makefile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Makefile b/Makefile index 484d4dafb4..aaa736fe3f 100644 --- a/Makefile +++ b/Makefile @@ -38,8 +38,6 @@ lint: ## Run linters mypy flytekit/core mypy flytekit/types mypy tests/flytekit/unit/core - # Exclude setup.py to fix error: Duplicate module named "setup" - # mypy plugins --exclude setup.py || true pre-commit run --all-files .PHONY: spellcheck From 7325dffe0f5dad2dcafbd5aa4b6b3eda2a8c9a3d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 11 Nov 2022 19:57:53 -0800 Subject: [PATCH 04/22] Fix tests Signed-off-by: Kevin Su --- flytekit/core/context_manager.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index db738ed77c..7abe0af033 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -27,13 +27,14 @@ from enum import Enum from typing import Generator, List, Optional, Union -from flytekit import LaunchPlan from flytekit.clients import friendly as friendly_client # noqa from flytekit.configuration import Config, SecretsConfig, SerializationSettings from flytekit.core import mock_stats, utils from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider +from flytekit.core.launch_plan import LaunchPlan from flytekit.core.node import Node +from flytekit.core.workflow import WorkflowBase from flytekit.interfaces.cli_identifiers import WorkflowExecutionIdentifier from flytekit.interfaces.stats import taggable from flytekit.loggers import logger, user_space_logger @@ -49,7 +50,7 @@ flyte_context_Var: ContextVar[typing.List[FlyteContext]] = ContextVar("", default=[]) if typing.TYPE_CHECKING: - from flytekit.core.base_task import TaskResolverMixin + from flytekit.core.base_task import Task, TaskResolverMixin # Identifier fields use placeholders for registration-time substitution. @@ -853,7 +854,7 @@ class FlyteEntities(object): registration process """ - entities: LaunchPlan = [] + entities: List[LaunchPlan | Task | WorkflowBase] = [] FlyteContextManager.initialize() From b8428f0878b98ff4de9e24cef210009d7331d39b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 11 Nov 2022 20:13:10 -0800 Subject: [PATCH 05/22] Fix tests Signed-off-by: Kevin Su --- flytekit/core/context_manager.py | 4 +--- flytekit/core/interface.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 7abe0af033..2175168872 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -32,9 +32,7 @@ from flytekit.core import mock_stats, utils from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider -from flytekit.core.launch_plan import LaunchPlan from flytekit.core.node import Node -from flytekit.core.workflow import WorkflowBase from flytekit.interfaces.cli_identifiers import WorkflowExecutionIdentifier from flytekit.interfaces.stats import taggable from flytekit.loggers import logger, user_space_logger @@ -854,7 +852,7 @@ class FlyteEntities(object): registration process """ - entities: List[LaunchPlan | Task | WorkflowBase] = [] + entities: List["LaunchPlan" | Task | "WorkflowBase"] = [] # type: ignore FlyteContextManager.initialize() diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index bcf3490a79..051d01d1fb 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -43,7 +43,7 @@ def __init__( primarily used when handling one-element NamedTuples. :param docstring: Docstring of the annotated @task or @workflow from which the interface derives from. """ - self._inputs: Dict[str, Tuple[Type, Any]] | Dict[str, Type] + self._inputs: Dict[str, Tuple[Type, Any]] | Dict[str, Type] = {} # type: ignore if inputs: for k, v in inputs.items(): if type(v) is tuple and len(cast(Tuple, v)) > 1: From 90852263f6bb287c2f3adac74359cbaf5432bd30 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 11 Nov 2022 21:54:58 -0800 Subject: [PATCH 06/22] Fix tests Signed-off-by: Kevin Su --- flytekit/core/promise.py | 2 +- flytekit/core/type_engine.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 6256b90874..a37bce0a77 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1042,7 +1042,7 @@ def flyte_entity_call_handler( ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) ) ) as child_ctx: - cast(ExecutionParameters, child_ctx).user_space_params._decks = [] + cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index ec0d5fc25e..053405eaa8 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -362,14 +362,13 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A from flytekit.types.schema.types import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDataset - has_attr = hasattr(python_type, "__origin__") - ot = getattr(python_type, "__origin__") - args = getattr(python_type, "__args__") - if has_attr and ot is list: - return [self._serialize_flyte_type(v, args[0]) for v in cast(list, python_val)] - - if has_attr and ot is dict: - return {k: self._serialize_flyte_type(v, args[1]) for k, v in cast(dict, python_val).items()} + if hasattr(python_type, "__origin__"): + ot = getattr(python_type, "__origin__") + args = getattr(python_type, "__args__") + if ot is list: + return [self._serialize_flyte_type(v, args[0]) for v in cast(list, python_val)] + if ot is dict: + return {k: self._serialize_flyte_type(v, args[1]) for k, v in cast(dict, python_val).items()} if not dataclasses.is_dataclass(python_type): return python_val From c46b045eefc9b58c51c9c9a3faf9f2d2a28b551e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 12 Nov 2022 01:07:06 -0800 Subject: [PATCH 07/22] wip Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 12 ++++++++---- flytekit/types/schema/types.py | 2 -- flytekit/types/structured/structured_dataset.py | 4 +--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 053405eaa8..cd65f91e72 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -364,11 +364,15 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A if hasattr(python_type, "__origin__"): ot = getattr(python_type, "__origin__") - args = getattr(python_type, "__args__") if ot is list: - return [self._serialize_flyte_type(v, args[0]) for v in cast(list, python_val)] + return [ + self._serialize_flyte_type(v, getattr(python_type, "__args__")[0]) for v in cast(list, python_val) + ] if ot is dict: - return {k: self._serialize_flyte_type(v, args[1]) for k, v in cast(dict, python_val).items()} + return { + k: self._serialize_flyte_type(v, getattr(python_type, "__args__")[1]) + for k, v in cast(dict, python_val).items() + } if not dataclasses.is_dataclass(python_type): return python_val @@ -1182,7 +1186,7 @@ def get_literal_type(self, t: Type[dict]) -> LiteralType: Transforms a native python dictionary to a flyte-specific ``LiteralType`` """ tp = self.get_dict_types(t) - if tp is not None: + if tp: if tp[0] == str: try: sub_type = TypeEngine.to_literal_type(cast(type, tp[1])) diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 5325b80ec4..607099cb13 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -292,7 +292,6 @@ def open( raise AssertionError("downloader cannot be None in read mode!") # Only for readable objects if they are not downloaded already, we should download them # Write objects should already have everything written to - assert self.remote_path is not None self._downloader(self.remote_path, self.local_path) self._downloaded = True if mode == SchemaOpenMode.WRITE: @@ -300,7 +299,6 @@ def open( return h.reader(self.local_path, self.columns(), self.format()) # Remote IO is handled. So we will just pass the remote reference to the object - assert self.remote_path is not None if mode == SchemaOpenMode.WRITE: return h.writer(self.remote_path, self.columns(), self.format()) return h.reader(self.remote_path, self.columns(), self.format()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 8a678b7700..f41481f823 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -691,9 +691,7 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ # Here we only render column information by default instead of opening the structured dataset. col = typing.cast(StructuredDataset, python_val).columns() df = pd.DataFrame(col, ["column type"]) - if hasattr(df, "to_html"): - return df.to_html() # type: ignore - return "" + return df.to_html() # type: ignore else: df = python_val From 8b1c1aa7a8089682c677483a00dd06353c7f3a92 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 12 Nov 2022 10:34:17 -0800 Subject: [PATCH 08/22] wip Signed-off-by: Kevin Su --- flytekit/core/promise.py | 2 +- flytekit/core/tracker.py | 2 +- flytekit/types/schema/types.py | 14 +++++----- flytekit/types/structured/basic_dfs.py | 3 ++- .../test_structured_dataset_workflow.py | 27 ++++++++++--------- 5 files changed, 25 insertions(+), 23 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index a37bce0a77..6fbc60d642 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -354,7 +354,7 @@ def __hash__(self): return hash(id(self)) def __rshift__(self, other: typing.Union[Promise, VoidPromise]): - if self.is_ready and other.ref: + if not self.is_ready: self.ref.node.runs_before(other.ref.node) return other diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index e6a8645c15..2a203d4861 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -179,7 +179,7 @@ class _ModuleSanitizer(object): def __init__(self): self._module_cache = {} - def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str]) -> str: + def _resolve_abs_module_name(self, path: str, package_root: str) -> str: """ Recursively finds the root python package under-which basename exists """ diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 607099cb13..a108fb7d45 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -76,7 +76,7 @@ def all(self, **kwargs) -> T: class SchemaWriter(typing.Generic[T]): - def __init__(self, to_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): self._to_path = to_path self._fmt = fmt self._columns = cols @@ -84,7 +84,7 @@ def __init__(self, to_path: os.PathLike, cols: typing.Optional[typing.Dict[str, self._file_name_gen = generate_ordered_files(Path(self._to_path), 1024) @property - def to_path(self) -> os.PathLike: + def to_path(self) -> str: return self._to_path @property @@ -129,7 +129,7 @@ def all(self, **kwargs) -> T: class LocalIOSchemaWriter(SchemaWriter[T]): - def __init__(self, to_local_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, to_local_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(to_local_path, cols, fmt) @abstractmethod @@ -233,8 +233,8 @@ def format(cls) -> SchemaFormat: def __init__( self, - local_path: typing.Optional[os.PathLike] = None, - remote_path: typing.Optional[os.PathLike] = None, + local_path: typing.Optional[str] = None, + remote_path: typing.Optional[str] = None, supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE, downloader: typing.Optional[typing.Callable] = None, ): @@ -259,7 +259,7 @@ def __init__( self._downloader = downloader @property - def local_path(self) -> os.PathLike: + def local_path(self) -> str: return self._local_path @property @@ -309,7 +309,7 @@ def as_readonly(self) -> FlyteSchema: s = FlyteSchema.__class_getitem__(self.columns(), self.format())( local_path=self.local_path, # Dummy path is ok, as we will assume data is already downloaded and will not download again - remote_path=self.remote_path, + remote_path=self.remote_path if self.remote_path else "", supported_mode=SchemaOpenMode.READ, ) s._downloaded = True diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 71dff61c5e..ef8886fe51 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -73,7 +73,8 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_path() + path = structured_dataset.uri or ctx.file_access.get_random_remote_path() + print("arrow", path) df = structured_dataset.dataframe local_dir = ctx.file_access.get_random_local_directory() local_path = os.path.join(local_dir, f"{0:05}") diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index 5d04a12e7b..f8b52d43bb 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -81,7 +81,8 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + path = structured_dataset.uri or ctx.file_access.get_random_remote_directory() + print("numpy", path) df = typing.cast(np.ndarray, structured_dataset.dataframe) name = ["col" + str(i) for i in range(len(df))] table = pa.Table.from_arrays(df, name) @@ -218,19 +219,19 @@ def wf(): df = generate_pandas() np_array = generate_numpy() arrow_df = generate_arrow() - t1(dataframe=df) - t1a(dataframe=df) - t2(dataframe=df) - t3(dataset=StructuredDataset(uri=PANDAS_PATH)) - t3a(dataset=StructuredDataset(uri=PANDAS_PATH)) - t4(dataset=StructuredDataset(uri=PANDAS_PATH)) - t5(dataframe=df) - t6(dataset=StructuredDataset(uri=BQ_PATH)) - t7(df1=df, df2=df) - t8(dataframe=arrow_df) - t8a(dataframe=arrow_df) + # t1(dataframe=df) + # t1a(dataframe=df) + # t2(dataframe=df) + # t3(dataset=StructuredDataset(uri=PANDAS_PATH)) + # t3a(dataset=StructuredDataset(uri=PANDAS_PATH)) + # t4(dataset=StructuredDataset(uri=PANDAS_PATH)) + # t5(dataframe=df) + # t6(dataset=StructuredDataset(uri=BQ_PATH)) + # t7(df1=df, df2=df) + # t8(dataframe=arrow_df) + # t8a(dataframe=arrow_df) t9(dataframe=np_array) - t10(dataset=StructuredDataset(uri=NUMPY_PATH)) + # t10(dataset=StructuredDataset(uri=NUMPY_PATH)) def test_structured_dataset_wf(): From 9317e89270da47fe6c178ac2aa405dd318368c4d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 15 Nov 2022 10:11:39 -0800 Subject: [PATCH 09/22] fix tests Signed-off-by: Kevin Su --- flytekit/core/promise.py | 2 +- flytekit/core/tracker.py | 2 +- flytekit/types/structured/basic_dfs.py | 3 +-- .../types/structured/structured_dataset.py | 2 +- .../test_structured_dataset_workflow.py | 27 +++++++++---------- 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 6fbc60d642..0bb2cf2db1 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -354,7 +354,7 @@ def __hash__(self): return hash(id(self)) def __rshift__(self, other: typing.Union[Promise, VoidPromise]): - if not self.is_ready: + if not self.is_ready and other.ref: self.ref.node.runs_before(other.ref.node) return other diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 2a203d4861..23ff7c9222 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -179,7 +179,7 @@ class _ModuleSanitizer(object): def __init__(self): self._module_cache = {} - def _resolve_abs_module_name(self, path: str, package_root: str) -> str: + def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str] = None) -> str: """ Recursively finds the root python package under-which basename exists """ diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index ef8886fe51..71dff61c5e 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -73,8 +73,7 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: - path = structured_dataset.uri or ctx.file_access.get_random_remote_path() - print("arrow", path) + path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_path() df = structured_dataset.dataframe local_dir = ctx.file_access.get_random_local_directory() local_path = os.path.join(local_dir, f"{0:05}") diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index f41481f823..6b4bca8be2 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -67,7 +67,7 @@ def __init__( self._dataframe = dataframe # Make these fields public, so that the dataclass transformer can set a value for it # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298 - self.uri = str(uri) + self.uri = uri # This is a special attribute that indicates if the data was either downloaded or uploaded self._metadata = metadata # This is not for users to set, the transformer will set this. diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index f8b52d43bb..5d04a12e7b 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -81,8 +81,7 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: - path = structured_dataset.uri or ctx.file_access.get_random_remote_directory() - print("numpy", path) + path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(np.ndarray, structured_dataset.dataframe) name = ["col" + str(i) for i in range(len(df))] table = pa.Table.from_arrays(df, name) @@ -219,19 +218,19 @@ def wf(): df = generate_pandas() np_array = generate_numpy() arrow_df = generate_arrow() - # t1(dataframe=df) - # t1a(dataframe=df) - # t2(dataframe=df) - # t3(dataset=StructuredDataset(uri=PANDAS_PATH)) - # t3a(dataset=StructuredDataset(uri=PANDAS_PATH)) - # t4(dataset=StructuredDataset(uri=PANDAS_PATH)) - # t5(dataframe=df) - # t6(dataset=StructuredDataset(uri=BQ_PATH)) - # t7(df1=df, df2=df) - # t8(dataframe=arrow_df) - # t8a(dataframe=arrow_df) + t1(dataframe=df) + t1a(dataframe=df) + t2(dataframe=df) + t3(dataset=StructuredDataset(uri=PANDAS_PATH)) + t3a(dataset=StructuredDataset(uri=PANDAS_PATH)) + t4(dataset=StructuredDataset(uri=PANDAS_PATH)) + t5(dataframe=df) + t6(dataset=StructuredDataset(uri=BQ_PATH)) + t7(df1=df, df2=df) + t8(dataframe=arrow_df) + t8a(dataframe=arrow_df) t9(dataframe=np_array) - # t10(dataset=StructuredDataset(uri=NUMPY_PATH)) + t10(dataset=StructuredDataset(uri=NUMPY_PATH)) def test_structured_dataset_wf(): From af0ff79a5bc8d6517b726203a0045c0bcd47501e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 15 Nov 2022 11:04:07 -0800 Subject: [PATCH 10/22] fix tests Signed-off-by: Kevin Su --- flytekit/core/base_sql_task.py | 2 +- flytekit/core/base_task.py | 4 ++-- flytekit/types/schema/types.py | 14 +++++++------- flytekit/types/schema/types_pandas.py | 4 ++-- flytekit/types/structured/structured_dataset.py | 3 +-- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 09b35dec5d..9727284123 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -22,7 +22,7 @@ def __init__( self, name: str, query_template: str, - task_config: T, + task_config: Optional[T] = None, task_type="sql_task", inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index fa0c1eaf94..0c5a4f562c 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -367,7 +367,7 @@ def __init__( self, task_type: str, name: str, - task_config: T, + task_config: Optional[T], interface: Optional[Interface] = None, environment: Optional[Dict[str, str]] = None, disable_deck: bool = True, @@ -406,7 +406,7 @@ def python_interface(self) -> Interface: return self._python_interface @property - def task_config(self) -> T: + def task_config(self) -> Optional[T]: """ Returns the user-specified task config which is used for plugin-specific handling of the task. """ diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index a108fb7d45..417fc045ed 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -51,13 +51,13 @@ class SchemaReader(typing.Generic[T]): Use the simplified base LocalIOSchemaReader for non distributed dataframes """ - def __init__(self, from_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): self._from_path = from_path self._fmt = fmt self._columns = cols @property - def from_path(self) -> os.PathLike: + def from_path(self) -> str: return self._from_path @property @@ -99,7 +99,7 @@ def write(self, *dfs, **kwargs): class LocalIOSchemaReader(SchemaReader[T]): - def __init__(self, from_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(from_path, cols, fmt) @abstractmethod @@ -181,7 +181,7 @@ def get_handler(cls, t: Type) -> SchemaHandler: @dataclass_json @dataclass class FlyteSchema(object): - remote_path: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) """ This is the main schema class that users should use. """ @@ -300,8 +300,8 @@ def open( # Remote IO is handled. So we will just pass the remote reference to the object if mode == SchemaOpenMode.WRITE: - return h.writer(self.remote_path, self.columns(), self.format()) - return h.reader(self.remote_path, self.columns(), self.format()) + return h.writer(typing.cast(str, self.remote_path), self.columns(), self.format()) + return h.reader(typing.cast(str, self.remote_path), self.columns(), self.format()) def as_readonly(self) -> FlyteSchema: if self._supported_mode == SchemaOpenMode.READ: @@ -309,7 +309,7 @@ def as_readonly(self) -> FlyteSchema: s = FlyteSchema.__class_getitem__(self.columns(), self.format())( local_path=self.local_path, # Dummy path is ok, as we will assume data is already downloaded and will not download again - remote_path=self.remote_path if self.remote_path else "", + remote_path=typing.cast(str, self.remote_path) if self.remote_path else "", supported_mode=SchemaOpenMode.READ, ) s._downloaded = True diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index e4c6078e94..55b3f23533 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -56,7 +56,7 @@ def write( class PandasSchemaReader(LocalIOSchemaReader[pandas.DataFrame]): - def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, local_dir: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) self._parquet_engine = ParquetIO() @@ -65,7 +65,7 @@ def _read(self, *path: os.PathLike, **kwargs) -> pandas.DataFrame: class PandasSchemaWriter(LocalIOSchemaWriter[pandas.DataFrame]): - def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, local_dir: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) self._parquet_engine = ParquetIO() diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 6b4bca8be2..a4b9dc9d4a 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import os import types import typing from abc import ABC, abstractmethod @@ -60,7 +59,7 @@ def column_names(cls) -> typing.List[str]: def __init__( self, dataframe: typing.Optional[typing.Any] = None, - uri: typing.Optional[typing.Union[str, os.PathLike]] = None, + uri: typing.Optional[str] = None, metadata: typing.Optional[literals.StructuredDatasetMetadata] = None, **kwargs, ): From 842f50c90f6a28cd28b300fc7d409b532aeef821 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 18 Nov 2022 10:50:41 -0800 Subject: [PATCH 11/22] fix test Signed-off-by: Kevin Su --- Makefile | 2 +- flytekit/core/base_sql_task.py | 2 +- flytekit/core/container_task.py | 10 +-- flytekit/core/context_manager.py | 2 +- flytekit/core/docstring.py | 2 +- flytekit/core/interface.py | 2 +- flytekit/core/launch_plan.py | 76 +++++++++---------- flytekit/core/map_task.py | 4 +- flytekit/core/python_function_task.py | 4 +- flytekit/core/schedule.py | 9 ++- flytekit/core/task.py | 2 +- flytekit/core/type_engine.py | 4 +- flytekit/core/workflow.py | 4 +- flytekit/types/directory/types.py | 7 +- flytekit/types/schema/types.py | 2 +- flytekit/types/schema/types_pandas.py | 4 +- .../types/structured/structured_dataset.py | 2 +- .../core/flyte_functools/decorator_source.py | 2 +- .../unit/core/test_dynamic_conditional.py | 2 +- tests/flytekit/unit/core/test_type_hints.py | 6 +- tests/flytekit/unit/core/test_workflows.py | 4 +- 21 files changed, 82 insertions(+), 70 deletions(-) diff --git a/Makefile b/Makefile index aaa736fe3f..f1973dc769 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,7 @@ fmt: ## Format code with black and isort lint: ## Run linters mypy flytekit/core mypy flytekit/types - mypy tests/flytekit/unit/core + mypy --allow-empty-bodies --allow-redefinition tests/flytekit/unit/core pre-commit run --all-files .PHONY: spellcheck diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 9727284123..0274ea2e1f 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -26,7 +26,7 @@ def __init__( task_type="sql_task", inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - outputs: Dict[str, Type] = None, + outputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 5527375c44..6aab550ffa 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -37,14 +37,14 @@ def __init__( command: List[str], inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - arguments: List[str] = None, - outputs: Dict[str, Type] = None, + arguments: Optional[List[str]] = None, + outputs: Optional[Dict[str, Type]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, - input_data_dir: str = None, - output_data_dir: str = None, + input_data_dir: Optional[str] = None, + output_data_dir: Optional[str] = None, metadata_format: MetadataFormat = MetadataFormat.JSON, - io_strategy: IOStrategy = None, + io_strategy: Optional[IOStrategy] = None, secret_requests: Optional[List[Secret]] = None, **kwargs, ): diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 2175168872..dca3cfec34 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -123,7 +123,7 @@ def build(self) -> ExecutionParameters: ) @staticmethod - def new_builder(current: ExecutionParameters = None) -> Builder: + def new_builder(current: Optional[ExecutionParameters] = None) -> Builder: return ExecutionParameters.Builder(current=current) def with_task_sandbox(self) -> Builder: diff --git a/flytekit/core/docstring.py b/flytekit/core/docstring.py index 420f26f8f5..fa9d9caec2 100644 --- a/flytekit/core/docstring.py +++ b/flytekit/core/docstring.py @@ -4,7 +4,7 @@ class Docstring(object): - def __init__(self, docstring: str = None, callable_: Callable = None): + def __init__(self, docstring: Optional[str] = None, callable_: Optional[Callable] = None): if docstring is not None: self._parsed_docstring = parse(docstring) else: diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 051d01d1fb..4b52be0301 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -345,7 +345,7 @@ def transform_variable_map( return res -def transform_type(x: type, description: str = None) -> _interface_models.Variable: +def transform_type(x: type, description: Optional[str] = None) -> _interface_models.Variable: return _interface_models.Variable(type=TypeEngine.to_literal_type(x), description=description) diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index ed77574e35..9d880302e5 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -107,16 +107,16 @@ def create( cls, name: str, workflow: _annotated_workflow.WorkflowBase, - default_inputs: Dict[str, Any] = None, - fixed_inputs: Dict[str, Any] = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, - auth_role: _common_models.AuthRole = None, + default_inputs: Optional[Dict[str, Any]] = None, + fixed_inputs: Optional[Dict[str, Any]] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, + auth_role: Optional[_common_models.AuthRole] = None, ) -> LaunchPlan: ctx = FlyteContextManager.current_context() default_inputs = default_inputs or {} @@ -185,16 +185,16 @@ def get_or_create( cls, workflow: _annotated_workflow.WorkflowBase, name: Optional[str] = None, - default_inputs: Dict[str, Any] = None, - fixed_inputs: Dict[str, Any] = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, - auth_role: _common_models.AuthRole = None, + default_inputs: Optional[Dict[str, Any]] = None, + fixed_inputs: Optional[Dict[str, Any]] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, + auth_role: Optional[_common_models.AuthRole] = None, ) -> LaunchPlan: """ This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not @@ -298,13 +298,13 @@ def __init__( workflow: _annotated_workflow.WorkflowBase, parameters: _interface_models.ParameterMap, fixed_inputs: _literal_models.LiteralMap, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: typing.Optional[int] = None, - security_context: typing.Optional[security.SecurityContext] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, ): self._name = name self._workflow = workflow @@ -328,15 +328,15 @@ def __init__( def clone_with( self, name: str, - parameters: _interface_models.ParameterMap = None, - fixed_inputs: _literal_models.LiteralMap = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, + parameters: Optional[_interface_models.ParameterMap] = None, + fixed_inputs: Optional[_literal_models.LiteralMap] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, ) -> LaunchPlan: return LaunchPlan( name=name, @@ -405,11 +405,11 @@ def raw_output_data_config(self) -> Optional[_common_models.RawOutputDataConfig] return self._raw_output_data_config @property - def max_parallelism(self) -> typing.Optional[int]: + def max_parallelism(self) -> Optional[int]: return self._max_parallelism @property - def security_context(self) -> typing.Optional[security.SecurityContext]: + def security_context(self) -> Optional[security.SecurityContext]: return self._security_context def construct_node_metadata(self) -> _workflow_model.NodeMetadata: diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 327c225a75..48d0f0b335 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -36,8 +36,8 @@ class MapPythonTask(PythonTask): def __init__( self, python_function_task: PythonFunctionTask, - concurrency: int = None, - min_success_ratio: float = None, + concurrency: Optional[int] = None, + min_success_ratio: Optional[float] = None, **kwargs, ): """ diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index d8bd4b27d0..7a7cbf78da 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -43,7 +43,7 @@ T = TypeVar("T") -class PythonInstanceTask(PythonAutoContainerTask[T], ABC): +class PythonInstanceTask(PythonAutoContainerTask[T], ABC): # type: ignore """ This class should be used as the base class for all Tasks that do not have a user defined function body, but have a platform defined execute method. (Execute needs to be overridden). This base class ensures that the module loader @@ -72,7 +72,7 @@ def __init__( super().__init__(name=name, task_config=task_config, task_type=task_type, task_resolver=task_resolver, **kwargs) -class PythonFunctionTask(PythonAutoContainerTask[T]): +class PythonFunctionTask(PythonAutoContainerTask[T]): # type: ignore """ A Python Function task should be used as the base for all extensions that have a python function. It will automatically detect interface of the python function and when serialized on the hosted Flyte platform handles the diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 7addc89197..93116d0720 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -6,6 +6,7 @@ import datetime import re as _re +from typing import Optional import croniter as _croniter @@ -52,7 +53,11 @@ class CronSchedule(_schedule_models.Schedule): _OFFSET_PATTERN = _re.compile("([-+]?)P([-+0-9YMWD]+)?(T([-+0-9HMS.,]+)?)?") def __init__( - self, cron_expression: str = None, schedule: str = None, offset: str = None, kickoff_time_input_arg: str = None + self, + cron_expression: Optional[str] = None, + schedule: Optional[str] = None, + offset: Optional[str] = None, + kickoff_time_input_arg: Optional[str] = None, ): """ :param str cron_expression: This should be a cron expression in AWS style.Shouldn't be used in case of native scheduler. @@ -161,7 +166,7 @@ class FixedRate(_schedule_models.Schedule): See the :std:ref:`fixed rate intervals` chapter in the cookbook for additional usage examples. """ - def __init__(self, duration: datetime.timedelta, kickoff_time_input_arg: str = None): + def __init__(self, duration: datetime.timedelta, kickoff_time_input_arg: Optional[str] = None): """ :param datetime.timedelta duration: :param str kickoff_time_input_arg: diff --git a/flytekit/core/task.py b/flytekit/core/task.py index a772775d82..fc4db83b14 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -214,7 +214,7 @@ def wrapper(fn) -> PythonFunctionTask: return wrapper -class ReferenceTask(ReferenceEntity, PythonFunctionTask): +class ReferenceTask(ReferenceEntity, PythonFunctionTask): # type: ignore """ This is a reference task, the body of the function passed in through the constructor will never be used, only the signature of the function will be. The signature should also match the signature of the task you're referencing, diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index cd65f91e72..8e57bc74d0 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -162,7 +162,7 @@ def __init__( self._to_literal_transformer = to_literal_transformer self._from_literal_transformer = from_literal_transformer - def get_literal_type(self, t: Type[T] = None) -> LiteralType: + def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: return LiteralType.from_flyte_idl(self._lt.to_flyte_idl()) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: @@ -207,7 +207,7 @@ class RestrictedTypeTransformer(TypeTransformer[T], ABC): def __init__(self, name: str, t: Type[T]): super().__init__(name, t) - def get_literal_type(self, t: Type[T] = None) -> LiteralType: + def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently") def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index c67a21427f..e9d64ab1ff 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -392,7 +392,7 @@ def execute(self, **kwargs): raise FlyteValidationException(f"Workflow not ready, wf is currently {self}") # Create a map that holds the outputs of each node. - intermediate_node_outputs = {GLOBAL_START_NODE: {}} # type: Dict[Node, Dict[str, Promise]] + intermediate_node_outputs: Dict[Node, Dict[str, Promise]] = {GLOBAL_START_NODE: {}} # Start things off with the outputs of the global input node, i.e. the inputs to the workflow. # local_execute should've already ensured that all the values in kwargs are Promise objects @@ -572,7 +572,7 @@ def __init__( workflow_function: Callable, metadata: WorkflowMetadata, default_metadata: WorkflowMetadataDefaults, - docstring: Docstring = None, + docstring: Optional[Docstring] = None, ): name, _, _, _ = extract_task_module(workflow_function) self._workflow_function = workflow_function diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index afb59d58d0..7d576f9353 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -115,7 +115,12 @@ def t1(in1: FlyteDirectory["svg"]): field in the ``BlobType``. """ - def __init__(self, path: typing.Union[str, os.PathLike], downloader: typing.Callable = None, remote_directory=None): + def __init__( + self, + path: typing.Union[str, os.PathLike], + downloader: typing.Optional[typing.Callable] = None, + remote_directory: typing.Optional[str] = None, + ): """ :param path: The source path that users are expected to call open() on :param downloader: Optional function that can be passed that used to delay downloading of the actual fil diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 417fc045ed..16e33f63c2 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -267,7 +267,7 @@ def supported_mode(self) -> SchemaOpenMode: return self._supported_mode def open( - self, dataframe_fmt: type = pandas.DataFrame, override_mode: SchemaOpenMode = None + self, dataframe_fmt: type = pandas.DataFrame, override_mode: typing.Optional[SchemaOpenMode] = None ) -> typing.Union[SchemaReader, SchemaWriter]: """ Returns a reader or writer depending on the mode of the object when created. This mode can be diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index 55b3f23533..ca6cab8030 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -17,7 +17,9 @@ class ParquetIO(object): def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame: return pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, **kwargs) - def read(self, *files: os.PathLike, columns: typing.List[str] = None, **kwargs) -> pandas.DataFrame: + def read( + self, *files: os.PathLike, columns: typing.Optional[typing.List[str]] = None, **kwargs + ) -> pandas.DataFrame: frames = [self._read(chunk=f, columns=columns, **kwargs) for f in files if os.path.getsize(f) > 0] if len(frames) == 1: return frames[0] diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index a4b9dc9d4a..bad72339f1 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -90,7 +90,7 @@ def open(self, dataframe_type: Type[DF]): self._dataframe_type = dataframe_type return self - def all(self) -> DF: + def all(self) -> DF: # type: ignore if self._dataframe_type is None: raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.") ctx = FlyteContextManager.current_context() diff --git a/tests/flytekit/unit/core/flyte_functools/decorator_source.py b/tests/flytekit/unit/core/flyte_functools/decorator_source.py index c0ce833263..5790d5d358 100644 --- a/tests/flytekit/unit/core/flyte_functools/decorator_source.py +++ b/tests/flytekit/unit/core/flyte_functools/decorator_source.py @@ -5,7 +5,7 @@ from typing import List -def task_setup(function: typing.Callable, *, integration_requests: List = None) -> typing.Callable: +def task_setup(function: typing.Callable, *, integration_requests: typing.Optional[List] = None) -> typing.Callable: integration_requests = integration_requests or [] @wraps(function) diff --git a/tests/flytekit/unit/core/test_dynamic_conditional.py b/tests/flytekit/unit/core/test_dynamic_conditional.py index 8c34f34759..1e68fd3d69 100644 --- a/tests/flytekit/unit/core/test_dynamic_conditional.py +++ b/tests/flytekit/unit/core/test_dynamic_conditional.py @@ -15,7 +15,7 @@ def test_dynamic_conditional(): @task - def split(in1: typing.List[int]) -> (typing.List[int], typing.List[int], int): + def split(in1: typing.List[int]) -> tuple[list[int], list[int], float]: return in1[0 : int(len(in1) / 2)], in1[int(len(in1) / 2) + 1 :], len(in1) / 2 # One sample implementation for merging. In a more real world example, this might merge file streams and only load diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 739adff86e..2d24e4ae22 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -880,7 +880,7 @@ def t2(a: str, b: str) -> str: return b + a @workflow - def my_subwf(a: int) -> (str, str): + def my_subwf(a: int) -> typing.Tuple[str, str]: x, y = t1(a=a) u, v = t1(a=x) return y, v @@ -1404,7 +1404,7 @@ def t2(a: str, b: str) -> str: return b + a @workflow - def my_wf(a: int, b: str) -> (str, typing.List[str]): + def my_wf(a: int, b: str) -> typing.Tuple[str, typing.List[str]]: @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] @@ -1444,7 +1444,7 @@ def t1() -> str: return "Hello" @workflow - def wf() -> typing.NamedTuple("OP", a=str, b=str): + def wf() -> typing.NamedTuple("OP", [("a", str), ("b", str)]): # type: ignore return t1(), t1() assert wf() == ("Hello", "Hello") diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index eb5c10f719..23b9d0631e 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -44,12 +44,12 @@ def test_default_metadata_values(): def test_workflow_values(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)]): a = a + 2 return a, "world-" + str(a) @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) - def wf(a: int) -> (str, str): + def wf(a: int) -> typing.Tuple[str, str]: x, y = t1(a=a) u, v = t1(a=x) return y, v From 085779fcf43f9df96e0ceda39673c761955dab4f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 18 Nov 2022 13:34:19 -0800 Subject: [PATCH 12/22] nit Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_dynamic_conditional.py | 2 +- tests/flytekit/unit/core/test_interface.py | 8 ++++---- tests/flytekit/unit/core/test_node_creation.py | 10 +++++----- tests/flytekit/unit/core/test_realworld_examples.py | 4 ++-- tests/flytekit/unit/core/test_references.py | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/flytekit/unit/core/test_dynamic_conditional.py b/tests/flytekit/unit/core/test_dynamic_conditional.py index 1e68fd3d69..8c34f34759 100644 --- a/tests/flytekit/unit/core/test_dynamic_conditional.py +++ b/tests/flytekit/unit/core/test_dynamic_conditional.py @@ -15,7 +15,7 @@ def test_dynamic_conditional(): @task - def split(in1: typing.List[int]) -> tuple[list[int], list[int], float]: + def split(in1: typing.List[int]) -> (typing.List[int], typing.List[int], int): return in1[0 : int(len(in1) / 2)], in1[int(len(in1) / 2) + 1 :], len(in1) / 2 # One sample implementation for merging. In a more real world example, this might merge file streams and only load diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index a0068d08ca..83459719eb 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -95,13 +95,13 @@ def t(a: int, b: str) -> Dict[str, int]: def test_named_tuples(): - nt1 = typing.NamedTuple("NT1", [("x_str", str), ("y_int", int)]) + nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int) - def x(a: int, b: str) -> typing.NamedTuple("NT1", [("x_str", str), ("y_int", int)]): - return "hello world", 5 + def x(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): + return ("hello world", 5) def y(a: int, b: str) -> nt1: - return nt1("hello world", 5) + return nt1("hello world", 5) # type: ignore result = transform_variable_map(extract_return_annotation(typing.get_type_hints(x).get("return", None))) assert result["x_str"].type.simple == 3 diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 23a4bac5a3..59b841a5f4 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -96,17 +96,17 @@ def empty_wf2(): def test_more_normal_task(): - nt = typing.NamedTuple("OneOutput", [("t1_str_output", str)]) + nt = typing.NamedTuple("OneOutput", t1_str_output=str) @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt(f"{a + 2}") + return nt(f"{a + 2}") # type: ignore @task def t1_nt(a: int) -> nt: # This one returns an instance of the named tuple. - return nt(f"{a + 2}") + return nt(f"{a + 2}") # type: ignore @task def t2(a: typing.List[str]) -> str: @@ -124,12 +124,12 @@ def my_wf(a: int, b: str) -> (str, str): def test_reserved_keyword(): - nt = typing.NamedTuple("OneOutput", [("outputs", str)]) + nt = typing.NamedTuple("OneOutput", outputs=str) @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt(f"{a + 2}") + return nt(f"{a + 2}") # type: ignore # Test that you can't name an output "outputs" with pytest.raises(FlyteAssertion): diff --git a/tests/flytekit/unit/core/test_realworld_examples.py b/tests/flytekit/unit/core/test_realworld_examples.py index c5b3e374fc..779ba3334c 100644 --- a/tests/flytekit/unit/core/test_realworld_examples.py +++ b/tests/flytekit/unit/core/test_realworld_examples.py @@ -105,7 +105,7 @@ def split_traintest_dataset( # We will fake train test split. Just return the same dataset multiple times return x, x, y, y - nt = typing.NamedTuple("Outputs", [("model", FlyteFile[MODELSER_JOBLIB])]) + nt = typing.NamedTuple("Outputs", model=FlyteFile[MODELSER_JOBLIB]) @task(cache_version="1.0", cache=True, limits=Resources(mem="200Mi")) def fit(x: FlyteSchema[FEATURE_COLUMNS], y: FlyteSchema[CLASSES_COLUMNS], hyperparams: dict) -> nt: @@ -126,7 +126,7 @@ def fit(x: FlyteSchema[FEATURE_COLUMNS], y: FlyteSchema[CLASSES_COLUMNS], hyperp fname = "model.joblib.dat" with open(fname, "w") as f: f.write("Some binary data") - return nt(model=fname) + return nt(model=fname) # type: ignore @task(cache_version="1.0", cache=True, limits=Resources(mem="200Mi")) def predict(x: FlyteSchema[FEATURE_COLUMNS], model_ser: FlyteFile[MODELSER_JOBLIB]) -> FlyteSchema[CLASSES_COLUMNS]: diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index 75dc431918..7486422fd9 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -155,12 +155,12 @@ def inner_test(ref_mock): inner_test() - nt1 = typing.NamedTuple("DummyNamedTuple", [("t1_int_output", int), ("c", str)]) + nt1 = typing.NamedTuple("DummyNamedTuple", t1_int_output=int, c=str) @task def t1(a: int) -> nt1: a = a + 2 - return nt1(a, "world-" + str(a)) + return nt1(a, "world-" + str(a)) # type: ignore @workflow def wf2(a: int): From cd048874c82ce548933fdf3539c60f71b8794d95 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 27 Dec 2022 16:36:56 -0800 Subject: [PATCH 13/22] Update type Signed-off-by: Kevin Su --- Makefile | 5 ++++- flytekit/core/gate.py | 5 +++-- flytekit/core/interface.py | 6 +++--- flytekit/core/promise.py | 6 ++++-- flytekit/core/python_function_task.py | 8 ++++---- flytekit/core/type_engine.py | 1 - flytekit/types/directory/__init__.py | 2 +- flytekit/types/structured/structured_dataset.py | 9 ++++++++- tests/flytekit/unit/core/test_gate.py | 2 +- 9 files changed, 28 insertions(+), 16 deletions(-) diff --git a/Makefile b/Makefile index c0f979fdab..f1f932cd69 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,10 @@ fmt: ## Format code with black and isort lint: ## Run linters mypy flytekit/core mypy flytekit/types - mypy --allow-empty-bodies --allow-redefinition tests/flytekit/unit/core + # allow-empty-bodies: Allow empty body in function. + # disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked". + # Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass. + mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core pre-commit run --all-files .PHONY: spellcheck diff --git a/flytekit/core/gate.py b/flytekit/core/gate.py index f3d90ebef8..38f024d7a7 100644 --- a/flytekit/core/gate.py +++ b/flytekit/core/gate.py @@ -53,7 +53,7 @@ def __init__( ) else: # We don't know how to find the python interface here, approve() sets it below, See the code. - self._python_interface = None + self._python_interface = None # type: ignore @property def name(self) -> str: @@ -105,7 +105,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr return p # Assume this is an approval operation since that's the only remaining option. - msg = f"Pausing execution for {self.name}, literal value is:\n{self._upstream_item.val}\nContinue?" + msg = f"Pausing execution for {self.name}, literal value is:\n{typing.cast(Promise, self._upstream_item).val}\nContinue?" proceed = click.confirm(msg, default=True) if proceed: # We need to return a promise here, and a promise is what should've been passed in by the call in approve() @@ -164,6 +164,7 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st raise ValueError("You can't use approval on a task that doesn't return anything.") ctx = FlyteContextManager.current_context() + upstream_item = typing.cast(Promise, upstream_item) if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: if not upstream_item.ref.node.flyte_entity.python_interface: raise ValueError( diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 4b52be0301..293fc7d86e 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -29,7 +29,7 @@ class Interface(object): def __init__( self, inputs: Optional[Dict[str, Type]] | Optional[Dict[str, Tuple[Type, Any]]] = None, - outputs: Optional[Dict[str, Type]] = None, + outputs: Optional[Dict[str, Type]] | Optional[Dict[str, Optional[Type]]] = None, output_tuple_name: Optional[str] = None, docstring: Optional[Docstring] = None, ): @@ -50,7 +50,7 @@ def __init__( self._inputs[k] = v # type: ignore else: self._inputs[k] = (v, None) # type: ignore - self._outputs = outputs if outputs else {} + self._outputs = outputs if outputs else {} # type: ignore self._output_tuple_name = output_tuple_name if outputs: @@ -120,7 +120,7 @@ def default_inputs_as_kwargs(self) -> Dict[str, Any]: @property def outputs(self) -> typing.Dict[str, type]: - return self._outputs + return self._outputs # type: ignore @property def docstring(self) -> Optional[Docstring]: diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 78e5cdda3f..878a8c7275 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -464,7 +464,7 @@ def __str__(self): def create_native_named_tuple( ctx: FlyteContext, - promises: Union[Tuple[Promise],Promise, VoidPromise, None], + promises: Union[Tuple[Promise], Promise, VoidPromise, None], entity_interface: Interface, ) -> Optional[Tuple]: """ @@ -999,7 +999,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr ... -def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: +def flyte_entity_call_handler( + entity: SupportsNodeCreation, *args, **kwargs +) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying workflow). The logic is the same for all three, but we did not want to create base class, hence this separate diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 8f7bac94f6..90b10cbc36 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -17,7 +17,7 @@ from abc import ABC from collections import OrderedDict from enum import Enum -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar, Union, cast from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager @@ -193,7 +193,7 @@ def compile_into_workflow( from flytekit.tools.translator import get_serializable self._create_and_cache_dynamic_workflow() - self._wf.compile(**kwargs) + cast(PythonFunctionWorkflow, self._wf).compile(**kwargs) wf = self._wf model_entities: OrderedDict = OrderedDict() @@ -263,12 +263,12 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: # local_execute directly though since that converts inputs into Promises. logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") self._create_and_cache_dynamic_workflow() - function_outputs = self._wf.execute(**kwargs) + function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs) if isinstance(function_outputs, VoidPromise) or function_outputs is None: return VoidPromise(self.name) - if len(self._wf.python_interface.outputs) == 0: + if len(cast(PythonFunctionWorkflow, self._wf).python_interface.outputs) == 0: raise FlyteValueException(function_outputs, "Interface output should've been VoidPromise or None.") # TODO: This will need to be cleaned up when we revisit top-level tuple support. diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 0df80488a1..31102ea6f8 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -22,7 +22,6 @@ from google.protobuf.struct_pb2 import Struct from marshmallow_enum import EnumField, LoadDumpOptions from marshmallow_jsonschema import JSONSchema -from proto import Message from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index c2ab8fd438..87b494d0ae 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -28,7 +28,7 @@ TensorBoard. """ -tfrecords_dir = typing.TypeVar("tfrecord") +tfrecords_dir = typing.TypeVar("tfrecords_dir") TFRecordsDirectory = FlyteDirectory[tfrecords_dir] """ This type can be used to denote that the output is a folder that contains tensorflow record files. diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index e9d6b307cc..65029b89ea 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -406,7 +406,14 @@ def register_renderer(cls, python_type: Type, renderer: Renderable): cls.Renderers[python_type] = renderer @classmethod - def register(cls, h: Handlers, default_for_type: bool = False, override: bool = False, default_format_for_type: bool = False, default_storage_for_type: bool = False): + def register( + cls, + h: Handlers, + default_for_type: bool = False, + override: bool = False, + default_format_for_type: bool = False, + default_storage_for_type: bool = False, + ): """ Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not specify a protocol (e.g. s3, gs, etc.) field, then diff --git a/tests/flytekit/unit/core/test_gate.py b/tests/flytekit/unit/core/test_gate.py index a4689ed814..8687b15ff3 100644 --- a/tests/flytekit/unit/core/test_gate.py +++ b/tests/flytekit/unit/core/test_gate.py @@ -218,7 +218,7 @@ def wf_dyn(a: int) -> typing.Tuple[int, int]: def test_subwf(): - nt = typing.NamedTuple("Multi", named1=int, named2=int) + nt = typing.NamedTuple("Multi", [("named1", int), ("named2", int)]) @task def nt1(a: int) -> nt: From a36ba267ecba6973190d2579968c66094f4920c9 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 27 Dec 2022 17:13:37 -0800 Subject: [PATCH 14/22] Fix tests Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 31102ea6f8..4a01845b0f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -361,17 +361,19 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A from flytekit.types.schema.types import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDataset - if hasattr(python_type, "__origin__"): - ot = getattr(python_type, "__origin__") - if ot is list: - return [ - self._serialize_flyte_type(v, getattr(python_type, "__args__")[0]) for v in cast(list, python_val) - ] - if ot is dict: - return { - k: self._serialize_flyte_type(v, getattr(python_type, "__args__")[1]) - for k, v in cast(dict, python_val).items() - } + # Handle Optional + if get_origin(python_type) is typing.Union and type(None) in get_args(python_type): + if python_val is None: + return None + return self._serialize_flyte_type(python_val, get_args(python_type)[0]) + + if hasattr(python_type, "__origin__") and get_origin(python_type) is list: + return [self._serialize_flyte_type(v, get_args(python_type)[0]) for v in cast(list, python_val)] + + if hasattr(python_type, "__origin__") and get_origin(python_type) is dict: + return { + k: self._serialize_flyte_type(v, get_args(python_type)[1]) for k, v in cast(dict, python_val).items() + } if not dataclasses.is_dataclass(python_type): return python_val From b41a2d633826bcf974a0a259c105641a6635f949 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 27 Dec 2022 21:42:34 -0800 Subject: [PATCH 15/22] Fix tests Signed-off-by: Kevin Su --- dev-requirements.in | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev-requirements.in b/dev-requirements.in index a02c8fa144..030c5453a9 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -17,3 +17,5 @@ tensorflow==2.8.1 # we put this constraint while we do not have per-environment requirements files torch<=1.12.1 scikit-learn +types-croniter +types-protobuf From 1b2e0bd148d122e688070bb2bb21f249908b0f19 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 12 Jan 2023 15:59:41 -0800 Subject: [PATCH 16/22] Fix tests Signed-off-by: Kevin Su --- flytekit/core/resources.py | 6 +++--- flytekit/core/utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index dccd18a5ee..4cf2523f6a 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -35,15 +35,15 @@ class Resources(object): @dataclass class ResourceSpec(object): - requests: Optional[Resources] = None - limits: Optional[Resources] = None + requests: Resources + limits: Resources _ResourceName = task_models.Resources.ResourceName _ResourceEntry = task_models.Resources.ResourceEntry -def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: +def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: # type: ignore resource_entries = [] if resources.cpu is not None: resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=resources.cpu)) diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index ae8b89a109..ee2c841465 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -51,7 +51,7 @@ def _dnsify(value: str) -> str: def _get_container_definition( image: str, command: List[str], - args: List[str], + args: Optional[List[str]] = None, data_loading_config: Optional[task_models.DataLoadingConfig] = None, storage_request: Optional[str] = None, ephemeral_storage_request: Optional[str] = None, From 676e61fe43988340d15c9637b671aa60a149bc9c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 12 Jan 2023 16:21:31 -0800 Subject: [PATCH 17/22] nit Signed-off-by: Kevin Su --- dev-requirements.in | 2 -- flytekit/types/structured/structured_dataset.py | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/dev-requirements.in b/dev-requirements.in index 030c5453a9..a02c8fa144 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -17,5 +17,3 @@ tensorflow==2.8.1 # we put this constraint while we do not have per-environment requirements files torch<=1.12.1 scikit-learn -types-croniter -types-protobuf diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index ad74ef2600..90755c8cc5 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -45,9 +45,7 @@ class (that is just a model, a Python class representation of the protobuf). """ uri: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) - file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String())) - - DEFAULT_FILE_FORMAT = PARQUET + file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String())) @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: From e50e7c5ddb9c836931c32add963c3ba86b169e03 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 20 Jan 2023 13:36:01 -0800 Subject: [PATCH 18/22] update dev-requirements.txt Signed-off-by: Kevin Su --- dev-requirements.in | 3 +++ dev-requirements.txt | 48 +++++++++++++++++--------------------------- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/dev-requirements.in b/dev-requirements.in index a02c8fa144..9e9f3b73be 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -17,3 +17,6 @@ tensorflow==2.8.1 # we put this constraint while we do not have per-environment requirements files torch<=1.12.1 scikit-learn +types-protobuf +types-croniter +types-mock diff --git a/dev-requirements.txt b/dev-requirements.txt index ea0e19354a..75e6b84f38 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile with Python 3.7 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # -# make dev-requirements.txt +# pip-compile dev-requirements.in # -e file:.#egg=flytekit # via @@ -12,6 +12,8 @@ absl-py==1.3.0 # via # tensorboard # tensorflow +appnope==0.1.3 + # via ipython arrow==1.2.3 # via # -c requirements.txt @@ -32,8 +34,6 @@ binaryornot==0.4.4 # via # -c requirements.txt # cookiecutter -cached-property==1.5.2 - # via docker-compose cachetools==5.2.0 # via google-auth certifi==2022.12.7 @@ -83,7 +83,6 @@ cryptography==38.0.4 # -c requirements.txt # paramiko # pyopenssl - # secretstorage dataclasses-json==0.5.7 # via # -c requirements.txt @@ -136,6 +135,10 @@ flyteidl==1.3.1 # flytekit gast==0.5.3 # via tensorflow +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via flytekit google-api-core[grpc]==2.11.0 # via # google-cloud-bigquery @@ -167,6 +170,7 @@ googleapis-common-protos==1.57.0 # via # -c requirements.txt # flyteidl + # flytekit # google-api-core # grpcio-status grpcio==1.51.1 @@ -194,15 +198,9 @@ idna==3.4 importlib-metadata==5.1.0 # via # -c requirements.txt - # click # flytekit - # jsonschema # keyring # markdown - # pluggy - # pre-commit - # pytest - # virtualenv iniconfig==1.1.1 # via pytest ipython==7.34.0 @@ -213,11 +211,6 @@ jaraco-classes==3.2.3 # keyring jedi==0.18.2 # via ipython -jeepney==0.8.0 - # via - # -c requirements.txt - # keyring - # secretstorage jinja2==3.1.2 # via # -c requirements.txt @@ -470,14 +463,6 @@ scikit-learn==1.0.2 # via -r dev-requirements.in scipy==1.7.3 # via scikit-learn -secretstorage==3.3.3 - # via - # -c requirements.txt - # keyring -singledispatchmethod==1.0 - # via - # -c requirements.txt - # flytekit six==1.16.0 # via # -c requirements.txt @@ -491,6 +476,8 @@ six==1.16.0 # python-dateutil # tensorflow # websocket-client +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via # -c requirements.txt @@ -531,14 +518,18 @@ tomli==2.0.1 # coverage # mypy # pytest -torch==1.13.1 +torch==1.12.1 # via -r dev-requirements.in traitlets==5.6.0 # via # ipython # matplotlib-inline -typed-ast==1.5.4 - # via mypy +types-croniter==1.3.2.2 + # via -r dev-requirements.in +types-mock==5.0.0.2 + # via -r dev-requirements.in +types-protobuf==4.21.0.3 + # via -r dev-requirements.in types-toml==0.10.8.1 # via # -c requirements.txt @@ -546,11 +537,8 @@ types-toml==0.10.8.1 typing-extensions==4.4.0 # via # -c requirements.txt - # arrow # flytekit - # importlib-metadata # mypy - # responses # tensorflow # torch # typing-inspect From 37925a56cf4c0367682cf7557a098e2f8b97ca7c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 23 Jan 2023 01:22:31 -0800 Subject: [PATCH 19/22] Address comment Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 2 +- flytekit/core/interface.py | 4 ++-- flytekit/core/reference.py | 2 +- flytekit/core/reference_entity.py | 2 +- flytekit/core/testing.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 90436c402f..a1c8ec037e 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -159,7 +159,7 @@ def __init__( self, task_type: str, name: str, - interface: Optional[_interface_models.TypedInterface] = None, + interface: _interface_models.TypedInterface, metadata: Optional[TaskMetadata] = None, task_type_version=0, security_ctx: Optional[SecurityContext] = None, diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index edc4203842..f650090153 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -28,8 +28,8 @@ class Interface(object): def __init__( self, - inputs: Optional[Dict[str, Type]] | Optional[Dict[str, Tuple[Type, Any]]] = None, - outputs: Optional[Dict[str, Type]] | Optional[Dict[str, Optional[Type]]] = None, + inputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Tuple[Type, Any]]]] = None, + outputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Optional[Type]]]] = None, output_tuple_name: Optional[str] = None, docstring: Optional[Docstring] = None, ): diff --git a/flytekit/core/reference.py b/flytekit/core/reference.py index cad44268ff..6a88549c43 100644 --- a/flytekit/core/reference.py +++ b/flytekit/core/reference.py @@ -15,7 +15,7 @@ def get_reference_entity( domain: str, name: str, version: str, - inputs: Dict[str, type], + inputs: Dict[str, Type], outputs: Dict[str, Type], ): """ diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index d8a8f620f7..b1ab790e90 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -66,7 +66,7 @@ class ReferenceEntity(object): def __init__( self, reference: Union[WorkflowReference, TaskReference, LaunchPlanReference], - inputs: Dict[str, type], + inputs: Dict[str, Type], outputs: Dict[str, Type], ): if ( diff --git a/flytekit/core/testing.py b/flytekit/core/testing.py index 055e47efd4..f1a0fec7de 100644 --- a/flytekit/core/testing.py +++ b/flytekit/core/testing.py @@ -10,7 +10,7 @@ @contextmanager -def task_mock(t: PythonTask) -> typing.Iterator[MagicMock]: +def task_mock(t: PythonTask) -> typing.Generator[MagicMock, None, None]: """ Use this method to mock a task declaration. It can mock any Task in Flytekit as long as it has a python native interface associated with it. From 0bd5affac095d2a99ece9099596ccd9abc767c70 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 23 Jan 2023 01:26:15 -0800 Subject: [PATCH 20/22] upgrade torch Signed-off-by: Kevin Su --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 75e6b84f38..057db14721 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -518,7 +518,7 @@ tomli==2.0.1 # coverage # mypy # pytest -torch==1.12.1 +torch==1.13.1 # via -r dev-requirements.in traitlets==5.6.0 # via From 6b4bcaf182a380ae6df04bff6eea2a81fb9c33a7 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 15 Feb 2023 14:01:29 -0800 Subject: [PATCH 21/22] nit Signed-off-by: Kevin Su --- Makefile | 4 ++-- flytekit/core/context_manager.py | 2 +- flytekit/core/interface.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index f1f932cd69..f800cb474a 100644 --- a/Makefile +++ b/Makefile @@ -38,8 +38,8 @@ lint: ## Run linters mypy flytekit/core mypy flytekit/types # allow-empty-bodies: Allow empty body in function. - # disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked". - # Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass. + # disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked". + # Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass. mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core pre-commit run --all-files diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 154f0bea77..fc8915e338 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -852,7 +852,7 @@ class FlyteEntities(object): registration process """ - entities: List["LaunchPlan" | Task | "WorkflowBase"] = [] # type: ignore + entities: List[Union["LaunchPlan", Task, "WorkflowBase"]] = [] # type: ignore FlyteContextManager.initialize() diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index f650090153..3c24e65db2 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -43,7 +43,7 @@ def __init__( primarily used when handling one-element NamedTuples. :param docstring: Docstring of the annotated @task or @workflow from which the interface derives from. """ - self._inputs: Dict[str, Tuple[Type, Any]] | Dict[str, Type] = {} # type: ignore + self._inputs: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]] = {} # type: ignore if inputs: for k, v in inputs.items(): if type(v) is tuple and len(cast(Tuple, v)) > 1: From ed8129415e9bb31a2d91bd37597504990b6cbac6 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 15 Feb 2023 14:12:54 -0800 Subject: [PATCH 22/22] lint Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 4 ++-- flytekit/core/python_auto_container.py | 14 +++++++------- flytekit/core/type_engine.py | 2 +- flytekit/core/workflow.py | 2 +- .../unit/core/test_python_function_task.py | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index bd7a4446bb..f163e891e1 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -340,8 +340,8 @@ def sandbox_execute( """ Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime. """ - es = ctx.execution_state - b = es.user_space_params.with_task_sandbox() + es = cast(ExecutionState, ctx.execution_state) + b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox() ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() return self.dispatch_execute(ctx, input_literal_map) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 9a6e376132..113f94a998 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,7 +3,7 @@ import importlib import re from abc import ABC -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, cast from flyteidl.core import tasks_pb2 as _core_task from kubernetes.client import ApiClient @@ -207,23 +207,23 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain ) def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: - containers = self.pod_template.pod_spec.containers + containers = cast(PodTemplate, self.pod_template).pod_spec.containers primary_exists = False for container in containers: - if container.name == self.pod_template.primary_container_name: + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: primary_exists = True break if not primary_exists: # insert a placeholder primary container if it is not defined in the pod spec. - containers.append(V1Container(name=self.pod_template.primary_container_name)) + containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name)) final_containers = [] for container in containers: # In the case of the primary container, we overwrite specific container attributes # with the default values used in the regular Python task. # The attributes include: image, command, args, resource, and env (env is unioned) - if container.name == self.pod_template.primary_container_name: + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: sdk_default_container = self._get_container(settings) container.image = sdk_default_container.image # clear existing commands @@ -243,9 +243,9 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any] container.env or [] ) final_containers.append(container) - self.pod_template.pod_spec.containers = final_containers + cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers - return ApiClient().sanitize_for_serialization(self.pod_template.pod_spec) + return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec) def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e4806ccfb4..9cbfc1fe7c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1152,7 +1152,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def guess_python_type(self, literal_type: LiteralType) -> type: if literal_type.union_type is not None: - return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] + return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] # type: ignore raise ValueError(f"Union transformer cannot reverse {literal_type}") diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 79adb55816..f8ba257d7e 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -766,7 +766,7 @@ def wrapper(fn): if _workflow_function: return wrapper(_workflow_function) else: - return wrapper + return wrapper # type: ignore class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignore diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 02e04a302f..7bbdd23a21 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -145,7 +145,7 @@ def test_pod_template(): pod_template_name="A", ) def func_with_pod_template(i: str): - print(i + 3) + print(i + "a") default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") default_image_config = ImageConfig(default_image=default_image)