diff --git a/src/sagemaker/local/pipeline.py b/src/sagemaker/local/pipeline.py index c1d049de4d..4b9209fc0b 100644 --- a/src/sagemaker/local/pipeline.py +++ b/src/sagemaker/local/pipeline.py @@ -444,7 +444,7 @@ def _resolve_not_condition(self, not_condition: dict): True if given ConditionNot evaluated as true, False otherwise. """ - return not self._resolve_condition(not_condition["Expression"]) + return not self._resolve_condition(not_condition["Condition"]) def _resolve_or_condition(self, or_condition: dict): """Resolve given ConditionOr. diff --git a/src/sagemaker/workflow/conditions.py b/src/sagemaker/workflow/conditions.py index 4cdec057f4..4b4996a7fa 100644 --- a/src/sagemaker/workflow/conditions.py +++ b/src/sagemaker/workflow/conditions.py @@ -259,7 +259,7 @@ def __init__(self, expression: Condition): def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" - return {"Type": self.condition_type.value, "Expression": self.expression.to_request()} + return {"Type": self.condition_type.value, "Condition": self.expression.to_request()} @property def _referenced_steps(self) -> List[str]: diff --git a/tests/integ/sagemaker/workflow/test_fail_steps.py b/tests/integ/sagemaker/workflow/test_fail_steps.py index 5f8c1e04ab..af9ed5368e 100644 --- a/tests/integ/sagemaker/workflow/test_fail_steps.py +++ b/tests/integ/sagemaker/workflow/test_fail_steps.py @@ -17,7 +17,7 @@ from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker import get_execution_role, utils from sagemaker.workflow.condition_step import ConditionStep -from sagemaker.workflow.conditions import ConditionEquals +from sagemaker.workflow.conditions import ConditionEquals, ConditionNot from sagemaker.workflow.fail_step import FailStep from sagemaker.workflow.functions import Join @@ -37,14 +37,15 @@ def pipeline_name(): def test_two_step_fail_pipeline_with_str_err_msg(sagemaker_session, role, pipeline_name): param = ParameterInteger(name="MyInt", default_value=2) - cond = ConditionEquals(left=param, right=1) + cond_equal = ConditionEquals(left=param, right=2) + cond_not_equal = ConditionNot(cond_equal) step_fail = FailStep( name="FailStep", error_message="Failed due to hitting in else branch", ) step_cond = ConditionStep( name="CondStep", - conditions=[cond], + conditions=[cond_not_equal], if_steps=[], else_steps=[step_fail], ) diff --git a/tests/unit/sagemaker/workflow/test_condition_step.py b/tests/unit/sagemaker/workflow/test_condition_step.py index 7ac335fbc3..315d549cce 100644 --- a/tests/unit/sagemaker/workflow/test_condition_step.py +++ b/tests/unit/sagemaker/workflow/test_condition_step.py @@ -202,7 +202,7 @@ def test_pipeline_condition_step_interpolated(sagemaker_session): }, { "Type": "Not", - "Expression": { + "Condition": { "Type": "Equals", "LeftValue": {"Get": "Parameters.MyInt1"}, "RightValue": {"Get": "Parameters.MyInt2"}, @@ -210,7 +210,7 @@ def test_pipeline_condition_step_interpolated(sagemaker_session): }, { "Type": "Not", - "Expression": { + "Condition": { "Type": "In", "QueryValue": {"Get": "Parameters.MyStr"}, "Values": ["abc", "def"], @@ -533,9 +533,9 @@ def func2(): assert len(step_dsl["Arguments"]["Conditions"]) == 1 condition_dsl = step_dsl["Arguments"]["Conditions"][0] assert condition_dsl["Type"] == "Not" - cond_expr_dsl = condition_dsl["Expression"] + cond_expr_dsl = condition_dsl["Condition"] assert cond_expr_dsl["Type"] == "Not" - cond_inner_expr_dsl = cond_expr_dsl["Expression"] + cond_inner_expr_dsl = cond_expr_dsl["Condition"] assert cond_inner_expr_dsl["Type"] == "Or" assert len(cond_inner_expr_dsl["Conditions"]) == 2 assert cond_inner_expr_dsl["Conditions"][0]["LeftValue"] == _get_expected_jsonget_expr( @@ -602,7 +602,7 @@ def func4(): assert len(step_dsl["Arguments"]["Conditions"]) == 1 condition_dsl = step_dsl["Arguments"]["Conditions"][0] assert condition_dsl["Type"] == "Not" - cond_expr_dsl = condition_dsl["Expression"] + cond_expr_dsl = condition_dsl["Condition"] assert cond_expr_dsl["Type"] == "In" assert cond_expr_dsl["QueryValue"] == _get_expected_jsonget_expr( step_name=step_output3._step.name, path="Result" diff --git a/tests/unit/sagemaker/workflow/test_conditions.py b/tests/unit/sagemaker/workflow/test_conditions.py index a7ec9c0c11..941a191856 100644 --- a/tests/unit/sagemaker/workflow/test_conditions.py +++ b/tests/unit/sagemaker/workflow/test_conditions.py @@ -122,7 +122,7 @@ def test_condition_not(): cond_not = ConditionNot(expression=cond_eq) assert cond_not.to_request() == { "Type": "Not", - "Expression": { + "Condition": { "Type": "Equals", "LeftValue": param, "RightValue": "foo", @@ -136,7 +136,7 @@ def test_condition_not_in(): cond_not = ConditionNot(expression=cond_in) assert cond_not.to_request() == { "Type": "Not", - "Expression": { + "Condition": { "Type": "In", "QueryValue": param, "Values": ["abc", "def"],