diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 7e4ff02645..f645df8f9d 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -10,6 +10,7 @@ import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 +from google.protobuf import wrappers_pb2 as _google_wrappers_pb2 import flytekit from flytekit.models import common as _common_models @@ -179,6 +180,7 @@ def __init__( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, overwrite_cache: Optional[bool] = None, + interruptible: Optional[bool] = None, envs: Optional[_common_models.Envs] = None, tags: Optional[typing.List[str]] = None, cluster_assignment: Optional[ClusterAssignment] = None, @@ -198,6 +200,7 @@ def __init__( parallelism/concurrency of MapTasks is independent from this. :param security_context: Optional security context to use for this execution. :param overwrite_cache: Optional flag to overwrite the cache for this execution. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: flytekit.models.common.Envs environment variables to set for this execution. :param tags: Optional list of tags to apply to the execution. :param execution_cluster_label: Optional execution cluster label to use for this execution. @@ -213,6 +216,7 @@ def __init__( self._max_parallelism = max_parallelism self._security_context = security_context self._overwrite_cache = overwrite_cache + self._interruptible = interruptible self._envs = envs self._tags = tags self._cluster_assignment = cluster_assignment @@ -287,6 +291,10 @@ def security_context(self) -> typing.Optional[security.SecurityContext]: def overwrite_cache(self) -> Optional[bool]: return self._overwrite_cache + @property + def interruptible(self) -> Optional[bool]: + return self._interruptible + @property def envs(self) -> Optional[_common_models.Envs]: return self._envs @@ -321,6 +329,9 @@ def to_flyte_idl(self): max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, + interruptible=_google_wrappers_pb2.BoolValue(value=self.interruptible) + if self.interruptible is not None + else None, envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, @@ -351,6 +362,7 @@ def from_flyte_idl(cls, p): if p.security_context else None, overwrite_cache=p.overwrite_cache, + interruptible=p.interruptible.value if p.HasField("interruptible") else None, envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None, tags=p.tags, cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 7eda76ddfa..7c3f94fd6b 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1290,6 +1290,7 @@ def _execute( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1308,6 +1309,7 @@ def _execute( :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1379,6 +1381,7 @@ def _execute( 0, ), overwrite_cache=overwrite_cache, + interruptible=interruptible, notifications=notifications, disable_all=options.disable_notifications, labels=options.labels, @@ -1455,6 +1458,7 @@ def execute( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1495,6 +1499,7 @@ def execute( :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to be set for the execution. :param tags: Tags to be set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1519,6 +1524,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1536,6 +1542,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1551,6 +1558,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1566,6 +1574,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1581,6 +1590,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1599,6 +1609,7 @@ def execute( image_config=image_config, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1618,6 +1629,7 @@ def execute( options=options, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1636,6 +1648,7 @@ def execute( options=options, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1658,6 +1671,7 @@ def execute_remote_task_lp( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1678,6 +1692,7 @@ def execute_remote_task_lp( options=options, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1696,6 +1711,7 @@ def execute_remote_wf( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1717,6 +1733,7 @@ def execute_remote_wf( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1735,6 +1752,7 @@ def execute_reference_task( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1766,6 +1784,7 @@ def execute_reference_task( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1782,6 +1801,7 @@ def execute_reference_workflow( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1827,6 +1847,7 @@ def execute_reference_workflow( options=options, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1843,6 +1864,7 @@ def execute_reference_launch_plan( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1874,6 +1896,7 @@ def execute_reference_launch_plan( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1896,6 +1919,7 @@ def execute_local_task( image_config: typing.Optional[ImageConfig] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1914,6 +1938,7 @@ def execute_local_task( :param image_config: If provided, will use this image config in the pod. :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1954,6 +1979,7 @@ def execute_local_task( wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1974,6 +2000,7 @@ def execute_local_workflow( options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1981,22 +2008,23 @@ def execute_local_workflow( ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. - :param entity: - :param inputs: - :param project: - :param domain: - :param name: - :param version: - :param execution_name: - :param image_config: - :param options: - :param wait: - :param overwrite_cache: - :param envs: - :param tags: - :param cluster_pool: - :param execution_cluster_label: - :return: + :param entity: The workflow to execute + :param inputs: Input dictionary + :param project: Project to execute in + :param domain: Domain to execute in + :param name: Optional name override for the workflow + :param version: Optional version for the workflow + :param execution_name: Optional name for the execution + :param image_config: Optional image config override + :param options: Optional Options object + :param wait: Whether to wait for execution completion + :param overwrite_cache: If True, will overwrite the cache + :param interruptible: Optional flag to override the default interruptible flag of the executed entity + :param envs: Environment variables to set for the execution + :param tags: Tags to set for the execution + :param cluster_pool: Specify cluster pool on which newly created execution should be placed + :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed + :return: FlyteWorkflowExecution object """ if not image_config: image_config = ImageConfig.auto_default_image() @@ -2052,6 +2080,7 @@ def execute_local_workflow( options=options, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -2071,12 +2100,14 @@ def execute_local_launch_plan( options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ + Execute a locally defined `LaunchPlan`. :param entity: The locally defined launch plan object :param inputs: Inputs to be passed into the execution as a dict with Python native values. @@ -2088,6 +2119,7 @@ def execute_local_launch_plan( :param options: Options to be passed into the execution. :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to be passed into the execution. :param tags: Tags to be passed into the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -2119,6 +2151,7 @@ def execute_local_launch_plan( wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 4d77e1b610..d24c1ffbb3 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -528,6 +528,25 @@ def test_execute_workflow_with_maptask(register): ) assert execution.outputs["o0"] == [4, 5, 6] +def test_executes_nested_workflow_dictating_interruptible(register): + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + flyte_launch_plan = remote.fetch_launch_plan(name="basic.child_workflow.parent_wf", version=VERSION) + # The values we want to test for + interruptible_values = [True, False, None] + executions = [] + for creation_interruptible in interruptible_values: + execution = remote.execute(flyte_launch_plan, inputs={"a": 10}, wait=False, interruptible=creation_interruptible) + executions.append(execution) + # Wait for all executions to complete + for execution, expected_interruptible in zip(executions, interruptible_values): + execution = remote.wait(execution, timeout=300) + # Check that the parent workflow is interruptible as expected + assert execution.spec.interruptible == expected_interruptible + # Check that the child workflow is interruptible as expected + subwf_execution_id = execution.node_executions["n1"].closure.workflow_node_metadata.execution_id.name + subwf_execution = remote.fetch_execution(project=PROJECT, domain=DOMAIN, name=subwf_execution_id) + assert subwf_execution.spec.interruptible == expected_interruptible + @pytest.mark.lftransfers class TestLargeFileTransfers: