Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change DAG.clear to take dag_run_state #9824

Merged
merged 5 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.utils.cli import get_dag, get_dag_by_file_location, process_subdir, sigint_handler
from airflow.utils.dot_renderer import render_dag
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State


def _tabulate_dag_runs(dag_runs: List[DagRun], tablefmt: str = "fancy_grid") -> str:
Expand Down Expand Up @@ -123,6 +124,7 @@ def dag_backfill(args, dag=None):
end_date=args.end_date,
confirm_prompt=not args.yes,
include_subdags=True,
dag_run_state=State.NONE,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I correctly understand that previously the default value State.RUNNING? If yes then we should add note in UPDATING.md.

Copy link
Contributor Author

@milton0825 milton0825 Jul 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backfill experience from the user side should be almost the same. If user backfill with --reset-dagruns, it will first clear the DagRun (setting the state to None) then the backfill scheduler would set the state to RUNNING.

)

dag.run(
Expand Down Expand Up @@ -381,7 +383,7 @@ def dag_list_dag_runs(args, dag=None):
def dag_test(args, session=None):
"""Execute one single DagRun for a given DAG and execution date, using the DebugExecutor."""
dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
dag.clear(start_date=args.execution_date, end_date=args.execution_date, reset_dag_runs=True)
dag.clear(start_date=args.execution_date, end_date=args.execution_date, dag_run_state=State.NONE)
try:
dag.run(executor=DebugExecutor(), start_date=args.execution_date, end_date=args.execution_date)
except BackfillUnfinished as e:
Expand Down
55 changes: 28 additions & 27 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import warnings
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union
from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast

import jinja2
import pendulum
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(
template_searchpath = [template_searchpath]
self.template_searchpath = template_searchpath
self.template_undefined = template_undefined
self.parent_dag = None # Gets set when DAGs are loaded
self.parent_dag: Optional[DAG] = None # Gets set when DAGs are loaded
self.last_loaded = timezone.utcnow()
self.safe_dag_id = dag_id.replace('.', '__dot__')
self.max_active_runs = max_active_runs
Expand Down Expand Up @@ -966,7 +966,7 @@ def clear(
confirm_prompt=False,
include_subdags=True,
include_parentdag=True,
reset_dag_runs=True,
dag_run_state: str = State.RUNNING,
dry_run=False,
session=None,
get_tis=False,
Expand All @@ -993,8 +993,7 @@ def clear(
:type include_subdags: bool
:param include_parentdag: Clear tasks in the parent dag of the subdag.
:type include_parentdag: bool
:param reset_dag_runs: Set state of dag to RUNNING
:type reset_dag_runs: bool
:param dag_run_state: state to set DagRun to
:param dry_run: Find the tasks to clear but don't clear them.
:type dry_run: bool
:param session: The sqlalchemy session to use
Expand Down Expand Up @@ -1025,8 +1024,7 @@ def clear(
tis = session.query(TI).filter(TI.dag_id == self.dag_id)
tis = tis.filter(TI.task_id.in_(self.task_ids))

if include_parentdag and self.is_subdag:

if include_parentdag and self.is_subdag and self.parent_dag is not None:
p_dag = self.parent_dag.sub_dag(
task_regex=r"^{}$".format(self.dag_id.split('.')[1]),
include_upstream=False,
Expand All @@ -1039,7 +1037,7 @@ def clear(
confirm_prompt=confirm_prompt,
include_subdags=include_subdags,
include_parentdag=False,
reset_dag_runs=reset_dag_runs,
dag_run_state=dag_run_state,
get_tis=True,
session=session,
recursion_depth=recursion_depth,
Expand All @@ -1065,12 +1063,13 @@ def clear(
instances = tis.all()
for ti in instances:
if ti.operator == ExternalTaskMarker.__name__:
ti.task = self.get_task(ti.task_id)
task: ExternalTaskMarker = cast(ExternalTaskMarker, self.get_task(ti.task_id))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to fix mypy issue.

ti.task = task

if recursion_depth == 0:
# Maximum recursion depth allowed is the recursion_depth of the first
# ExternalTaskMarker in the tasks to be cleared.
max_recursion_depth = ti.task.recursion_depth
max_recursion_depth = task.recursion_depth

if recursion_depth + 1 > max_recursion_depth:
# Prevent cycles or accidents.
Expand All @@ -1080,10 +1079,10 @@ def clear(
.format(max_recursion_depth,
ExternalTaskMarker.__name__, ti.task_id))
ti.render_templates()
external_tis = session.query(TI).filter(TI.dag_id == ti.task.external_dag_id,
TI.task_id == ti.task.external_task_id,
external_tis = session.query(TI).filter(TI.dag_id == task.external_dag_id,
TI.task_id == task.external_task_id,
TI.execution_date ==
pendulum.parse(ti.task.execution_date))
pendulum.parse(task.execution_date))

for tii in external_tis:
if not dag_bag:
Expand All @@ -1103,7 +1102,7 @@ def clear(
confirm_prompt=confirm_prompt,
include_subdags=include_subdags,
include_parentdag=False,
reset_dag_runs=reset_dag_runs,
dag_run_state=dag_run_state,
get_tis=True,
session=session,
recursion_depth=recursion_depth + 1,
Expand Down Expand Up @@ -1134,16 +1133,18 @@ def clear(
do_it = utils.helpers.ask_yesno(question)

if do_it:
clear_task_instances(tis,
session,
dag=self,
)
if reset_dag_runs:
self.set_dag_runs_state(session=session,
start_date=start_date,
end_date=end_date,
state=State.NONE,
)
clear_task_instances(
tis,
session,
dag=self,
activate_dag_runs=False, # We will set DagRun state later.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When activate_dag_runs=True, DagRun will be set to RUNNING

)
self.set_dag_runs_state(
session=session,
start_date=start_date,
end_date=end_date,
state=dag_run_state,
)
else:
count = 0
print("Bail. Nothing was cleared.")
Expand All @@ -1161,7 +1162,7 @@ def clear_dags(
confirm_prompt=False,
include_subdags=True,
include_parentdag=False,
reset_dag_runs=True,
dag_run_state=State.RUNNING,
dry_run=False,
):
all_tis = []
Expand All @@ -1174,7 +1175,7 @@ def clear_dags(
confirm_prompt=False,
include_subdags=include_subdags,
include_parentdag=include_parentdag,
reset_dag_runs=reset_dag_runs,
dag_run_state=dag_run_state,
dry_run=True)
all_tis.extend(tis)

Expand Down Expand Up @@ -1202,7 +1203,7 @@ def clear_dags(
only_running=only_running,
confirm_prompt=False,
include_subdags=include_subdags,
reset_dag_runs=reset_dag_runs,
dag_run_state=dag_run_state,
dry_run=False,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CloudDataFusionStopPipelineOperator, CloudDataFusionUpdateInstanceOperator,
)
from airflow.utils import dates
from airflow.utils.state import State

# [START howto_data_fusion_env_variables]
LOCATION = "europe-north1"
Expand Down Expand Up @@ -227,5 +228,5 @@
delete_pipeline >> delete_instance

if __name__ == "__main__":
dag.clear(reset_dag_runs=True)
dag.clear(dag_run_state=State.NONE)
dag.run()
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/example_dags/example_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator
from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator
from airflow.utils.dates import days_ago
from airflow.utils.state import State

default_args = {"start_date": days_ago(1)}

Expand Down Expand Up @@ -155,5 +156,5 @@


if __name__ == '__main__':
dag.clear(reset_dag_runs=True)
dag.clear(dag_run_state=State.NONE)
dag.run()
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GoogleCampaignManagerReportSensor,
)
from airflow.utils import dates
from airflow.utils.state import State

PROFILE_ID = os.environ.get("MARKETING_PROFILE_ID", "123456789")
FLOODLIGHT_ACTIVITY_ID = os.environ.get("FLOODLIGHT_ACTIVITY_ID", 12345)
Expand Down Expand Up @@ -157,5 +158,5 @@
insert_conversion >> update_conversion

if __name__ == "__main__":
dag.clear(reset_dag_runs=True)
dag.clear(dag_run_state=State.NONE)
dag.run()
7 changes: 5 additions & 2 deletions tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ def test_dag_test(self, mock_get_dag, mock_executor):
subdir=cli_args.subdir, dag_id='example_bash_operator'
),
mock.call().clear(
start_date=cli_args.execution_date, end_date=cli_args.execution_date, reset_dag_runs=True
start_date=cli_args.execution_date, end_date=cli_args.execution_date,
dag_run_state=State.NONE,
),
mock.call().run(
executor=mock_executor.return_value,
Expand Down Expand Up @@ -461,7 +462,9 @@ def test_dag_test_show_dag(self, mock_get_dag, mock_executor, mock_render_dag):
subdir=cli_args.subdir, dag_id='example_bash_operator'
),
mock.call().clear(
start_date=cli_args.execution_date, end_date=cli_args.execution_date, reset_dag_runs=True
start_date=cli_args.execution_date,
end_date=cli_args.execution_date,
dag_run_state=State.NONE,
),
mock.call().run(
executor=mock_executor.return_value,
Expand Down
60 changes: 52 additions & 8 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import unittest
from contextlib import redirect_stdout
from tempfile import NamedTemporaryFile
from typing import Optional
from unittest import mock
from unittest.mock import patch

Expand Down Expand Up @@ -55,6 +56,12 @@

class TestDag(unittest.TestCase):

def setUp(self) -> None:
clear_db_runs()

def tearDown(self) -> None:
clear_db_runs()

@staticmethod
def _clean_up(dag_id: str):
with create_session() as session:
Expand Down Expand Up @@ -1355,8 +1362,14 @@ def test_create_dagrun_run_type_is_obtained_from_run_id(self):
dr = dag.create_dagrun(run_id="custom_is_set_to_manual", state=State.NONE)
assert dr.run_type == DagRunType.MANUAL.value

def test_clear_reset_dagruns(self):
dag_id = 'test_clear_dag_reset_dagruns'
@parameterized.expand(
[
(State.NONE,),
(State.RUNNING,),
]
)
def test_clear_set_dagrun_state(self, dag_run_state):
dag_id = 'test_clear_set_dagrun_state'
self._clean_up(dag_id)
task_id = 't1'
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
Expand All @@ -1365,7 +1378,7 @@ def test_clear_reset_dagruns(self):
session = settings.Session()
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.RUNNING,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
Expand All @@ -1378,7 +1391,7 @@ def test_clear_reset_dagruns(self):
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
reset_dag_runs=True,
dag_run_state=dag_run_state,
include_subdags=False,
include_parentdag=False,
session=session,
Expand All @@ -1392,17 +1405,48 @@ def test_clear_reset_dagruns(self):

self.assertEqual(len(dagruns), 1)
dagrun = dagruns[0] # type: DagRun
self.assertEqual(dagrun.state, State.NONE)
self.assertEqual(dagrun.state, dag_run_state)

@parameterized.expand([
(state, State.NONE)
for state in State.task_states if state != State.RUNNING
] + [(State.RUNNING, State.SHUTDOWN)]) # type: ignore
def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
dag_id = 'test_clear_dag'
self._clean_up(dag_id)
task_id = 't1'
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
t_1 = DummyOperator(task_id=task_id, dag=dag)

session = settings.Session() # type: ignore
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.RUNNING,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
session.merge(dagrun_1)

task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=ti_state_begin)
task_instance_1.job_id = 123
session.merge(task_instance_1)
session.commit()

dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
session=session,
)

task_instances = session.query(
DagRun,
TI,
).filter(
DagRun.dag_id == dag_id,
TI.dag_id == dag_id,
).all()

self.assertEqual(len(task_instances), 1)
task_instance = task_instances[0] # type: TI
self.assertEqual(task_instance.state, State.NONE)
self.assertEqual(task_instance.state, ti_state_end)
self._clean_up(dag_id)


Expand Down