Skip to content

Commit

Permalink
Using ParamSpec to show underlying typehinting (flyteorg#2617)
Browse files Browse the repository at this point in the history
Signed-off-by: JackUrb <jack@datologyai.com>
  • Loading branch information
JackUrb authored and mao3267 committed Aug 1, 2024
1 parent ee664f1 commit d90b6ec
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
18 changes: 12 additions & 6 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from functools import update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload

try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
Expand Down Expand Up @@ -80,6 +85,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction
return PythonFunctionTask


P = ParamSpec("P")
T = TypeVar("T")
FuncOut = TypeVar("FuncOut")

Expand Down Expand Up @@ -124,7 +130,7 @@ def task(

@overload
def task(
_task_function: Callable[..., FuncOut],
_task_function: Callable[P, FuncOut],
task_config: Optional[T] = ...,
cache: bool = ...,
cache_serialize: bool = ...,
Expand Down Expand Up @@ -157,11 +163,11 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]: ...
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...


def task(
_task_function: Optional[Callable[..., FuncOut]] = None,
_task_function: Optional[Callable[P, FuncOut]] = None,
task_config: Optional[T] = None,
cache: bool = False,
cache_serialize: bool = False,
Expand Down Expand Up @@ -201,9 +207,9 @@ def task(
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
) -> Union[
Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]],
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
PythonFunctionTask[T],
Callable[..., FuncOut],
]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down Expand Up @@ -324,7 +330,7 @@ def launch_dynamically():
:param accelerator: The accelerator to use for this task.
"""

def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
_metadata = TaskMetadata(
cache=cache,
cache_serialize=cache_serialize,
Expand Down
16 changes: 11 additions & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from functools import update_wrapper
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload

try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore

from flytekit.core import constants as _common_constants
from flytekit.core import launch_plan as _annotated_launch_plan
from flytekit.core.base_task import PythonTask, Task
Expand Down Expand Up @@ -58,6 +63,7 @@
flyte_entity=None,
)

P = ParamSpec("P")
T = typing.TypeVar("T")
FuncOut = typing.TypeVar("FuncOut")

Expand Down Expand Up @@ -809,21 +815,21 @@ def workflow(

@overload
def workflow(
_workflow_function: Callable[..., FuncOut],
_workflow_function: Callable[P, FuncOut],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
) -> Union[PythonFunctionWorkflow, Callable[..., FuncOut]]: ...
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...


def workflow(
_workflow_function: Optional[Callable[..., Any]] = None,
_workflow_function: Optional[Callable[P, FuncOut]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]:
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down Expand Up @@ -856,7 +862,7 @@ def workflow(
:param docs: Description entity for the workflow
"""

def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow:
def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand Down

0 comments on commit d90b6ec

Please sign in to comment.