Skip to content

Commit

Permalink
Improve workflow decorator type hints with overload
Browse files Browse the repository at this point in the history
Previously, the workflow decorator is hinted as always returning a WorkflowBase, which is not true when _workflow_function is None; similar to flyteorg#1631, we propose using typing.overload to differentiate the return type of workflow based on the value of _workflow_function

Signed-off-by: Matthew Hoffman <matthew@protopia.ai>
  • Loading branch information
ringohoffman committed May 11, 2023
1 parent 993201f commit 8548b7d
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from enum import Enum
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, overload

from typing_extensions import get_args

Expand Down Expand Up @@ -653,7 +653,7 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver):

def __init__(
self,
workflow_function: Callable,
workflow_function: Callable[..., Any],
metadata: WorkflowMetadata,
default_metadata: WorkflowMetadataDefaults,
docstring: Optional[Docstring] = None,
Expand Down Expand Up @@ -777,12 +777,32 @@ def execute(self, **kwargs):
return exception_scopes.user_entry_point(self._workflow_function)(**kwargs)


@overload
def workflow(
_workflow_function=None,
_workflow_function: None = ...,
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> Callable[[Callable[..., Any]], PythonFunctionWorkflow]:
...


@overload
def workflow(
_workflow_function: Callable[..., Any],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> PythonFunctionWorkflow:
...


def workflow(
_workflow_function: Optional[Callable[..., Any]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
docs: Optional[Documentation] = None,
) -> WorkflowBase:
) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down Expand Up @@ -813,7 +833,7 @@ def workflow(
:param docs: Description entity for the workflow
"""

def wrapper(fn):
def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow:
workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand All @@ -828,10 +848,10 @@ def wrapper(fn):
update_wrapper(workflow_instance, fn)
return workflow_instance

if _workflow_function:
if _workflow_function is not None:
return wrapper(_workflow_function)
else:
return wrapper # type: ignore
return wrapper


class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignore
Expand Down

0 comments on commit 8548b7d

Please sign in to comment.