Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error message for FailureNodeInputMismatch error #2693

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from functools import update_wrapper
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload

from typing_inspect import is_optional_type

try:
from typing import ParamSpec
except ImportError:
Expand Down Expand Up @@ -47,7 +49,11 @@
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions import scopes as exception_scopes
from flytekit.exceptions.user import FlyteValidationException, FlyteValueException
from flytekit.exceptions.user import (
FlyteFailureNodeInputMismatchException,
FlyteValidationException,
FlyteValueException,
)
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
Expand Down Expand Up @@ -689,6 +695,19 @@
) as inner_comp_ctx:
# Now lets compile the failure-node if it exists
if self.on_failure:
if self.on_failure.python_interface and self.python_interface:
workflow_inputs = self.python_interface.inputs
failure_node_inputs = self.on_failure.python_interface.inputs

# Workflow inputs should be a subset of failure node inputs.
if (failure_node_inputs | workflow_inputs) != failure_node_inputs:
raise FlyteFailureNodeInputMismatchException(self.on_failure, self)
additional_keys = failure_node_inputs.keys() - workflow_inputs.keys()
# Raising an error if the additional inputs in the failure node are not optional.
for k in additional_keys:
if not is_optional_type(failure_node_inputs[k]):
raise FlyteFailureNodeInputMismatchException(self.on_failure, self)

Check warning on line 709 in flytekit/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/workflow.py#L709

Added line #L709 was not covered by tests

c = wf_args.copy()
exception_scopes.user_entry_point(self.on_failure)(**c)
inner_nodes = None
Expand Down
22 changes: 22 additions & 0 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from flytekit.exceptions.base import FlyteException as _FlyteException
from flytekit.exceptions.base import FlyteRecoverableException as _Recoverable

if typing.TYPE_CHECKING:
from flytekit.core.base_task import Task
from flytekit.core.workflow import WorkflowBase


class FlyteUserException(_FlyteException):
_ERROR_CODE = "USER:Unknown"
Expand Down Expand Up @@ -68,6 +72,24 @@ class FlyteValidationException(FlyteAssertion):
_ERROR_CODE = "USER:ValidationError"


class FlyteFailureNodeInputMismatchException(FlyteAssertion):
_ERROR_CODE = "USER:FailureNodeInputMismatch"

def __init__(self, failure_node_node: typing.Union["WorkflowBase", "Task"], workflow: "WorkflowBase"):
self.failure_node_node = failure_node_node
self.workflow = workflow

def __str__(self):
return (
f"Mismatched Inputs Detected\n"
f"The failure node `{self.failure_node_node.name}` has inputs that do not align with those expected by the workflow `{self.workflow.name}`.\n"
f"Failure Node's Inputs: {self.failure_node_node.python_interface.inputs}\n"
f"Workflow's Inputs: {self.workflow.python_interface.inputs}\n"
"Action Required:\n"
"Please ensure that all input arguments in the failure node are provided and match the expected arguments specified in the workflow."
)


class FlyteDisapprovalException(FlyteAssertion):
_ERROR_CODE = "USER:ResultNotApproved"

Expand Down
39 changes: 38 additions & 1 deletion tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from flytekit.core.testing import patch, task_mock
from flytekit.core.type_engine import RestrictedTypeError, SimpleTransformer, TypeEngine
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteValidationException
from flytekit.exceptions.user import FlyteValidationException, FlyteFailureNodeInputMismatchException
from flytekit.models import literals as _literal_models
from flytekit.models.core import types as _core_types
from flytekit.models.interface import Parameter
Expand Down Expand Up @@ -1635,6 +1635,7 @@ def foo4(input: DC1=DC1(1, 'a')) -> DC2:
):
foo4()


def test_failure_node():
@task
def run(a: int, b: str) -> typing.Tuple[int, str]:
Expand Down Expand Up @@ -1686,6 +1687,42 @@ def wf2(a: int, b: str) -> typing.Tuple[int, str]:
assert wf2.failure_node.flyte_entity == failure_handler


def test_failure_node_mismatch_inputs():
@task()
def t1(a: int) -> int:
return a + 3

@workflow(on_failure=t1)
def wf1(a: int = 3, b: str = "hello"):
t1(a=a)

# pytest-xdist uses `__channelexec__` as the top-level module
running_xdist = os.environ.get("PYTEST_XDIST_WORKER") is not None
prefix = "__channelexec__." if running_xdist else ""

with pytest.raises(
FlyteFailureNodeInputMismatchException,
match="Mismatched Inputs Detected\n"
f"The failure node `{prefix}tests.flytekit.unit.core.test_type_hints.t1` has "
"inputs that do not align with those expected by the workflow `tests.flytekit.unit.core.test_type_hints.wf1`.\n"
"Failure Node's Inputs: {'a': <class 'int'>}\n"
"Workflow's Inputs: {'a': <class 'int'>, 'b': <class 'str'>}\n"
"Action Required:\n"
"Please ensure that all input arguments in the failure node are provided and match the expected arguments specified in the workflow.",
):
wf1()

@task()
def t2(a: int, b: typing.Optional[int] = None) -> int:
return a + 3

@workflow(on_failure=t2)
def wf2(a: int = 3):
t2(a=a)

wf2()


@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.")
def test_union_type():
import pandas as pd
Expand Down
Loading