Skip to content

Commit

Permalink
Add default_deferrable config (#31712)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Jul 5, 2023
1 parent ab2c861 commit f859350
Show file tree
Hide file tree
Showing 41 changed files with 160 additions and 111 deletions.
7 changes: 7 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,13 @@ operators:
type: string
example: ~
default: "airflow"
default_deferrable:
description: |
The default value of attribute "deferrable" in operators and sensors.
version_added: ~
type: boolean
example: ~
default: "false"
default_cpus:
description: ~
version_added: ~
Expand Down
3 changes: 3 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,9 @@ password =
# The default owner assigned to each new operator, unless
# provided explicitly or passed via ``default_args``
default_owner = airflow

# The default value of attribute "deferrable" in operators and sensors.
default_deferrable = false
default_cpus = 1
default_ram = 512
default_disk = 512
Expand Down
6 changes: 2 additions & 4 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sqlalchemy.orm.exc import NoResultFound

from airflow.api.common.trigger_dag import trigger_dag
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
poke_interval: int = 60,
allowed_states: list | None = None,
failed_states: list | None = None,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -135,7 +136,6 @@ def __init__(
self.execution_date = execution_date

def execute(self, context: Context):

if isinstance(self.execution_date, datetime.datetime):
parsed_execution_date = self.execution_date
elif isinstance(self.execution_date, str):
Expand Down Expand Up @@ -187,7 +187,6 @@ def execute(self, context: Context):
ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)

if self.wait_for_completion:

# Kick off the deferral process
if self._defer:
self.defer(
Expand Down Expand Up @@ -219,7 +218,6 @@ def execute(self, context: Context):

@provide_session
def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]):

# This execution date is parsed from the return trigger event
provided_execution_date = event[1]["execution_dates"][0]
try:
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import TYPE_CHECKING, Any, Sequence

