From 2e8f0a8dbd7acdf1e6f1cd74ad295ea21139b189 Mon Sep 17 00:00:00 2001 From: binhnq94 Date: Tue, 12 Sep 2023 04:12:28 +0700 Subject: [PATCH] Remove prefix of run and test tasks if test_behavior = TestBehavior.AFTER_EACH (#524) If test_behavior = TestBehavior.AFTER_EACH, DbtNode.name in task_id is not necessary because the parent task group is named as `DbtNode.name` Comes from: [Slack thread](https://apache-airflow.slack.com/archives/C059CC42E9W/p1692776042134929) Remove prefix of run and test tasks if test_behavior = TestBehavior.AFTER_EACH. Olds rendered tasks before this PR: ![image](https://github.com/astronomer/astronomer-cosmos/assets/8995895/d7b55ad8-1c3a-4355-a5ce-04bf3037fd90) New rendered tasks: ![image](https://github.com/astronomer/astronomer-cosmos/assets/8995895/a8b933f6-7990-4f4f-9fa8-d590dd63f8b2) --- cosmos/airflow/graph.py | 31 +++++++++++++++------ tests/airflow/test_graph.py | 54 ++++++++++++++++++++++++++----------- 2 files changed, 62 insertions(+), 23 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index f0792b8fb..9e750ddab 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -2,16 +2,15 @@ from typing import Any, Callable +from airflow.models import BaseOperator from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup -from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode +from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior from cosmos.core.airflow import get_airflow_task as create_airflow_task from cosmos.core.graph.entities import Task as TaskMetadata from cosmos.dbt.graph import DbtNode from cosmos.log import get_logger -from airflow.models import BaseOperator - logger = get_logger(__name__) @@ -51,7 +50,9 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st return leaves -def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any]) -> TaskMetadata | None: +def create_task_metadata( + node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_name_as_task_id_prefix=True +) -> TaskMetadata | None: """ Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node. @@ -59,6 +60,8 @@ def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dic :param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.LOCAL, ExecutionMode.KUBERNETES). Default is ExecutionMode.LOCAL. :param args: Arguments to be used to instantiate an Airflow Task + :param use_name_as_task_id_prefix: If resource_type is DbtResourceType.MODEL, it determines whether + using name as task id prefix or not. If it is True task_id = _run, else task_id=run. :returns: The metadata necessary to instantiate the source dbt node as an Airflow task. """ dbt_resource_to_class = { @@ -70,9 +73,16 @@ def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dic args = {**args, **{"models": node.name}} if hasattr(node.resource_type, "value") and node.resource_type in dbt_resource_to_class: - task_id_suffix = "run" if node.resource_type == DbtResourceType.MODEL else node.resource_type.value + if node.resource_type == DbtResourceType.MODEL: + if use_name_as_task_id_prefix: + task_id = f"{node.name}_run" + else: + task_id = "run" + else: + task_id = f"{node.name}_{node.resource_type.value}" + task_metadata = TaskMetadata( - id=f"{node.name}_{task_id_suffix}", + id=task_id, operator_class=calculate_operator_class( execution_mode=execution_mode, dbt_class=dbt_resource_to_class[node.resource_type] ), @@ -157,13 +167,18 @@ def build_airflow_graph( # The exception are the test nodes, since it would be too slow to run test tasks individually. # If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup for node_id, node in nodes.items(): - task_meta = create_task_metadata(node=node, execution_mode=execution_mode, args=task_args) + task_meta = create_task_metadata( + node=node, + execution_mode=execution_mode, + args=task_args, + use_name_as_task_id_prefix=test_behavior != TestBehavior.AFTER_EACH, + ) if task_meta and node.resource_type != DbtResourceType.TEST: if node.resource_type == DbtResourceType.MODEL and test_behavior == TestBehavior.AFTER_EACH: with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group: task = create_airflow_task(task_meta, dag, task_group=model_task_group) test_meta = create_test_task_metadata( - f"{node.name}_test", + "test", execution_mode, task_args=task_args, model_name=node.name, diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 92b630ee6..7b539bb5b 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -1,5 +1,5 @@ -from pathlib import Path from datetime import datetime +from pathlib import Path from unittest.mock import patch import pytest @@ -10,15 +10,14 @@ from cosmos.airflow.graph import ( build_airflow_graph, calculate_leaves, + calculate_operator_class, create_task_metadata, create_test_task_metadata, - calculate_operator_class, ) from cosmos.config import ProfileConfig -from cosmos.profiles import PostgresUserPasswordProfileMapping -from cosmos.constants import ExecutionMode, DbtResourceType, TestBehavior +from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior from cosmos.dbt.graph import DbtNode - +from cosmos.profiles import PostgresUserPasswordProfileMapping SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/") @@ -92,23 +91,23 @@ def test_build_airflow_graph_with_after_each(): topological_sort = [task.task_id for task in dag.topological_sort()] expected_sort = [ "seed_parent_seed", - "parent.parent_run", - "parent.parent_test", - "child.child_run", - "child.child_test", + "parent.run", + "parent.test", + "child.run", + "child.test", ] assert topological_sort == expected_sort task_groups = dag.task_group_dict assert len(task_groups) == 2 assert task_groups["parent"].upstream_task_ids == {"seed_parent_seed"} - assert list(task_groups["parent"].children.keys()) == ["parent.parent_run", "parent.parent_test"] + assert list(task_groups["parent"].children.keys()) == ["parent.run", "parent.test"] - assert task_groups["child"].upstream_task_ids == {"parent.parent_test"} - assert list(task_groups["child"].children.keys()) == ["child.child_run", "child.child_test"] + assert task_groups["child"].upstream_task_ids == {"parent.test"} + assert list(task_groups["child"].children.keys()) == ["child.run", "child.test"] assert len(dag.leaves) == 1 - assert dag.leaves[0].task_id == "child.child_test" + assert dag.leaves[0].task_id == "child.test" @pytest.mark.skipif( @@ -232,7 +231,24 @@ def test_create_task_metadata_model(caplog): assert metadata.arguments == {"models": "my_model"} -def test_create_task_metadata_seed(caplog): +def test_create_task_metadata_model_use_name_as_task_id_prefix(caplog): + child_node = DbtNode( + name="my_model", + unique_id="my_folder.my_model", + resource_type=DbtResourceType.MODEL, + depends_on=[], + file_path=Path(""), + tags=[], + config={}, + ) + metadata = create_task_metadata( + child_node, execution_mode=ExecutionMode.LOCAL, args={}, use_name_as_task_id_prefix=False + ) + assert metadata.id == "run" + + +@pytest.mark.parametrize("use_name_as_task_id_prefix", (None, True, False)) +def test_create_task_metadata_seed(caplog, use_name_as_task_id_prefix): sample_node = DbtNode( name="my_seed", unique_id="my_folder.my_seed", @@ -242,7 +258,15 @@ def test_create_task_metadata_seed(caplog): tags=[], config={}, ) - metadata = create_task_metadata(sample_node, execution_mode=ExecutionMode.DOCKER, args={}) + if use_name_as_task_id_prefix is None: + metadata = create_task_metadata(sample_node, execution_mode=ExecutionMode.DOCKER, args={}) + else: + metadata = create_task_metadata( + sample_node, + execution_mode=ExecutionMode.DOCKER, + args={}, + use_name_as_task_id_prefix=use_name_as_task_id_prefix, + ) assert metadata.id == "my_seed_seed" assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator" assert metadata.arguments == {"models": "my_seed"}