Skip to content

Commit

Permalink
Check for TaskGroup in _PythonDecoratedOperator (#12312)
Browse files Browse the repository at this point in the history
Crucial feature of functions decorated by @task is to be able
to invoke them multiple times in single DAG. To do this we are
generating custom task_id for each invocation. However, this didn't
work with TaskGroup as the task_id is already altered by adding group_id
prefix. This PR fixes it.

closes: #12309

Co-authored-by: Kaxil Naik <kaxilnaik@gmail.com>
  • Loading branch information
turbaszek and kaxil committed Nov 15, 2020
1 parent 823b3aa commit 39ea872
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
17 changes: 14 additions & 3 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from airflow.utils.operator_helpers import determine_kwargs
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script
from airflow.utils.task_group import TaskGroup, TaskGroupContext


class PythonOperator(BaseOperator):
Expand Down Expand Up @@ -165,7 +166,7 @@ def __init__(
multiple_outputs: bool = False,
**kwargs,
) -> None:
kwargs['task_id'] = self._get_unique_task_id(task_id, kwargs.get('dag'))
kwargs['task_id'] = self._get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group'))
super().__init__(**kwargs)
self.python_callable = python_callable

Expand All @@ -176,7 +177,9 @@ def __init__(
self.op_kwargs = op_kwargs

@staticmethod
def _get_unique_task_id(task_id: str, dag: Optional[DAG] = None) -> str:
def _get_unique_task_id(
task_id: str, dag: Optional[DAG] = None, task_group: Optional[TaskGroup] = None
) -> str:
"""
Generate unique task id given a DAG (or if run in a DAG context)
Ids are generated by appending a unique number to the end of
Expand All @@ -190,7 +193,15 @@ def _get_unique_task_id(task_id: str, dag: Optional[DAG] = None) -> str:
task_id__20
"""
dag = dag or DagContext.get_current_dag()
if not dag or task_id not in dag.task_ids:
if not dag:
return task_id

# We need to check if we are in the context of TaskGroup as the task_id may
# already be altered
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
tg_task_id = task_group.child_id(task_id) if task_group else task_id

if tg_task_id not in dag.task_ids:
return task_id
core = re.split(r'__\d+$', task_id)[0]
suffixes = sorted(
Expand Down
10 changes: 10 additions & 0 deletions docs/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,16 @@ a unique ``task_id`` for each generated operator.
Task ids are generated by appending a number at the end of the original task id. For the above example, the DAG will have
the following task ids: ``[update_user, update_user__1, update_user__2, ... update_user__n]``.

Due to dynamic nature of the ids generations users should be aware that changing a DAG by adding or removing additional
invocations of task-decorated function may change ``task_id`` of other task of the same type withing a single DAG.

For example, if there are many task-decorated tasks without explicitly given task_id. Their ``task_id`` will be
generated sequentially: ``task__1``, ``task__2``, ``task__3``, etc. After the DAG goes into production, one day
someone inserts a new task before ``task__2``. The ``task_id`` after that will all be shifted forward by one place.
This is going to produce ``task__1``, ``task__2``, ``task__3``, ``task__4``. But at this point the ``task__3`` is
no longer the same ``task__3`` as before. This may create confusion when analyzing history logs / DagRuns of a DAG
that changed over time.

Accessing current context
-------------------------

Expand Down
18 changes: 18 additions & 0 deletions tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from airflow.utils.dates import days_ago
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
from tests.test_utils.db import clear_db_runs

Expand Down Expand Up @@ -487,6 +488,23 @@ def do_run():
assert do_run_1.operator.task_id == 'do_run__1' # pylint: disable=maybe-no-member
assert do_run_2.operator.task_id == 'do_run__2' # pylint: disable=maybe-no-member

def test_multiple_calls_in_task_group(self):
"""Test calling task multiple times in a TaskGroup"""

@task_decorator
def do_run():
return 4

group_id = "KnightsOfNii"
with self.dag:
with TaskGroup(group_id=group_id):
do_run()
assert [f"{group_id}.do_run"] == self.dag.task_ids
do_run()
assert [f"{group_id}.do_run", f"{group_id}.do_run__1"] == self.dag.task_ids

assert len(self.dag.task_ids) == 2

def test_call_20(self):
"""Test calling decorated function 21 times in a DAG"""

Expand Down

0 comments on commit 39ea872

Please sign in to comment.