Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <pingsutw@apache.org>
  • Loading branch information
pingsutw committed Aug 20, 2024
1 parent 0924e54 commit 5ca2b60
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
21 changes: 15 additions & 6 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from functools import update_wrapper
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload

from typing_inspect import is_optional_type

try:
from typing import ParamSpec
except ImportError:
Expand Down Expand Up @@ -693,12 +695,19 @@ def _validate_add_on_failure_handler(self, ctx: FlyteContext, prefix: str, wf_ar
) as inner_comp_ctx:
# Now lets compile the failure-node if it exists
if self.on_failure:
if (
self.on_failure.python_interface
and self.python_interface
and self.on_failure.python_interface.inputs != self.python_interface.inputs
):
raise FlyteFailureNodeInputMismatchException(self.on_failure, self)
if self.on_failure.python_interface and self.python_interface:
workflow_inputs = self.python_interface.inputs
failure_node_inputs = self.on_failure.python_interface.inputs

# Workflow inputs should be a subset of failure node inputs.
if (failure_node_inputs | workflow_inputs) != failure_node_inputs:
raise FlyteFailureNodeInputMismatchException(self.on_failure, self)
additional_keys = failure_node_inputs.keys() - workflow_inputs.keys()
# Raising an error if the additional inputs in the failure node are not optional.
for k in additional_keys:
if not is_optional_type(failure_node_inputs[k]):
raise FlyteFailureNodeInputMismatchException(self.on_failure, self)

Check warning on line 709 in flytekit/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/workflow.py#L709

Added line #L709 was not covered by tests

c = wf_args.copy()
exception_scopes.user_entry_point(self.on_failure)(**c)
inner_nodes = None
Expand Down
24 changes: 17 additions & 7 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,12 +1689,12 @@ def wf2(a: int, b: str) -> typing.Tuple[int, str]:

def test_failure_node_mismatch_inputs():
@task()
def t2(a: int) -> int:
def t1(a: int) -> int:
return a + 3

@workflow(on_failure=t2)
def wf(a: int = 3, b: str = "hello"):
t2(a=a)
@workflow(on_failure=t1)
def wf1(a: int = 3, b: str = "hello"):
t1(a=a)

# pytest-xdist uses `__channelexec__` as the top-level module
running_xdist = os.environ.get("PYTEST_XDIST_WORKER") is not None
Expand All @@ -1703,14 +1703,24 @@ def wf(a: int = 3, b: str = "hello"):
with pytest.raises(
FlyteFailureNodeInputMismatchException,
match="Mismatched Inputs Detected\n"
f"The failure node `{prefix}tests.flytekit.unit.core.test_type_hints.t2` has "
"inputs that do not align with those expected by the workflow `tests.flytekit.unit.core.test_type_hints.wf`.\n"
f"The failure node `{prefix}tests.flytekit.unit.core.test_type_hints.t1` has "
"inputs that do not align with those expected by the workflow `tests.flytekit.unit.core.test_type_hints.wf1`.\n"
"Failure Node's Inputs: {'a': <class 'int'>}\n"
"Workflow's Inputs: {'a': <class 'int'>, 'b': <class 'str'>}\n"
"Action Required:\n"
"Please ensure that all input arguments in the failure node are provided and match the expected arguments specified in the workflow.",
):
wf()
wf1()

@task()
def t2(a: int, b: typing.Optional[int] = None) -> int:
return a + 3

@workflow(on_failure=t2)
def wf2(a: int = 3):
t2(a=a)

wf2()


@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.")
Expand Down

0 comments on commit 5ca2b60

Please sign in to comment.