From 11a5393b63cf9f944364cbf7d8d7f12a332c52ce Mon Sep 17 00:00:00 2001 From: pankajastro Date: Thu, 10 Oct 2024 18:08:32 +0530 Subject: [PATCH 1/2] Cast callbacks to functions when set with default_args on task groups --- dagfactory/dagbuilder.py | 53 ++++++++++++++++++++++++++++++++- tests/test_dagbuilder.py | 63 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 114 insertions(+), 2 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index ef7babec..2d8ca7c7 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -179,10 +179,18 @@ def get_dag_params(self) -> Dict[str, Any]: if utils.check_dict_key(dag_params["default_args"], "sla_miss_callback"): if isinstance(dag_params["default_args"]["sla_miss_callback"], str): - dag_params["default_args"]["sla_miss_callback"]: Callable = import_string( + dag_params["default_args"]["sla_miss_callback"] = import_string( dag_params["default_args"]["sla_miss_callback"] ) + if utils.check_dict_key(dag_params["default_args"], "on_execute_callback") and version.parse( + AIRFLOW_VERSION + ) >= version.parse("2.0.0"): + if isinstance(dag_params["default_args"]["on_execute_callback"], str): + dag_params["default_args"]["on_execute_callback"] = import_string( + dag_params["default_args"]["on_execute_callback"] + ) + if utils.check_dict_key(dag_params["default_args"], "on_success_callback"): if isinstance(dag_params["default_args"]["on_success_callback"], str): dag_params["default_args"]["on_success_callback"]: Callable = import_string( @@ -544,6 +552,47 @@ def make_task_groups(task_groups: Dict[str, Any], dag: DAG) -> Dict[str, "TaskGr for task_group_name, task_group_conf in task_groups.items(): task_group_conf["group_id"] = task_group_name task_group_conf["dag"] = dag + + if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0") and isinstance( + task_group_conf.get("default_args"), dict + ): + # https://github.com/apache/airflow/pull/16557 + if utils.check_dict_key(task_group_conf["default_args"], "on_success_callback"): + if isinstance( + task_group_conf["default_args"]["on_success_callback"], + str, + ): + task_group_conf["default_args"]["on_success_callback"]: Callable = import_string( + task_group_conf["default_args"]["on_success_callback"] + ) + + if utils.check_dict_key(task_group_conf["default_args"], "on_execute_callback"): + if isinstance( + task_group_conf["default_args"]["on_execute_callback"], + str, + ): + task_group_conf["default_args"]["on_execute_callback"]: Callable = import_string( + task_group_conf["default_args"]["on_execute_callback"] + ) + + if utils.check_dict_key(task_group_conf["default_args"], "on_failure_callback"): + if isinstance( + task_group_conf["default_args"]["on_failure_callback"], + str, + ): + task_group_conf["default_args"]["on_failure_callback"]: Callable = import_string( + task_group_conf["default_args"]["on_failure_callback"] + ) + + if utils.check_dict_key(task_group_conf["default_args"], "on_retry_callback"): + if isinstance( + task_group_conf["default_args"]["on_retry_callback"], + str, + ): + task_group_conf["default_args"]["on_retry_callback"]: Callable = import_string( + task_group_conf["default_args"]["on_retry_callback"] + ) + task_group = TaskGroup(**{k: v for k, v in task_group_conf.items() if k not in SYSTEM_PARAMS}) task_groups_dict[task_group.group_id] = task_group return task_groups_dict @@ -572,11 +621,13 @@ def set_dependencies( name = f"{group_id}.{name}" if conf.get("dependencies"): source: Union[BaseOperator, "TaskGroup"] = tasks_and_task_groups_instances[name] + for dep in conf["dependencies"]: if tasks_and_task_groups_config[dep].get("task_group"): group_id = tasks_and_task_groups_config[dep]["task_group"].group_id dep = f"{group_id}.{dep}" dep: Union[BaseOperator, "TaskGroup"] = tasks_and_task_groups_instances[dep] + source.set_upstream(dep) @staticmethod diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index c7dc0e88..000a98ef 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -130,6 +130,44 @@ }, }, } +DAG_CONFIG_TASK_GROUP_WITH_CALLBACKS = { + "default_args": {"owner": "custom_owner"}, + "schedule_interval": "0 3 * * *", + "task_groups": { + "task_group_1": { + "tooltip": "this is a task group", + "default_args": { + "on_failure_callback": f"{__name__}.print_context_callback", + "on_success_callback": f"{__name__}.print_context_callback", + "on_execute_callback": f"{__name__}.print_context_callback", + "on_retry_callback": f"{__name__}.print_context_callback", + }, + }, + }, + "tasks": { + "task_1": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 1", + "task_group_name": "task_group_1", + }, + "task_2": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 2", + "task_group_name": "task_group_1", + }, + "task_3": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 3", + "task_group_name": "task_group_1", + "dependencies": ["task_2"], + }, + "task_4": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 4", + "dependencies": ["task_group_1"], + }, + }, +} DAG_CONFIG_DYNAMIC_TASK_MAPPING = { "default_args": {"owner": "custom_owner"}, "description": "This is an example dag with dynamic task mapping", @@ -453,7 +491,6 @@ def test_build(): if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.8"): assert actual["dag"].tags == ["tag1", "tag2"] - def test_get_dag_params_dag_with_task_group(): td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_TASK_GROUP, DEFAULT_CONFIG) expected = { @@ -545,6 +582,30 @@ def test_build_task_groups(): assert {"task_group_2.task_5", "task_group_2.task_6"} == task_group_2 +def test_build_task_groups_with_callbacks(): + td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_TASK_GROUP_WITH_CALLBACKS, DEFAULT_CONFIG) + if version.parse(AIRFLOW_VERSION) < version.parse("2.2.0"): + error_message = "`task_groups` key can only be used with Airflow 2.x.x" + with pytest.raises(Exception, match=error_message): + td.build() + else: + actual = td.build() + assert actual["dag_id"] == "test_dag" + assert isinstance(actual["dag"], DAG) + assert callable( + actual["dag"].task_group.get_task_group_dict()["task_group_1"].default_args["on_failure_callback"] + ) + assert callable( + actual["dag"].task_group.get_task_group_dict()["task_group_1"].default_args["on_execute_callback"] + ) + assert callable( + actual["dag"].task_group.get_task_group_dict()["task_group_1"].default_args["on_success_callback"] + ) + assert callable( + actual["dag"].task_group.get_task_group_dict()["task_group_1"].default_args["on_retry_callback"] + ) + + @patch("dagfactory.dagbuilder.TaskGroup", new=MockTaskGroup) def test_make_task_groups(): task_group_dict = { From e1a2b1b6c1b41825d85c4da160224a75374ce3e2 Mon Sep 17 00:00:00 2001 From: pankajastro Date: Thu, 10 Oct 2024 18:18:04 +0530 Subject: [PATCH 2/2] Cast callbacks to functions when set with default_args on task groups > Co-authored-by: Luiz Felipe de Mesquita Baraldo --- tests/test_dagbuilder.py | 77 ++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index 000a98ef..bc1260ff 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -130,44 +130,6 @@ }, }, } -DAG_CONFIG_TASK_GROUP_WITH_CALLBACKS = { - "default_args": {"owner": "custom_owner"}, - "schedule_interval": "0 3 * * *", - "task_groups": { - "task_group_1": { - "tooltip": "this is a task group", - "default_args": { - "on_failure_callback": f"{__name__}.print_context_callback", - "on_success_callback": f"{__name__}.print_context_callback", - "on_execute_callback": f"{__name__}.print_context_callback", - "on_retry_callback": f"{__name__}.print_context_callback", - }, - }, - }, - "tasks": { - "task_1": { - "operator": "airflow.operators.bash_operator.BashOperator", - "bash_command": "echo 1", - "task_group_name": "task_group_1", - }, - "task_2": { - "operator": "airflow.operators.bash_operator.BashOperator", - "bash_command": "echo 2", - "task_group_name": "task_group_1", - }, - "task_3": { - "operator": "airflow.operators.bash_operator.BashOperator", - "bash_command": "echo 3", - "task_group_name": "task_group_1", - "dependencies": ["task_2"], - }, - "task_4": { - "operator": "airflow.operators.bash_operator.BashOperator", - "bash_command": "echo 4", - "dependencies": ["task_group_1"], - }, - }, -} DAG_CONFIG_DYNAMIC_TASK_MAPPING = { "default_args": {"owner": "custom_owner"}, "description": "This is an example dag with dynamic task mapping", @@ -236,6 +198,45 @@ } UTC = pendulum.timezone("UTC") +DAG_CONFIG_TASK_GROUP_WITH_CALLBACKS = { + "default_args": {"owner": "custom_owner"}, + "schedule_interval": "0 3 * * *", + "task_groups": { + "task_group_1": { + "tooltip": "this is a task group", + "default_args": { + "on_failure_callback": f"{__name__}.print_context_callback", + "on_success_callback": f"{__name__}.print_context_callback", + "on_execute_callback": f"{__name__}.print_context_callback", + "on_retry_callback": f"{__name__}.print_context_callback", + }, + }, + }, + "tasks": { + "task_1": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 1", + "task_group_name": "task_group_1", + }, + "task_2": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 2", + "task_group_name": "task_group_1", + }, + "task_3": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 3", + "task_group_name": "task_group_1", + "dependencies": ["task_2"], + }, + "task_4": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 4", + "dependencies": ["task_group_1"], + }, + }, +} + class MockTaskGroup: def __init__(self, **kwargs):