Skip to content

Commit

Permalink
Merge branch 'main' into community-issue
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana authored Sep 11, 2023
2 parents 94a19f6 + 2e8f0a8 commit 9e15f87
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 23 deletions.
31 changes: 23 additions & 8 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

from typing import Any, Callable

from airflow.models import BaseOperator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup

from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode
from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior
from cosmos.core.airflow import get_airflow_task as create_airflow_task
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
from cosmos.log import get_logger
from airflow.models import BaseOperator


logger = get_logger(__name__)

Expand Down Expand Up @@ -51,14 +50,18 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st
return leaves


def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any]) -> TaskMetadata | None:
def create_task_metadata(
node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_name_as_task_id_prefix=True
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
:param node: The dbt node which we desired to convert into an Airflow Task
:param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.LOCAL, ExecutionMode.KUBERNETES).
Default is ExecutionMode.LOCAL.
:param args: Arguments to be used to instantiate an Airflow Task
:param use_name_as_task_id_prefix: If resource_type is DbtResourceType.MODEL, it determines whether
using name as task id prefix or not. If it is True task_id = <node.name>_run, else task_id=run.
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
dbt_resource_to_class = {
Expand All @@ -70,9 +73,16 @@ def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dic
args = {**args, **{"models": node.name}}

if hasattr(node.resource_type, "value") and node.resource_type in dbt_resource_to_class:
task_id_suffix = "run" if node.resource_type == DbtResourceType.MODEL else node.resource_type.value
if node.resource_type == DbtResourceType.MODEL:
if use_name_as_task_id_prefix:
task_id = f"{node.name}_run"
else:
task_id = "run"
else:
task_id = f"{node.name}_{node.resource_type.value}"

task_metadata = TaskMetadata(
id=f"{node.name}_{task_id_suffix}",
id=task_id,
operator_class=calculate_operator_class(
execution_mode=execution_mode, dbt_class=dbt_resource_to_class[node.resource_type]
),
Expand Down Expand Up @@ -157,13 +167,18 @@ def build_airflow_graph(
# The exception are the test nodes, since it would be too slow to run test tasks individually.
# If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup
for node_id, node in nodes.items():
task_meta = create_task_metadata(node=node, execution_mode=execution_mode, args=task_args)
task_meta = create_task_metadata(
node=node,
execution_mode=execution_mode,
args=task_args,
use_name_as_task_id_prefix=test_behavior != TestBehavior.AFTER_EACH,
)
if task_meta and node.resource_type != DbtResourceType.TEST:
if node.resource_type == DbtResourceType.MODEL and test_behavior == TestBehavior.AFTER_EACH:
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
test_meta = create_test_task_metadata(
f"{node.name}_test",
"test",
execution_mode,
task_args=task_args,
model_name=node.name,
Expand Down
54 changes: 39 additions & 15 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from datetime import datetime
from pathlib import Path
from unittest.mock import patch

import pytest
Expand All @@ -10,15 +10,14 @@
from cosmos.airflow.graph import (
build_airflow_graph,
calculate_leaves,
calculate_operator_class,
create_task_metadata,
create_test_task_metadata,
calculate_operator_class,
)
from cosmos.config import ProfileConfig
from cosmos.profiles import PostgresUserPasswordProfileMapping
from cosmos.constants import ExecutionMode, DbtResourceType, TestBehavior
from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior
from cosmos.dbt.graph import DbtNode

from cosmos.profiles import PostgresUserPasswordProfileMapping

SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/")

Expand Down Expand Up @@ -92,23 +91,23 @@ def test_build_airflow_graph_with_after_each():
topological_sort = [task.task_id for task in dag.topological_sort()]
expected_sort = [
"seed_parent_seed",
"parent.parent_run",
"parent.parent_test",
"child.child_run",
"child.child_test",
"parent.run",
"parent.test",
"child.run",
"child.test",
]
assert topological_sort == expected_sort
task_groups = dag.task_group_dict
assert len(task_groups) == 2

assert task_groups["parent"].upstream_task_ids == {"seed_parent_seed"}
assert list(task_groups["parent"].children.keys()) == ["parent.parent_run", "parent.parent_test"]
assert list(task_groups["parent"].children.keys()) == ["parent.run", "parent.test"]

assert task_groups["child"].upstream_task_ids == {"parent.parent_test"}
assert list(task_groups["child"].children.keys()) == ["child.child_run", "child.child_test"]
assert task_groups["child"].upstream_task_ids == {"parent.test"}
assert list(task_groups["child"].children.keys()) == ["child.run", "child.test"]

assert len(dag.leaves) == 1
assert dag.leaves[0].task_id == "child.child_test"
assert dag.leaves[0].task_id == "child.test"


@pytest.mark.skipif(
Expand Down Expand Up @@ -232,7 +231,24 @@ def test_create_task_metadata_model(caplog):
assert metadata.arguments == {"models": "my_model"}


def test_create_task_metadata_seed(caplog):
def test_create_task_metadata_model_use_name_as_task_id_prefix(caplog):
child_node = DbtNode(
name="my_model",
unique_id="my_folder.my_model",
resource_type=DbtResourceType.MODEL,
depends_on=[],
file_path=Path(""),
tags=[],
config={},
)
metadata = create_task_metadata(
child_node, execution_mode=ExecutionMode.LOCAL, args={}, use_name_as_task_id_prefix=False
)
assert metadata.id == "run"


@pytest.mark.parametrize("use_name_as_task_id_prefix", (None, True, False))
def test_create_task_metadata_seed(caplog, use_name_as_task_id_prefix):
sample_node = DbtNode(
name="my_seed",
unique_id="my_folder.my_seed",
Expand All @@ -242,7 +258,15 @@ def test_create_task_metadata_seed(caplog):
tags=[],
config={},
)
metadata = create_task_metadata(sample_node, execution_mode=ExecutionMode.DOCKER, args={})
if use_name_as_task_id_prefix is None:
metadata = create_task_metadata(sample_node, execution_mode=ExecutionMode.DOCKER, args={})
else:
metadata = create_task_metadata(
sample_node,
execution_mode=ExecutionMode.DOCKER,
args={},
use_name_as_task_id_prefix=use_name_as_task_id_prefix,
)
assert metadata.id == "my_seed_seed"
assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator"
assert metadata.arguments == {"models": "my_seed"}
Expand Down

0 comments on commit 9e15f87

Please sign in to comment.