from airflow import AirflowException
from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(
sleep_time: int = 30,
max_polling_attempts: int | None = None,
log_query: bool = True,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(
region_name: str | None = None,
tags: dict | None = None,
wait_for_completion: bool = True,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: int = 30,
awslogs_enabled: bool = False,
awslogs_fetch_interval: timedelta = timedelta(seconds=30),
Expand Down Expand Up @@ -437,7 +438,7 @@ def __init__(
max_retries: int | None = None,
aws_conn_id: str | None = None,
region_name: str | None = None,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
if "status_retries" in kwargs:
Expand Down
19 changes: 6 additions & 13 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,14 @@

import boto3

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator, XCom
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.ecs import (
EcsClusterStates,
EcsHook,
should_retry_eni,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.ecs import (
ClusterWaiterTrigger,
TaskDoneTrigger,
)
from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, TaskDoneTrigger
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -118,7 +112,7 @@ def __init__(
wait_for_completion: bool = True,
waiter_delay: int = 15,
waiter_max_attempts: int = 60,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -201,7 +195,7 @@ def __init__(
wait_for_completion: bool = True,
waiter_delay: int = 15,
waiter_max_attempts: int = 60,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -482,7 +476,7 @@ def __init__(
wait_for_completion: bool = True,
waiter_delay: int = 6,
waiter_max_attempts: int = 100,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -727,7 +721,6 @@ def _check_success_task(self) -> None:
raise AirflowException(response)

for task in response["tasks"]:

if task.get("stopCode", "") == "TaskFailedToStart":
# Reset task arn here otherwise the retry run will not start
# a new task but keep polling the old dead one
Expand Down
11 changes: 5 additions & 6 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from botocore.exceptions import ClientError, WaiterError

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.eks import EksHook
Expand Down Expand Up @@ -83,7 +84,6 @@ def _create_compute(
log = logging.getLogger(__name__)
eks_hook = EksHook(aws_conn_id=aws_conn_id, region_name=region)
if compute == "nodegroup" and nodegroup_name:

# this is to satisfy mypy
subnets = subnets or []
create_nodegroup_kwargs = create_nodegroup_kwargs or {}
Expand All @@ -107,7 +107,6 @@ def _create_compute(
status_args=["nodegroup.status"],
)
elif compute == "fargate" and fargate_profile_name:

# this is to satisfy mypy
create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
fargate_selectors = fargate_selectors or []
Expand Down Expand Up @@ -366,7 +365,7 @@ def __init__(
region: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 80,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
self.nodegroup_subnets = nodegroup_subnets
Expand Down Expand Up @@ -489,7 +488,7 @@ def __init__(
region: str | None = None,
waiter_delay: int = 10,
waiter_max_attempts: int = 60,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
self.cluster_name = cluster_name
Expand Down Expand Up @@ -690,7 +689,7 @@ def __init__(
region: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 40,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
self.cluster_name = cluster_name
Expand Down Expand Up @@ -780,7 +779,7 @@ def __init__(
region: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down
9 changes: 5 additions & 4 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import TYPE_CHECKING, Any, Sequence
from uuid import uuid4

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
Expand Down Expand Up @@ -96,7 +97,7 @@ def __init__(
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
execution_role_arn: str | None = None,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
if not exactly_one(job_flow_id is None, job_flow_name is None):
Expand Down Expand Up @@ -510,7 +511,7 @@ def __init__(
max_tries: int | None = None,
tags: dict | None = None,
max_polling_attempts: int | None = None,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -695,7 +696,7 @@ def __init__(
waiter_delay: int | None | ArgNotSet = NOTSET,
waiter_countdown: int | None = None,
waiter_check_interval_seconds: int = 60,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
):
if waiter_max_attempts is NOTSET:
Expand Down Expand Up @@ -900,7 +901,7 @@ def __init__(
aws_conn_id: str = "aws_default",
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import TYPE_CHECKING, Sequence

from airflow import AirflowException
from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
create_job_kwargs: dict | None = None,
run_job_kwargs: dict | None = None,
wait_for_completion: bool = True,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
verbose: bool = False,
update_config: bool = False,
job_poll_interval: int | float = 6,
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/operators/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import TYPE_CHECKING, Sequence

from airflow import AirflowException
from airflow.configuration import conf
from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
region_name: str | None = None,
poll_interval: int = 5,
wait_for_completion: bool = True,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from mypy_boto3_rds.type_defs import TagTypeDef

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.rds import RdsHook
Expand Down Expand Up @@ -554,7 +555,7 @@ def __init__(
rds_kwargs: dict | None = None,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
**kwargs,
Expand Down Expand Up @@ -645,7 +646,7 @@ def __init__(
rds_kwargs: dict | None = None,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
**kwargs,
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
Expand Down Expand Up @@ -148,7 +149,7 @@ def __init__(
wait_for_completion: bool = False,
max_attempt: int = 5,
poll_interval: int = 60,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -327,7 +328,7 @@ def __init__(
poll_interval: int = 15,
max_attempt: int = 20,
aws_conn_id: str = "aws_default",
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -470,7 +471,7 @@ def __init__(
cluster_identifier: str,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = False,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: int = 10,
max_attempts: int = 10,
**kwargs,
Expand Down Expand Up @@ -560,7 +561,7 @@ def __init__(
*,
cluster_identifier: str,
aws_conn_id: str = "aws_default",
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: int = 10,
max_attempts: int = 15,
**kwargs,
Expand Down Expand Up @@ -647,7 +648,7 @@ def __init__(
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
poll_interval: int = 30,
deferrable: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
max_attempts: int = 30,
**kwargs,
):
Expand All @@ -668,7 +669,6 @@ def __init__(
self.max_attempts = max_attempts

def execute(self, context: Context):

while self._attempts >= 1:
try:
self.redshift_hook.delete_cluster(
Expand Down
Loading

0 comments on commit f859350

Please sign in to comment.