From 39ea8722c04fb1c0b286b4248a52e8d974a47b30 Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Sun, 15 Nov 2020 12:28:04 +0100 Subject: [PATCH] Check for TaskGroup in _PythonDecoratedOperator (#12312) Crucial feature of functions decorated by @task is to be able to invoke them multiple times in single DAG. To do this we are generating custom task_id for each invocation. However, this didn't work with TaskGroup as the task_id is already altered by adding group_id prefix. This PR fixes it. closes: #12309 Co-authored-by: Kaxil Naik --- airflow/operators/python.py | 17 ++++++++++++++--- docs/concepts.rst | 10 ++++++++++ tests/operators/test_python.py | 18 ++++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 6d254a0f7bc2..c03f6b575033 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -40,6 +40,7 @@ from airflow.utils.operator_helpers import determine_kwargs from airflow.utils.process_utils import execute_in_subprocess from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script +from airflow.utils.task_group import TaskGroup, TaskGroupContext class PythonOperator(BaseOperator): @@ -165,7 +166,7 @@ def __init__( multiple_outputs: bool = False, **kwargs, ) -> None: - kwargs['task_id'] = self._get_unique_task_id(task_id, kwargs.get('dag')) + kwargs['task_id'] = self._get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group')) super().__init__(**kwargs) self.python_callable = python_callable @@ -176,7 +177,9 @@ def __init__( self.op_kwargs = op_kwargs @staticmethod - def _get_unique_task_id(task_id: str, dag: Optional[DAG] = None) -> str: + def _get_unique_task_id( + task_id: str, dag: Optional[DAG] = None, task_group: Optional[TaskGroup] = None + ) -> str: """ Generate unique task id given a DAG (or if run in a DAG context) Ids are generated by appending a unique number to the end of @@ -190,7 +193,15 @@ def _get_unique_task_id(task_id: str, dag: Optional[DAG] = None) -> str: task_id__20 """ dag = dag or DagContext.get_current_dag() - if not dag or task_id not in dag.task_ids: + if not dag: + return task_id + + # We need to check if we are in the context of TaskGroup as the task_id may + # already be altered + task_group = task_group or TaskGroupContext.get_current_task_group(dag) + tg_task_id = task_group.child_id(task_id) if task_group else task_id + + if tg_task_id not in dag.task_ids: return task_id core = re.split(r'__\d+$', task_id)[0] suffixes = sorted( diff --git a/docs/concepts.rst b/docs/concepts.rst index 4c94f3da3b5d..a0136e435dae 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -315,6 +315,16 @@ a unique ``task_id`` for each generated operator. Task ids are generated by appending a number at the end of the original task id. For the above example, the DAG will have the following task ids: ``[update_user, update_user__1, update_user__2, ... update_user__n]``. +Due to dynamic nature of the ids generations users should be aware that changing a DAG by adding or removing additional +invocations of task-decorated function may change ``task_id`` of other task of the same type withing a single DAG. + +For example, if there are many task-decorated tasks without explicitly given task_id. Their ``task_id`` will be +generated sequentially: ``task__1``, ``task__2``, ``task__3``, etc. After the DAG goes into production, one day +someone inserts a new task before ``task__2``. The ``task_id`` after that will all be shifted forward by one place. +This is going to produce ``task__1``, ``task__2``, ``task__3``, ``task__4``. But at this point the ``task__3`` is +no longer the same ``task__3`` as before. This may create confusion when analyzing history logs / DagRuns of a DAG +that changed over time. + Accessing current context ------------------------- diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index f1bb085d8033..a6300c030754 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -45,6 +45,7 @@ from airflow.utils.dates import days_ago from airflow.utils.session import create_session from airflow.utils.state import State +from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType from tests.test_utils.db import clear_db_runs @@ -487,6 +488,23 @@ def do_run(): assert do_run_1.operator.task_id == 'do_run__1' # pylint: disable=maybe-no-member assert do_run_2.operator.task_id == 'do_run__2' # pylint: disable=maybe-no-member + def test_multiple_calls_in_task_group(self): + """Test calling task multiple times in a TaskGroup""" + + @task_decorator + def do_run(): + return 4 + + group_id = "KnightsOfNii" + with self.dag: + with TaskGroup(group_id=group_id): + do_run() + assert [f"{group_id}.do_run"] == self.dag.task_ids + do_run() + assert [f"{group_id}.do_run", f"{group_id}.do_run__1"] == self.dag.task_ids + + assert len(self.dag.task_ids) == 2 + def test_call_20(self): """Test calling decorated function 21 times in a DAG"""