diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index 685ca21c782c6..55c40b4d812bd 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -39,6 +39,7 @@ from airflow.utils.cli import get_dag, get_dag_by_file_location, process_subdir, sigint_handler from airflow.utils.dot_renderer import render_dag from airflow.utils.session import create_session, provide_session +from airflow.utils.state import State def _tabulate_dag_runs(dag_runs: List[DagRun], tablefmt: str = "fancy_grid") -> str: @@ -123,6 +124,7 @@ def dag_backfill(args, dag=None): end_date=args.end_date, confirm_prompt=not args.yes, include_subdags=True, + dag_run_state=State.NONE, ) dag.run( @@ -381,7 +383,7 @@ def dag_list_dag_runs(args, dag=None): def dag_test(args, session=None): """Execute one single DagRun for a given DAG and execution date, using the DebugExecutor.""" dag = get_dag(subdir=args.subdir, dag_id=args.dag_id) - dag.clear(start_date=args.execution_date, end_date=args.execution_date, reset_dag_runs=True) + dag.clear(start_date=args.execution_date, end_date=args.execution_date, dag_run_state=State.NONE) try: dag.run(executor=DebugExecutor(), start_date=args.execution_date, end_date=args.execution_date) except BackfillUnfinished as e: diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e6aafd37b110c..dfb6409c69003 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -27,7 +27,7 @@ import warnings from collections import OrderedDict from datetime import datetime, timedelta -from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union +from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast import jinja2 import pendulum @@ -297,7 +297,7 @@ def __init__( template_searchpath = [template_searchpath] self.template_searchpath = template_searchpath self.template_undefined = template_undefined - self.parent_dag = None # Gets set when DAGs are loaded + self.parent_dag: Optional[DAG] = None # Gets set when DAGs are loaded self.last_loaded = timezone.utcnow() self.safe_dag_id = dag_id.replace('.', '__dot__') self.max_active_runs = max_active_runs @@ -966,7 +966,7 @@ def clear( confirm_prompt=False, include_subdags=True, include_parentdag=True, - reset_dag_runs=True, + dag_run_state: str = State.RUNNING, dry_run=False, session=None, get_tis=False, @@ -993,8 +993,7 @@ def clear( :type include_subdags: bool :param include_parentdag: Clear tasks in the parent dag of the subdag. :type include_parentdag: bool - :param reset_dag_runs: Set state of dag to RUNNING - :type reset_dag_runs: bool + :param dag_run_state: state to set DagRun to :param dry_run: Find the tasks to clear but don't clear them. :type dry_run: bool :param session: The sqlalchemy session to use @@ -1025,8 +1024,7 @@ def clear( tis = session.query(TI).filter(TI.dag_id == self.dag_id) tis = tis.filter(TI.task_id.in_(self.task_ids)) - if include_parentdag and self.is_subdag: - + if include_parentdag and self.is_subdag and self.parent_dag is not None: p_dag = self.parent_dag.sub_dag( task_regex=r"^{}$".format(self.dag_id.split('.')[1]), include_upstream=False, @@ -1039,7 +1037,7 @@ def clear( confirm_prompt=confirm_prompt, include_subdags=include_subdags, include_parentdag=False, - reset_dag_runs=reset_dag_runs, + dag_run_state=dag_run_state, get_tis=True, session=session, recursion_depth=recursion_depth, @@ -1065,12 +1063,13 @@ def clear( instances = tis.all() for ti in instances: if ti.operator == ExternalTaskMarker.__name__: - ti.task = self.get_task(ti.task_id) + task: ExternalTaskMarker = cast(ExternalTaskMarker, self.get_task(ti.task_id)) + ti.task = task if recursion_depth == 0: # Maximum recursion depth allowed is the recursion_depth of the first # ExternalTaskMarker in the tasks to be cleared. - max_recursion_depth = ti.task.recursion_depth + max_recursion_depth = task.recursion_depth if recursion_depth + 1 > max_recursion_depth: # Prevent cycles or accidents. @@ -1080,10 +1079,10 @@ def clear( .format(max_recursion_depth, ExternalTaskMarker.__name__, ti.task_id)) ti.render_templates() - external_tis = session.query(TI).filter(TI.dag_id == ti.task.external_dag_id, - TI.task_id == ti.task.external_task_id, + external_tis = session.query(TI).filter(TI.dag_id == task.external_dag_id, + TI.task_id == task.external_task_id, TI.execution_date == - pendulum.parse(ti.task.execution_date)) + pendulum.parse(task.execution_date)) for tii in external_tis: if not dag_bag: @@ -1103,7 +1102,7 @@ def clear( confirm_prompt=confirm_prompt, include_subdags=include_subdags, include_parentdag=False, - reset_dag_runs=reset_dag_runs, + dag_run_state=dag_run_state, get_tis=True, session=session, recursion_depth=recursion_depth + 1, @@ -1134,16 +1133,18 @@ def clear( do_it = utils.helpers.ask_yesno(question) if do_it: - clear_task_instances(tis, - session, - dag=self, - ) - if reset_dag_runs: - self.set_dag_runs_state(session=session, - start_date=start_date, - end_date=end_date, - state=State.NONE, - ) + clear_task_instances( + tis, + session, + dag=self, + activate_dag_runs=False, # We will set DagRun state later. + ) + self.set_dag_runs_state( + session=session, + start_date=start_date, + end_date=end_date, + state=dag_run_state, + ) else: count = 0 print("Bail. Nothing was cleared.") @@ -1161,7 +1162,7 @@ def clear_dags( confirm_prompt=False, include_subdags=True, include_parentdag=False, - reset_dag_runs=True, + dag_run_state=State.RUNNING, dry_run=False, ): all_tis = [] @@ -1174,7 +1175,7 @@ def clear_dags( confirm_prompt=False, include_subdags=include_subdags, include_parentdag=include_parentdag, - reset_dag_runs=reset_dag_runs, + dag_run_state=dag_run_state, dry_run=True) all_tis.extend(tis) @@ -1202,7 +1203,7 @@ def clear_dags( only_running=only_running, confirm_prompt=False, include_subdags=include_subdags, - reset_dag_runs=reset_dag_runs, + dag_run_state=dag_run_state, dry_run=False, ) else: diff --git a/airflow/providers/google/cloud/example_dags/example_datafusion.py b/airflow/providers/google/cloud/example_dags/example_datafusion.py index e2b686cfe6463..62ab1d46fdaf8 100644 --- a/airflow/providers/google/cloud/example_dags/example_datafusion.py +++ b/airflow/providers/google/cloud/example_dags/example_datafusion.py @@ -29,6 +29,7 @@ CloudDataFusionStopPipelineOperator, CloudDataFusionUpdateInstanceOperator, ) from airflow.utils import dates +from airflow.utils.state import State # [START howto_data_fusion_env_variables] LOCATION = "europe-north1" @@ -227,5 +228,5 @@ delete_pipeline >> delete_instance if __name__ == "__main__": - dag.clear(reset_dag_runs=True) + dag.clear(dag_run_state=State.NONE) dag.run() diff --git a/airflow/providers/google/cloud/example_dags/example_gcs.py b/airflow/providers/google/cloud/example_dags/example_gcs.py index 4cdac3636088e..18f173f66edbe 100644 --- a/airflow/providers/google/cloud/example_dags/example_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_gcs.py @@ -32,6 +32,7 @@ from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator from airflow.utils.dates import days_ago +from airflow.utils.state import State default_args = {"start_date": days_ago(1)} @@ -155,5 +156,5 @@ if __name__ == '__main__': - dag.clear(reset_dag_runs=True) + dag.clear(dag_run_state=State.NONE) dag.run() diff --git a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py index ca82c93269ae9..74fb6d328210a 100644 --- a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py +++ b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py @@ -31,6 +31,7 @@ GoogleCampaignManagerReportSensor, ) from airflow.utils import dates +from airflow.utils.state import State PROFILE_ID = os.environ.get("MARKETING_PROFILE_ID", "123456789") FLOODLIGHT_ACTIVITY_ID = os.environ.get("FLOODLIGHT_ACTIVITY_ID", 12345) @@ -157,5 +158,5 @@ insert_conversion >> update_conversion if __name__ == "__main__": - dag.clear(reset_dag_runs=True) + dag.clear(dag_run_state=State.NONE) dag.run() diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index fa4a32e50ccbc..6dda923931ae6 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -433,7 +433,8 @@ def test_dag_test(self, mock_get_dag, mock_executor): subdir=cli_args.subdir, dag_id='example_bash_operator' ), mock.call().clear( - start_date=cli_args.execution_date, end_date=cli_args.execution_date, reset_dag_runs=True + start_date=cli_args.execution_date, end_date=cli_args.execution_date, + dag_run_state=State.NONE, ), mock.call().run( executor=mock_executor.return_value, @@ -461,7 +462,9 @@ def test_dag_test_show_dag(self, mock_get_dag, mock_executor, mock_render_dag): subdir=cli_args.subdir, dag_id='example_bash_operator' ), mock.call().clear( - start_date=cli_args.execution_date, end_date=cli_args.execution_date, reset_dag_runs=True + start_date=cli_args.execution_date, + end_date=cli_args.execution_date, + dag_run_state=State.NONE, ), mock.call().run( executor=mock_executor.return_value, diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 8207b2457f0c1..8891d56c8e98d 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -25,6 +25,7 @@ import unittest from contextlib import redirect_stdout from tempfile import NamedTemporaryFile +from typing import Optional from unittest import mock from unittest.mock import patch @@ -55,6 +56,12 @@ class TestDag(unittest.TestCase): + def setUp(self) -> None: + clear_db_runs() + + def tearDown(self) -> None: + clear_db_runs() + @staticmethod def _clean_up(dag_id: str): with create_session() as session: @@ -1355,8 +1362,14 @@ def test_create_dagrun_run_type_is_obtained_from_run_id(self): dr = dag.create_dagrun(run_id="custom_is_set_to_manual", state=State.NONE) assert dr.run_type == DagRunType.MANUAL.value - def test_clear_reset_dagruns(self): - dag_id = 'test_clear_dag_reset_dagruns' + @parameterized.expand( + [ + (State.NONE,), + (State.RUNNING,), + ] + ) + def test_clear_set_dagrun_state(self, dag_run_state): + dag_id = 'test_clear_set_dagrun_state' self._clean_up(dag_id) task_id = 't1' dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) @@ -1365,7 +1378,7 @@ def test_clear_reset_dagruns(self): session = settings.Session() dagrun_1 = dag.create_dagrun( run_type=DagRunType.BACKFILL_JOB, - state=State.RUNNING, + state=State.FAILED, start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, ) @@ -1378,7 +1391,7 @@ def test_clear_reset_dagruns(self): dag.clear( start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=1), - reset_dag_runs=True, + dag_run_state=dag_run_state, include_subdags=False, include_parentdag=False, session=session, @@ -1392,17 +1405,48 @@ def test_clear_reset_dagruns(self): self.assertEqual(len(dagruns), 1) dagrun = dagruns[0] # type: DagRun - self.assertEqual(dagrun.state, State.NONE) + self.assertEqual(dagrun.state, dag_run_state) + + @parameterized.expand([ + (state, State.NONE) + for state in State.task_states if state != State.RUNNING + ] + [(State.RUNNING, State.SHUTDOWN)]) # type: ignore + def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]): + dag_id = 'test_clear_dag' + self._clean_up(dag_id) + task_id = 't1' + dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) + t_1 = DummyOperator(task_id=task_id, dag=dag) + + session = settings.Session() # type: ignore + dagrun_1 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + state=State.RUNNING, + start_date=DEFAULT_DATE, + execution_date=DEFAULT_DATE, + ) + session.merge(dagrun_1) + + task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=ti_state_begin) + task_instance_1.job_id = 123 + session.merge(task_instance_1) + session.commit() + + dag.clear( + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=1), + session=session, + ) task_instances = session.query( - DagRun, + TI, ).filter( - DagRun.dag_id == dag_id, + TI.dag_id == dag_id, ).all() self.assertEqual(len(task_instances), 1) task_instance = task_instances[0] # type: TI - self.assertEqual(task_instance.state, State.NONE) + self.assertEqual(task_instance.state, ti_state_end) self._clean_up(dag_id)