Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FlyteRemote.execute interruptible flag override #2885

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions flytekit/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import flyteidl.admin.execution_pb2 as _execution_pb2
import flyteidl.admin.node_execution_pb2 as _node_execution_pb2
import flyteidl.admin.task_execution_pb2 as _task_execution_pb2
from google.protobuf import wrappers_pb2 as _google_wrappers_pb2

import flytekit
from flytekit.models import common as _common_models
Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
overwrite_cache: Optional[bool] = None,
interruptible: Optional[bool] = None,
envs: Optional[_common_models.Envs] = None,
tags: Optional[typing.List[str]] = None,
cluster_assignment: Optional[ClusterAssignment] = None,
Expand All @@ -198,6 +200,7 @@ def __init__(
parallelism/concurrency of MapTasks is independent from this.
:param security_context: Optional security context to use for this execution.
:param overwrite_cache: Optional flag to overwrite the cache for this execution.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: flytekit.models.common.Envs environment variables to set for this execution.
:param tags: Optional list of tags to apply to the execution.
:param execution_cluster_label: Optional execution cluster label to use for this execution.
Expand All @@ -213,6 +216,7 @@ def __init__(
self._max_parallelism = max_parallelism
self._security_context = security_context
self._overwrite_cache = overwrite_cache
self._interruptible = interruptible
self._envs = envs
self._tags = tags
self._cluster_assignment = cluster_assignment
Expand Down Expand Up @@ -287,6 +291,10 @@ def security_context(self) -> typing.Optional[security.SecurityContext]:
def overwrite_cache(self) -> Optional[bool]:
return self._overwrite_cache

@property
def interruptible(self) -> Optional[bool]:
return self._interruptible

@property
def envs(self) -> Optional[_common_models.Envs]:
return self._envs
Expand Down Expand Up @@ -321,6 +329,9 @@ def to_flyte_idl(self):
max_parallelism=self.max_parallelism,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
overwrite_cache=self.overwrite_cache,
interruptible=_google_wrappers_pb2.BoolValue(value=self.interruptible)
if self.interruptible is not None
else None,
envs=self.envs.to_flyte_idl() if self.envs else None,
tags=self.tags,
cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None,
Expand Down Expand Up @@ -351,6 +362,7 @@ def from_flyte_idl(cls, p):
if p.security_context
else None,
overwrite_cache=p.overwrite_cache,
interruptible=p.interruptible.value if p.HasField("interruptible") else None,
envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None,
tags=p.tags,
cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment)
Expand Down
65 changes: 49 additions & 16 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,7 @@ def _execute(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand All @@ -1308,6 +1309,7 @@ def _execute(
:param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten
for a single execution. If enabled, all calculations are performed even if cached results would
be available, overwriting the stored data once execution finishes successfully.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
Expand Down Expand Up @@ -1379,6 +1381,7 @@ def _execute(
0,
),
overwrite_cache=overwrite_cache,
interruptible=interruptible,
notifications=notifications,
disable_all=options.disable_notifications,
labels=options.labels,
Expand Down Expand Up @@ -1455,6 +1458,7 @@ def execute(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand Down Expand Up @@ -1495,6 +1499,7 @@ def execute(
:param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten
for a single execution. If enabled, all calculations are performed even if cached results would
be available, overwriting the stored data once execution finishes successfully.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: Environment variables to be set for the execution.
:param tags: Tags to be set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
Expand All @@ -1519,6 +1524,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1536,6 +1542,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1551,6 +1558,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1566,6 +1574,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1581,6 +1590,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1599,6 +1609,7 @@ def execute(
image_config=image_config,
wait=wait,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1618,6 +1629,7 @@ def execute(
options=options,
wait=wait,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1636,6 +1648,7 @@ def execute(
options=options,
wait=wait,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1658,6 +1671,7 @@ def execute_remote_task_lp(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand All @@ -1678,6 +1692,7 @@ def execute_remote_task_lp(
options=options,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1696,6 +1711,7 @@ def execute_remote_wf(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand All @@ -1717,6 +1733,7 @@ def execute_remote_wf(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1735,6 +1752,7 @@ def execute_reference_task(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand Down Expand Up @@ -1766,6 +1784,7 @@ def execute_reference_task(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1782,6 +1801,7 @@ def execute_reference_workflow(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand Down Expand Up @@ -1827,6 +1847,7 @@ def execute_reference_workflow(
options=options,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1843,6 +1864,7 @@ def execute_reference_launch_plan(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand Down Expand Up @@ -1874,6 +1896,7 @@ def execute_reference_launch_plan(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1896,6 +1919,7 @@ def execute_local_task(
image_config: typing.Optional[ImageConfig] = None,
wait: bool = False,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand All @@ -1914,6 +1938,7 @@ def execute_local_task(
:param image_config: If provided, will use this image config in the pod.
:param wait: If True, will wait for the execution to complete before returning.
:param overwrite_cache: If True, will overwrite the cache.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
Expand Down Expand Up @@ -1954,6 +1979,7 @@ def execute_local_task(
wait=wait,
type_hints=entity.python_interface.inputs,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1974,29 +2000,31 @@ def execute_local_workflow(
options: typing.Optional[Options] = None,
wait: bool = False,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute an @workflow decorated function.
:param entity:
:param inputs:
:param project:
:param domain:
:param name:
:param version:
:param execution_name:
:param image_config:
:param options:
:param wait:
:param overwrite_cache:
:param envs:
:param tags:
:param cluster_pool:
:param execution_cluster_label:
:return:
:param entity: The workflow to execute
:param inputs: Input dictionary
:param project: Project to execute in
:param domain: Domain to execute in
:param name: Optional name override for the workflow
:param version: Optional version for the workflow
:param execution_name: Optional name for the execution
:param image_config: Optional image config override
:param options: Optional Options object
:param wait: Whether to wait for execution completion
:param overwrite_cache: If True, will overwrite the cache
:param interruptible: Optional flag to override the default interruptible flag of the executed entity
:param envs: Environment variables to set for the execution
:param tags: Tags to set for the execution
:param cluster_pool: Specify cluster pool on which newly created execution should be placed
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed
:return: FlyteWorkflowExecution object
"""
if not image_config:
image_config = ImageConfig.auto_default_image()
Expand Down Expand Up @@ -2052,6 +2080,7 @@ def execute_local_workflow(
options=options,
type_hints=entity.python_interface.inputs,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -2071,12 +2100,14 @@ def execute_local_launch_plan(
options: typing.Optional[Options] = None,
wait: bool = False,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a locally defined `LaunchPlan`.
redartera marked this conversation as resolved.
Show resolved Hide resolved

:param entity: The locally defined launch plan object
:param inputs: Inputs to be passed into the execution as a dict with Python native values.
Expand All @@ -2088,6 +2119,7 @@ def execute_local_launch_plan(
:param options: Options to be passed into the execution.
:param wait: If True, will wait for the execution to complete before returning.
:param overwrite_cache: If True, will overwrite the cache.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: Environment variables to be passed into the execution.
:param tags: Tags to be passed into the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
Expand Down Expand Up @@ -2119,6 +2151,7 @@ def execute_local_launch_plan(
wait=wait,
type_hints=entity.python_interface.inputs,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand Down
19 changes: 19 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,25 @@ def test_execute_workflow_with_maptask(register):
)
assert execution.outputs["o0"] == [4, 5, 6]

def test_executes_nested_workflow_dictating_interruptible(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
flyte_launch_plan = remote.fetch_launch_plan(name="basic.child_workflow.parent_wf", version=VERSION)
# The values we want to test for
interruptible_values = [True, False, None]
executions = []
for creation_interruptible in interruptible_values:
execution = remote.execute(flyte_launch_plan, inputs={"a": 10}, wait=False, interruptible=creation_interruptible)
executions.append(execution)
# Wait for all executions to complete
for execution, expected_interruptible in zip(executions, interruptible_values):
execution = remote.wait(execution, timeout=300)
# Check that the parent workflow is interruptible as expected
assert execution.spec.interruptible == expected_interruptible
# Check that the child workflow is interruptible as expected
subwf_execution_id = execution.node_executions["n1"].closure.workflow_node_metadata.execution_id.name
subwf_execution = remote.fetch_execution(project=PROJECT, domain=DOMAIN, name=subwf_execution_id)
assert subwf_execution.spec.interruptible == expected_interruptible


@pytest.mark.lftransfers
class TestLargeFileTransfers:
Expand Down
Loading