diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 5d2ef6f2a5..4abd07a007 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -8,6 +8,8 @@ from functools import update_wrapper from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload +from typing_inspect import is_optional_type + try: from typing import ParamSpec except ImportError: @@ -47,7 +49,11 @@ from flytekit.core.tracker import extract_task_module from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import scopes as exception_scopes -from flytekit.exceptions.user import FlyteValidationException, FlyteValueException +from flytekit.exceptions.user import ( + FlyteFailureNodeInputMismatchException, + FlyteValidationException, + FlyteValueException, +) from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models @@ -689,6 +695,19 @@ def _validate_add_on_failure_handler(self, ctx: FlyteContext, prefix: str, wf_ar ) as inner_comp_ctx: # Now lets compile the failure-node if it exists if self.on_failure: + if self.on_failure.python_interface and self.python_interface: + workflow_inputs = self.python_interface.inputs + failure_node_inputs = self.on_failure.python_interface.inputs + + # Workflow inputs should be a subset of failure node inputs. + if (failure_node_inputs | workflow_inputs) != failure_node_inputs: + raise FlyteFailureNodeInputMismatchException(self.on_failure, self) + additional_keys = failure_node_inputs.keys() - workflow_inputs.keys() + # Raising an error if the additional inputs in the failure node are not optional. + for k in additional_keys: + if not is_optional_type(failure_node_inputs[k]): + raise FlyteFailureNodeInputMismatchException(self.on_failure, self) + c = wf_args.copy() exception_scopes.user_entry_point(self.on_failure)(**c) inner_nodes = None diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index 645754dc35..6637c8d573 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -3,6 +3,10 @@ from flytekit.exceptions.base import FlyteException as _FlyteException from flytekit.exceptions.base import FlyteRecoverableException as _Recoverable +if typing.TYPE_CHECKING: + from flytekit.core.base_task import Task + from flytekit.core.workflow import WorkflowBase + class FlyteUserException(_FlyteException): _ERROR_CODE = "USER:Unknown" @@ -68,6 +72,24 @@ class FlyteValidationException(FlyteAssertion): _ERROR_CODE = "USER:ValidationError" +class FlyteFailureNodeInputMismatchException(FlyteAssertion): + _ERROR_CODE = "USER:FailureNodeInputMismatch" + + def __init__(self, failure_node_node: typing.Union["WorkflowBase", "Task"], workflow: "WorkflowBase"): + self.failure_node_node = failure_node_node + self.workflow = workflow + + def __str__(self): + return ( + f"Mismatched Inputs Detected\n" + f"The failure node `{self.failure_node_node.name}` has inputs that do not align with those expected by the workflow `{self.workflow.name}`.\n" + f"Failure Node's Inputs: {self.failure_node_node.python_interface.inputs}\n" + f"Workflow's Inputs: {self.workflow.python_interface.inputs}\n" + "Action Required:\n" + "Please ensure that all input arguments in the failure node are provided and match the expected arguments specified in the workflow." + ) + + class FlyteDisapprovalException(FlyteAssertion): _ERROR_CODE = "USER:ResultNotApproved" diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 0a3501665c..9601ab6763 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -33,7 +33,7 @@ from flytekit.core.testing import patch, task_mock from flytekit.core.type_engine import RestrictedTypeError, SimpleTransformer, TypeEngine from flytekit.core.workflow import workflow -from flytekit.exceptions.user import FlyteValidationException +from flytekit.exceptions.user import FlyteValidationException, FlyteFailureNodeInputMismatchException from flytekit.models import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.interface import Parameter @@ -1635,6 +1635,7 @@ def foo4(input: DC1=DC1(1, 'a')) -> DC2: ): foo4() + def test_failure_node(): @task def run(a: int, b: str) -> typing.Tuple[int, str]: @@ -1686,6 +1687,42 @@ def wf2(a: int, b: str) -> typing.Tuple[int, str]: assert wf2.failure_node.flyte_entity == failure_handler +def test_failure_node_mismatch_inputs(): + @task() + def t1(a: int) -> int: + return a + 3 + + @workflow(on_failure=t1) + def wf1(a: int = 3, b: str = "hello"): + t1(a=a) + + # pytest-xdist uses `__channelexec__` as the top-level module + running_xdist = os.environ.get("PYTEST_XDIST_WORKER") is not None + prefix = "__channelexec__." if running_xdist else "" + + with pytest.raises( + FlyteFailureNodeInputMismatchException, + match="Mismatched Inputs Detected\n" + f"The failure node `{prefix}tests.flytekit.unit.core.test_type_hints.t1` has " + "inputs that do not align with those expected by the workflow `tests.flytekit.unit.core.test_type_hints.wf1`.\n" + "Failure Node's Inputs: {'a': }\n" + "Workflow's Inputs: {'a': , 'b': }\n" + "Action Required:\n" + "Please ensure that all input arguments in the failure node are provided and match the expected arguments specified in the workflow.", + ): + wf1() + + @task() + def t2(a: int, b: typing.Optional[int] = None) -> int: + return a + 3 + + @workflow(on_failure=t2) + def wf2(a: int = 3): + t2(a=a) + + wf2() + + @pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") def test_union_type(): import pandas as pd