diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index d8c215a598..ed46a29583 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -7,10 +7,13 @@ import sys import tempfile import typing +import typing as t from dataclasses import dataclass, field, fields from typing import Iterator, get_args import rich_click as click +import yaml +from click import Context from mashumaro.codecs.json import JSONEncoder from rich.progress import Progress from typing_extensions import get_origin @@ -25,7 +28,12 @@ pretty_print_exception, project_option, ) -from flytekit.configuration import DefaultImages, FastSerializationSettings, ImageConfig, SerializationSettings +from flytekit.configuration import ( + DefaultImages, + FastSerializationSettings, + ImageConfig, + SerializationSettings, +) from flytekit.configuration.plugin import get_plugin from flytekit.core import context_manager from flytekit.core.artifact import ArtifactQuery @@ -34,14 +42,24 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.exceptions.system import FlyteSystemException -from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback, labels_callback +from flytekit.interaction.click_types import ( + FlyteLiteralConverter, + key_value_callback, + labels_callback, +) from flytekit.interaction.string_literals import literal_string_repr from flytekit.loggers import logger from flytekit.models import security from flytekit.models.common import RawOutputDataConfig from flytekit.models.interface import Parameter, Variable from flytekit.models.types import SimpleType -from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow, remote_fs +from flytekit.remote import ( + FlyteLaunchPlan, + FlyteRemote, + FlyteTask, + FlyteWorkflow, + remote_fs, +) from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader from flytekit.tools.script_mode import _find_project_root, compress_scripts, get_all_modules @@ -489,7 +507,8 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder: return ctx.current_context().new_builder() file_access = FileAccessProvider( - local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=output_prefix + local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), + raw_output_prefix=output_prefix, ) # The task might run on a remote machine if raw_output_prefix is a remote path, @@ -539,7 +558,10 @@ def _run(*args, **kwargs): entity_type = "workflow" if isinstance(entity, PythonFunctionWorkflow) else "task" logger.debug(f"Running {entity_type} {entity.name} with input {kwargs}") - click.secho(f"Running Execution on {'Remote' if run_level_params.is_remote else 'local'}.", fg="cyan") + click.secho( + f"Running Execution on {'Remote' if run_level_params.is_remote else 'local'}.", + fg="cyan", + ) try: inputs = {} for input_name, v in entity.python_interface.inputs_with_defaults.items(): @@ -576,6 +598,8 @@ def _run(*args, **kwargs): ) if processed_click_value is not None or optional_v: inputs[input_name] = processed_click_value + if processed_click_value is None and v[0] == bool: + inputs[input_name] = False if not run_level_params.is_remote: with FlyteContextManager.with_context(_update_flyte_context(run_level_params)): @@ -755,7 +779,10 @@ def list_commands(self, ctx): run_level_params: RunLevelParams = ctx.obj r = run_level_params.remote_instance() progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...", total=None) + task = progress.add_task( + f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...", + total=None, + ) with progress: progress.start_task(task) try: @@ -783,6 +810,70 @@ def get_command(self, ctx, name): ) +class YamlFileReadingCommand(click.RichCommand): + def __init__( + self, + name: str, + params: typing.List[click.Option], + help: str, + callback: typing.Callable = None, + ): + params.append( + click.Option( + ["--inputs-file"], + required=False, + type=click.Path(exists=True, dir_okay=False, resolve_path=True), + help="Path to a YAML | JSON file containing inputs for the workflow.", + ) + ) + super().__init__(name=name, params=params, callback=callback, help=help) + + def parse_args(self, ctx: Context, args: t.List[str]) -> t.List[str]: + def load_inputs(f: str) -> t.Dict[str, str]: + try: + inputs = yaml.safe_load(f) + except yaml.YAMLError as e: + yaml_e = e + try: + inputs = json.loads(f) + except json.JSONDecodeError as e: + raise click.BadParameter( + message=f"Could not load the inputs file. Please make sure it is a valid JSON or YAML file." + f"\n json error: {e}," + f"\n yaml error: {yaml_e}", + param_hint="--inputs-file", + ) + + return inputs + + inputs = {} + if "--inputs-file" in args: + idx = args.index("--inputs-file") + args.pop(idx) + f = args.pop(idx) + with open(f, "r") as f: + inputs = load_inputs(f.read()) + elif not sys.stdin.isatty(): + f = sys.stdin.read() + if f != "": + inputs = load_inputs(f) + + new_args = [] + for k, v in inputs.items(): + if isinstance(v, str): + new_args.extend([f"--{k}", v]) + elif isinstance(v, bool): + if v: + new_args.append(f"--{k}") + else: + v = json.dumps(v) + new_args.extend([f"--{k}", v]) + new_args.extend(args) + args = new_args + + return super().parse_args(ctx, args) + + class WorkflowCommand(click.RichGroup): """ click multicommand at the python file layer, subcommands should be all the workflows in the file. @@ -837,11 +928,11 @@ def _create_command( h = f"{click.style(entity_type, bold=True)} ({run_level_params.computed_params.module}.{entity_name})" if loaded_entity.__doc__: h = h + click.style(f"{loaded_entity.__doc__}", dim=True) - cmd = click.RichCommand( + cmd = YamlFileReadingCommand( name=entity_name, params=params, - callback=run_command(ctx, loaded_entity), help=h, + callback=run_command(ctx, loaded_entity), ) return cmd diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 8124f617b3..cbfd08ae2f 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -6,7 +6,18 @@ import sys import typing from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast +from typing import ( + Any, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) from flyteidl.core import artifact_id_pb2 as art_id from typing_extensions import get_args, get_type_hints @@ -370,7 +381,9 @@ def transform_interface_to_list_interface( def transform_function_to_interface( - fn: typing.Callable, docstring: Optional[Docstring] = None, is_reference_entity: bool = False + fn: typing.Callable, + docstring: Optional[Docstring] = None, + is_reference_entity: bool = False, ) -> Interface: """ From the annotations on a task function that the user should have provided, and the output names they want to use @@ -463,7 +476,9 @@ def transform_type(x: type, description: Optional[str] = None) -> _interface_mod if artifact_id: logger.debug(f"Found artifact id spec: {artifact_id}") return _interface_models.Variable( - type=TypeEngine.to_literal_type(x), description=description, artifact_partial_id=artifact_id + type=TypeEngine.to_literal_type(x), + description=description, + artifact_partial_id=artifact_id, ) diff --git a/flytekit/image_spec/default_builder.py b/flytekit/image_spec/default_builder.py index 50fcc4ea8a..32f20d6373 100644 --- a/flytekit/image_spec/default_builder.py +++ b/flytekit/image_spec/default_builder.py @@ -19,25 +19,24 @@ ) from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore -UV_PYTHON_INSTALL_COMMAND_TEMPLATE = Template("""\ +UV_PYTHON_INSTALL_COMMAND_TEMPLATE = Template( + """\ RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \ --mount=from=uv,source=/uv,target=/usr/bin/uv \ --mount=type=bind,target=requirements_uv.txt,src=requirements_uv.txt \ /usr/bin/uv \ pip install --python /opt/micromamba/envs/runtime/bin/python $PIP_EXTRA \ --requirement requirements_uv.txt -""") +""" +) -APT_INSTALL_COMMAND_TEMPLATE = Template( - """\ +APT_INSTALL_COMMAND_TEMPLATE = Template("""\ RUN --mount=type=cache,sharing=locked,mode=0777,target=/var/cache/apt,id=apt \ apt-get update && apt-get install -y --no-install-recommends \ $APT_PACKAGES -""" -) +""") -DOCKER_FILE_TEMPLATE = Template( - """\ +DOCKER_FILE_TEMPLATE = Template("""\ #syntax=docker/dockerfile:1.5 FROM ghcr.io/astral-sh/uv:0.2.37 as uv FROM mambaorg/micromamba:1.5.8-bookworm-slim as micromamba @@ -84,8 +83,7 @@ USER flytekit RUN mkdir -p $$HOME && \ echo "export PATH=$$PATH" >> $$HOME/.profile -""" -) +""") def get_flytekit_for_pypi(): diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 966425f901..c50d7f0984 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -206,7 +206,7 @@ def _convert_replica_spec( replicas=replicas, image=replica_config.image, resources=resources.to_flyte_idl() if resources else None, - restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + restart_policy=(replica_config.restart_policy.value if replica_config.restart_policy else None), ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: @@ -289,9 +289,11 @@ def spawn_helper( return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om) -def _convert_run_policy_to_flyte_idl(run_policy: RunPolicy) -> kubeflow_common.RunPolicy: +def _convert_run_policy_to_flyte_idl( + run_policy: RunPolicy, +) -> kubeflow_common.RunPolicy: return kubeflow_common.RunPolicy( - clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, + clean_pod_policy=(run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None), ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, active_deadline_seconds=run_policy.active_deadline_seconds, backoff_limit=run_policy.backoff_limit, @@ -416,7 +418,13 @@ def _execute(self, **kwargs) -> Any: checkpoint_dest = None checkpoint_src = None - launcher_args = (dumped_target_function, ctx.raw_output_prefix, checkpoint_dest, checkpoint_src, kwargs) + launcher_args = ( + dumped_target_function, + ctx.raw_output_prefix, + checkpoint_dest, + checkpoint_src, + kwargs, + ) elif self.task_config.start_method == "fork": """ The torch elastic launcher doesn't support passing kwargs to the target function, @@ -440,7 +448,11 @@ def fn_partial(): if isinstance(e, FlyteRecoverableException): create_recoverable_error_file() raise - return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om) + return ElasticWorkerResult( + return_value=return_val, + decks=flytekit.current_context().decks, + om=om, + ) launcher_target_func = fn_partial launcher_args = () diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index 39f1e0bb80..faadc1019f 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -62,7 +62,7 @@ def test_end_to_end(start_method: str) -> None: """Test that the workflow with elastic task runs end to end.""" world_size = 2 - train_task = task(train, task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) + train_task = task(train,task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) @workflow def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, int]: @@ -89,9 +89,7 @@ def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, ("fork", "local", False), ], ) -def test_execution_params( - start_method: str, target_exec_id: str, monkeypatch_exec_id_env_var: bool, monkeypatch -) -> None: +def test_execution_params(start_method: str, target_exec_id: str, monkeypatch_exec_id_env_var: bool, monkeypatch) -> None: """Test that execution parameters are set in the worker processes.""" if monkeypatch_exec_id_env_var: monkeypatch.setenv("FLYTE_INTERNAL_EXECUTION_ID", target_exec_id) @@ -117,7 +115,7 @@ def test_rdzv_configs(start_method: str) -> None: rdzv_configs = {"join_timeout": 10} - @task(task_config=Elastic(nnodes=1, nproc_per_node=2, start_method=start_method, rdzv_configs=rdzv_configs)) + @task(task_config=Elastic(nnodes=1,nproc_per_node=2,start_method=start_method,rdzv_configs=rdzv_configs)) def test_task(): pass @@ -131,15 +129,12 @@ def test_deck(start_method: str) -> None: """Test that decks created in the main worker process are transferred to the parent process.""" world_size = 2 - @task( - task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method), - enable_deck=True, - ) + @task(task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method), enable_deck=True) def train(): import os ctx = flytekit.current_context() - deck = flytekit.Deck("test-deck", f"Hello Flyte Deck viewer from worker process {os.environ.get('RANK')}") + deck = flytekit.Deck("test-deck", f"Hello Flyte Deck viewer from worker process {os.environ.get('RANK')}",) ctx.decks.append(deck) default_deck = ctx.default_deck default_deck.append("Hello from default deck") @@ -189,9 +184,7 @@ def wf(): ctx = FlyteContext.current_context() omt = OutputMetadataTracker() - with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION)).with_output_metadata_tracker(omt) - ) as child_ctx: + with FlyteContextManager.with_context(ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION)).with_output_metadata_tracker(omt)) as child_ctx: cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] # call execute directly so as to be able to get at the same FlyteContext object. res = train2.execute() @@ -215,9 +208,7 @@ def test_recoverable_error(recoverable: bool, start_method: str) -> None: class CustomRecoverableException(FlyteRecoverableException): pass - @task( - task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method), - ) + @task(task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) def train(recoverable: bool): if recoverable: raise CustomRecoverableException("Recoverable error") @@ -244,7 +235,6 @@ def test_task(): assert test_task.task_config.rdzv_configs == {"join_timeout": 900, "timeout": 900} - def test_run_policy() -> None: """Test that run policy is propagated to custom spec.""" @@ -268,6 +258,7 @@ def test_task(): "activeDeadlineSeconds": 36000, } + @pytest.mark.parametrize("start_method", ["spawn", "fork"]) def test_omp_num_threads(start_method: str) -> None: """Test that the env var OMP_NUM_THREADS is set by default and not overwritten if set.""" diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 7e0661f808..ef47aa3529 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -100,7 +100,10 @@ def test_fetch_execute_launch_plan_with_args(register): flyte_launch_plan = remote.fetch_launch_plan(name="basic.basic_workflow.my_wf", version=VERSION) execution = remote.execute(flyte_launch_plan, inputs={"a": 10, "b": "foobar"}, wait=True) assert execution.node_executions["n0"].inputs == {"a": 10} - assert execution.node_executions["n0"].outputs == {"t1_int_output": 12, "c": "world"} + assert execution.node_executions["n0"].outputs == { + "t1_int_output": 12, + "c": "world", + } assert execution.node_executions["n1"].inputs == {"a": "world", "b": "foobar"} assert execution.node_executions["n1"].outputs == {"o0": "foobarworld"} assert execution.node_executions["n0"].task_executions[0].inputs == {"a": 10} @@ -130,7 +133,7 @@ def test_monitor_workflow_execution(register): break with pytest.raises( - FlyteAssertion, match="Please wait until the execution has completed before requesting the outputs." + FlyteAssertion, match="Please wait until the execution has completed before requesting the outputs.", ): execution.outputs @@ -241,7 +244,11 @@ def test_execute_python_workflow_and_launch_plan(register): launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name) execution = remote.execute( - launch_plan, name="basic.basic_workflow.my_wf", inputs={"a": 14, "b": "foobar"}, version=VERSION, wait=True + launch_plan, + name="basic.basic_workflow.my_wf", + inputs={"a": 14, "b": "foobar"}, + version=VERSION, + wait=True, ) assert execution.outputs["o0"] == 16 assert execution.outputs["o1"] == "foobarworld" @@ -269,7 +276,9 @@ def test_fetch_execute_task_list_of_floats(register): def test_fetch_execute_task_convert_dict(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) - flyte_task = remote.fetch_task(name="basic.dict_str_wf.convert_to_string", version=VERSION) + flyte_task = remote.fetch_task( + name="basic.dict_str_wf.convert_to_string", version=VERSION + ) d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"} execution = remote.execute(flyte_task, inputs={"d": d}, wait=True) remote.sync_execution(execution, sync_nodes=True) @@ -374,9 +383,7 @@ def test_execute_with_default_launch_plan(register): from .workflows.basic.subworkflows import parent_wf remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) - execution = remote.execute( - parent_wf, inputs={"a": 101}, version=VERSION, wait=True, image_config=ImageConfig.auto(img_name=IMAGE) - ) + execution = remote.execute(parent_wf, inputs={"a": 101}, version=VERSION, wait=True, image_config=ImageConfig.auto(img_name=IMAGE)) # check node execution inputs and outputs assert execution.node_executions["n0"].inputs == {"a": 101} assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"} diff --git a/tests/flytekit/unit/cli/pyflyte/my_wf_input.json b/tests/flytekit/unit/cli/pyflyte/my_wf_input.json new file mode 100644 index 0000000000..c20081f3b2 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/my_wf_input.json @@ -0,0 +1,47 @@ +{ + "a": 1, + "b": "Hello", + "c": 1.1, + "d": { + "i": 1, + "a": [ + "h", + "e" + ] + }, + "e": [ + 1, + 2, + 3 + ], + "f": { + "x": 1.0, + "y": 2.0 + }, + "g": "tests/flytekit/unit/cli/pyflyte/testdata/df.parquet", + "h": true, + "i": "2020-05-01", + "j": "20H", + "k": "RED", + "l": { + "hello": "world" + }, + "m": { + "a": "b", + "c": "d" + }, + "n": [ + { + "x": "tests/flytekit/unit/cli/pyflyte/testdata/df.parquet" + } + ], + "o": { + "x": [ + "tests/flytekit/unit/cli/pyflyte/testdata/df.parquet" + ] + }, + "p": "None", + "q": "tests/flytekit/unit/cli/pyflyte/testdata", + "remote": "tests/flytekit/unit/cli/pyflyte/testdata", + "image": "tests/flytekit/unit/cli/pyflyte/testdata" +} diff --git a/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml b/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml new file mode 100644 index 0000000000..678f5331c8 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml @@ -0,0 +1,34 @@ +a: 1 +b: Hello +c: 1.1 +d: + i: 1 + a: + - h + - e +e: + - 1 + - 2 + - 3 +f: + x: 1.0 + y: 2.0 +g: tests/flytekit/unit/cli/pyflyte/testdata/df.parquet +h: true +i: '2020-05-01' +j: 20H +k: RED +l: + hello: world +m: + a: b + c: d +n: + - x: tests/flytekit/unit/cli/pyflyte/testdata/df.parquet +o: + x: + - tests/flytekit/unit/cli/pyflyte/testdata/df.parquet +p: 'None' +q: tests/flytekit/unit/cli/pyflyte/testdata +remote: tests/flytekit/unit/cli/pyflyte/testdata +image: tests/flytekit/unit/cli/pyflyte/testdata diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 3eb3062de9..475fb42ff1 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -4,6 +4,7 @@ import pathlib import shutil import sys +import io import mock import pytest @@ -39,6 +40,8 @@ ) DIR_NAME = os.path.dirname(os.path.realpath(__file__)) +monkeypatch = pytest.MonkeyPatch() + class WorkflowFileLocation(enum.Enum): NORMAL = enum.auto() @@ -230,6 +233,92 @@ def test_union_type1(input): assert result.exit_code == 0 +def test_all_types_with_json_input(): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "my_wf", + "--inputs-file", + os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"), + ], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + + +def test_all_types_with_yaml_input(): + runner = CliRunner() + + result = runner.invoke( + pyflyte.main, + ["run", os.path.join(DIR_NAME, "workflow.py"), "my_wf", "--inputs-file", os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.yaml")], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + + +def test_all_types_with_pipe_input(monkeypatch): + runner = CliRunner() + input= str(json.load(open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"),"r"))) + monkeypatch.setattr("sys.stdin", io.StringIO(input)) + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "my_wf", + ], + input=input, + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + + +@pytest.mark.parametrize( + "pipe_input, option_input", + [ + ( + str( + json.load( + open( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "my_wf_input.json", + ), + "r", + ) + ) + ), + "GREEN", + ) + ], +) +def test_replace_file_inputs(monkeypatch, pipe_input, option_input): + runner = CliRunner() + monkeypatch.setattr("sys.stdin", io.StringIO(pipe_input)) + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "my_wf", + "--inputs-file", + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json" + ), + "--k", + option_input, + ], + input=pipe_input, + ) + + assert result.exit_code == 0 + assert option_input in result.output + + @pytest.mark.parametrize( "input", [2.0, '{"i":1,"a":["h","e"]}', "[1, 2, 3]"], @@ -276,7 +365,9 @@ def test_union_type_with_invalid_input(): assert result.exit_code == 2 -@pytest.mark.skipif(sys.version_info < (3, 9), reason="listing entities requires python>=3.9") +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="listing entities requires python>=3.9" +) @pytest.mark.parametrize( "workflow_file", [ @@ -287,12 +378,13 @@ def test_union_type_with_invalid_input(): ) def test_get_entities_in_file(workflow_file): e = get_entities_in_file(pathlib.Path(workflow_file), False) - assert e.workflows == ["my_wf", "wf_with_env_vars", "wf_with_none"] + assert e.workflows == ["my_wf", "wf_with_env_vars", "wf_with_list", "wf_with_none"] assert e.tasks == [ "get_subset_df", "print_all", "show_sd", "task_with_env_vars", + "task_with_list", "task_with_optional", "test_union1", "test_union2", @@ -300,11 +392,13 @@ def test_get_entities_in_file(workflow_file): assert e.all() == [ "my_wf", "wf_with_env_vars", + "wf_with_list", "wf_with_none", "get_subset_df", "print_all", "show_sd", "task_with_env_vars", + "task_with_list", "task_with_optional", "test_union1", "test_union2", diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 95535d2fc0..accebf82df 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -125,3 +125,11 @@ def task_with_env_vars(env_vars: typing.List[str]) -> str: @workflow def wf_with_env_vars(env_vars: typing.List[str]) -> str: return task_with_env_vars(env_vars=env_vars) + +@task +def task_with_list(a: typing.List[int]) -> typing.List[int]: + return a + +@workflow +def wf_with_list(a: typing.List[int]) -> typing.List[int]: + return task_with_list(a=a)