Skip to content

Commit

Permalink
Signaling (#1133)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor authored and eapolinario committed Feb 22, 2023
1 parent 20ea162 commit 9abc7f9
Show file tree
Hide file tree
Showing 10 changed files with 726 additions and 15 deletions.
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.gate import approve, sleep, wait_for_input
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.map_task import map_task
Expand Down
195 changes: 195 additions & 0 deletions flytekit/core/gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import annotations

import datetime
import typing
from typing import Tuple, Union

import click

from flytekit.core import interface as flyte_interface
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.promise import Promise, VoidPromise, flyte_entity_call_handler
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.user import FlyteDisapprovalException
from flytekit.interaction.parse_stdin import parse_stdin_to_literal
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.types import LiteralType

DEFAULT_TIMEOUT = datetime.timedelta(hours=1)


class Gate(object):
"""
A node type that waits for user input before proceeding with a workflow.
A gate is a type of node that behaves like a task, but instead of running code, it either needs to wait
for user input to proceed or wait for a timer to complete running.
"""

def __init__(
self,
name: str,
input_type: typing.Optional[typing.Type] = None,
upstream_item: typing.Optional[typing.Any] = None,
sleep_duration: typing.Optional[datetime.timedelta] = None,
timeout: typing.Optional[datetime.timedelta] = None,
):
self._name = name
self._input_type = input_type
self._sleep_duration = sleep_duration
self._timeout = timeout or DEFAULT_TIMEOUT
self._upstream_item = upstream_item
self._literal_type = TypeEngine.to_literal_type(input_type) if input_type else None

# Determine the python interface if we can
if self._sleep_duration:
# Just a sleep so there is no interface
self._python_interface = flyte_interface.Interface()
elif input_type:
# Waiting for user input, so the output of the node is whatever input the user provides.
self._python_interface = flyte_interface.Interface(
outputs={
"o0": self.input_type,
}
)
else:
# We don't know how to find the python interface here, approve() sets it below, See the code.
self._python_interface = None

@property
def name(self) -> str:
# Part of SupportsNodeCreation interface
return self._name

@property
def input_type(self) -> typing.Optional[typing.Type]:
return self._input_type

@property
def literal_type(self) -> typing.Optional[LiteralType]:
return self._literal_type

@property
def sleep_duration(self) -> typing.Optional[datetime.timedelta]:
return self._sleep_duration

@property
def python_interface(self) -> flyte_interface.Interface:
"""
This will not be valid during local execution
Part of SupportsNodeCreation interface
"""
# If this is just a sleep node, or user input node, then it will have a Python interface upon construction.
if self._python_interface:
return self._python_interface

raise ValueError("You can't check for a Python interface for an approval node outside of compilation")

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
# Part of SupportsNodeCreation interface
return _workflow_model.NodeMetadata(
name=self.name,
timeout=self._timeout,
)

# This is to satisfy the LocallyExecutable protocol
def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
if self.sleep_duration:
print(f"Mock sleeping for {self.sleep_duration}")
return VoidPromise(self.name)

# Trigger stdin
if self.input_type:
msg = f"Execution stopped for gate {self.name}...\n"
literal = parse_stdin_to_literal(ctx, self.input_type, msg)
p = Promise(var="o0", val=literal)
return p

# Assume this is an approval operation since that's the only remaining option.
msg = f"Pausing execution for {self.name}, literal value is:\n{self._upstream_item.val}\nContinue?"
proceed = click.confirm(msg, default=True)
if proceed:
# We need to return a promise here, and a promise is what should've been passed in by the call in approve()
# Only one element should be in this map. Rely on kwargs instead of the stored _upstream_item even though
# they should be the same to be cleaner
output_name = list(kwargs.keys())[0]
return kwargs[output_name]
else:
raise FlyteDisapprovalException(f"User did not approve the transaction for gate node {self.name}")


def wait_for_input(name: str, timeout: datetime.timedelta, expected_type: typing.Type):
"""
Create a Gate object. This object will function like a task. Note that unlike a task,
each time this function is called, a new Python object is created. If a workflow
calls a subworkflow twice, and the subworkflow has a signal, then two Gate
objects are created. This shouldn't be a problem as long as the objects are identical.
:param name: The name of the gate node.
:param timeout: How long to wait for before Flyte fails the workflow.
:param expected_type: What is the type that the user will be inputting?
:return:
"""

g = Gate(name, input_type=expected_type, timeout=timeout)

return flyte_entity_call_handler(g)


def sleep(duration: datetime.timedelta):
"""
:param duration: How long to sleep for
:return:
"""
g = Gate("sleep-gate", sleep_duration=duration)

return flyte_entity_call_handler(g)


def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: str, timeout: datetime.timedelta):
"""
Create a Gate object. This object will function like a task. Note that unlike a task,
each time this function is called, a new Python object is created. If a workflow
calls a subworkflow twice, and the subworkflow has a signal, then two Gate
objects are created. This shouldn't be a problem as long as the objects are identical.
:param upstream_item: This should be the output, one output, of a previous task, that you want to gate execution
on. This is the value that you want a human to check before moving on.
:param name: The name of the gate node.
:param timeout: How long to wait before Flyte fails the workflow.
:return:
"""
g = Gate(name, upstream_item=upstream_item, timeout=timeout)

if upstream_item is None or isinstance(upstream_item, VoidPromise):
raise ValueError("You can't use approval on a task that doesn't return anything.")

ctx = FlyteContextManager.current_context()
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
if not upstream_item.ref.node.flyte_entity.python_interface:
raise ValueError(
f"Upstream node doesn't have a Python interface. Node entity is: "
f"{upstream_item.ref.node.flyte_entity}"
)

# We have reach back up to the entity that this promise came from, to get the python type, since
# the approve function itself doesn't have a python interface.
io_type = upstream_item.ref.node.flyte_entity.python_interface.outputs[upstream_item.var]
io_var_name = upstream_item.var
else:
# We don't know the python type here. in local execution, downstream doesn't really use the type
# so we should be okay. But use None instead of type() so that errors are more obvious hopefully.
io_type = None
io_var_name = "o0"

# In either case, we need a python interface
g._python_interface = flyte_interface.Interface(
inputs={
io_var_name: io_type,
},
outputs={
io_var_name: io_type,
},
)
kwargs = {io_var_name: upstream_item}

return flyte_entity_call_handler(g, **kwargs)
2 changes: 1 addition & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
...


def flyte_entity_call_handler(entity: Union[SupportsNodeCreation], *args, **kwargs):
def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs):
"""
This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying
workflow). The logic is the same for all three, but we did not want to create base class, hence this separate
Expand Down
4 changes: 4 additions & 0 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class FlyteValidationException(FlyteAssertion):
_ERROR_CODE = "USER:ValidationError"


class FlyteDisapprovalException(FlyteAssertion):
_ERROR_CODE = "USER:ResultNotApproved"


class FlyteEntityAlreadyExistsException(FlyteAssertion):
_ERROR_CODE = "USER:EntityAlreadyExists"

Expand Down
Empty file.
36 changes: 36 additions & 0 deletions flytekit/interaction/parse_stdin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import typing

import click

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine
from flytekit.loggers import logger
from flytekit.models.literals import Literal


# TODO: Move the improved click parsing here. https://github.com/flyteorg/flyte/issues/3124
def parse_stdin_to_literal(ctx: FlyteContext, t: typing.Type, message_prefix: typing.Optional[str]) -> Literal:

message = message_prefix or ""
message += f"Please enter value for type {t} to continue"
if issubclass(t, bool):
user_input = click.prompt(message, type=bool)
l = TypeEngine.to_literal(ctx, user_input, bool, TypeEngine.to_literal_type(bool)) # noqa
elif issubclass(t, int):
user_input = click.prompt(message, type=int)
l = TypeEngine.to_literal(ctx, user_input, int, TypeEngine.to_literal_type(int)) # noqa
elif issubclass(t, float):
user_input = click.prompt(message, type=float)
l = TypeEngine.to_literal(ctx, user_input, float, TypeEngine.to_literal_type(float)) # noqa
elif issubclass(t, str):
user_input = click.prompt(message, type=str)
l = TypeEngine.to_literal(ctx, user_input, str, TypeEngine.to_literal_type(str)) # noqa
else:
# Todo: We should implement the rest by way of importing the code in pyflyte run
# that parses text from the command line
raise Exception("Only bool, int/float, or strings are accepted for now.")

logger.debug(f"Parsed literal {l} from user input {user_input}")
return l
51 changes: 41 additions & 10 deletions flytekit/models/core/identifier.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from flyteidl.core import identifier_pb2 as _identifier_pb2
from flyteidl.core import identifier_pb2 as identifier_pb2

from flytekit.models import common as _common_models


class ResourceType(object):
UNSPECIFIED = _identifier_pb2.UNSPECIFIED
TASK = _identifier_pb2.TASK
WORKFLOW = _identifier_pb2.WORKFLOW
LAUNCH_PLAN = _identifier_pb2.LAUNCH_PLAN
UNSPECIFIED = identifier_pb2.UNSPECIFIED
TASK = identifier_pb2.TASK
WORKFLOW = identifier_pb2.WORKFLOW
LAUNCH_PLAN = identifier_pb2.LAUNCH_PLAN


class Identifier(_common_models.FlyteIdlEntity):
Expand All @@ -34,7 +34,7 @@ def resource_type(self):
return self._resource_type

def resource_type_name(self) -> str:
return _identifier_pb2.ResourceType.Name(self.resource_type)
return identifier_pb2.ResourceType.Name(self.resource_type)

@property
def project(self):
Expand Down Expand Up @@ -68,7 +68,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.core.identifier_pb2.Identifier
"""
return _identifier_pb2.Identifier(
return identifier_pb2.Identifier(
resource_type=self.resource_type,
project=self.project,
domain=self.domain,
Expand Down Expand Up @@ -133,7 +133,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier
"""
return _identifier_pb2.WorkflowExecutionIdentifier(
return identifier_pb2.WorkflowExecutionIdentifier(
project=self.project,
domain=self.domain,
name=self.name,
Expand Down Expand Up @@ -179,7 +179,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.core.identifier_pb2.NodeExecutionIdentifier
"""
return _identifier_pb2.NodeExecutionIdentifier(
return identifier_pb2.NodeExecutionIdentifier(
node_id=self.node_id,
execution_id=self.execution_id.to_flyte_idl(),
)
Expand Down Expand Up @@ -232,7 +232,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.core.identifier_pb2.TaskExecutionIdentifier
"""
return _identifier_pb2.TaskExecutionIdentifier(
return identifier_pb2.TaskExecutionIdentifier(
task_id=self.task_id.to_flyte_idl(),
node_execution_id=self.node_execution_id.to_flyte_idl(),
retry_attempt=self.retry_attempt,
Expand All @@ -249,3 +249,34 @@ def from_flyte_idl(cls, proto):
node_execution_id=NodeExecutionIdentifier.from_flyte_idl(proto.node_execution_id),
retry_attempt=proto.retry_attempt,
)


class SignalIdentifier(_common_models.FlyteIdlEntity):
def __init__(self, signal_id: str, execution_id: WorkflowExecutionIdentifier):
"""
:param signal_id: User provided name for the gate node.
:param execution_id: The workflow execution id this signal is for.
"""
self._signal_id = signal_id
self._execution_id = execution_id

@property
def signal_id(self) -> str:
return self._signal_id

@property
def execution_id(self) -> WorkflowExecutionIdentifier:
return self._execution_id

def to_flyte_idl(self) -> identifier_pb2.SignalIdentifier:
return identifier_pb2.SignalIdentifier(
signal_id=self.signal_id,
execution_id=self.execution_id.to_flyte_idl(),
)

@classmethod
def from_flyte_idl(cls, proto: identifier_pb2.SignalIdentifier) -> "SignalIdentifier":
return cls(
signal_id=proto.signal_id,
execution_id=WorkflowExecutionIdentifier.from_flyte_idl(proto.execution_id),
)
Loading

0 comments on commit 9abc7f9

Please sign in to comment.