diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 1ad8705a7047a..d588a7f26e3b2 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1305,6 +1305,13 @@ operators: type: string example: ~ default: "airflow" + default_deferrable: + description: | + The default value of attribute "deferrable" in operators and sensors. + version_added: ~ + type: boolean + example: ~ + default: "false" default_cpus: description: ~ version_added: ~ diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 0cc99a3f4e5f2..ae6bdec085247 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -703,6 +703,9 @@ password = # The default owner assigned to each new operator, unless # provided explicitly or passed via ``default_args`` default_owner = airflow + +# The default value of attribute "deferrable" in operators and sensors. +default_deferrable = false default_cpus = 1 default_ram = 512 default_disk = 512 diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 0165bb470ed92..548ef9189418c 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -25,6 +25,7 @@ from sqlalchemy.orm.exc import NoResultFound from airflow.api.common.trigger_dag import trigger_dag +from airflow.configuration import conf from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.models.dag import DagModel @@ -113,7 +114,7 @@ def __init__( poke_interval: int = 60, allowed_states: list | None = None, failed_states: list | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -135,7 +136,6 @@ def __init__( self.execution_date = execution_date def execute(self, context: Context): - if isinstance(self.execution_date, datetime.datetime): parsed_execution_date = self.execution_date elif isinstance(self.execution_date, str): @@ -187,7 +187,6 @@ def execute(self, context: Context): ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) if self.wait_for_completion: - # Kick off the deferral process if self._defer: self.defer( @@ -219,7 +218,6 @@ def execute(self, context: Context): @provide_session def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]): - # This execution date is parsed from the return trigger event provided_execution_date = event[1]["execution_dates"][0] try: diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 0467fe6d11aed..6dd1432ea4803 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow import AirflowException +from airflow.configuration import conf from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger @@ -74,7 +75,7 @@ def __init__( sleep_time: int = 30, max_polling_attempts: int | None = None, log_query: bool = True, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index b9b3322c49102..9dd954d05cd80 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -29,6 +29,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook @@ -154,7 +155,7 @@ def __init__( region_name: str | None = None, tags: dict | None = None, wait_for_completion: bool = True, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: int = 30, awslogs_enabled: bool = False, awslogs_fetch_interval: timedelta = timedelta(seconds=30), @@ -437,7 +438,7 @@ def __init__( max_retries: int | None = None, aws_conn_id: str | None = None, region_name: str | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): if "status_retries" in kwargs: diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 91533cfa62112..e5833bf4c3d53 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -26,20 +26,14 @@ import boto3 +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator, XCom from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.providers.amazon.aws.hooks.ecs import ( - EcsClusterStates, - EcsHook, - should_retry_eni, -) +from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook -from airflow.providers.amazon.aws.triggers.ecs import ( - ClusterWaiterTrigger, - TaskDoneTrigger, -) +from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, TaskDoneTrigger from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher from airflow.utils.helpers import prune_dict from airflow.utils.session import provide_session @@ -118,7 +112,7 @@ def __init__( wait_for_completion: bool = True, waiter_delay: int = 15, waiter_max_attempts: int = 60, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -201,7 +195,7 @@ def __init__( wait_for_completion: bool = True, waiter_delay: int = 15, waiter_max_attempts: int = 60, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -482,7 +476,7 @@ def __init__( wait_for_completion: bool = True, waiter_delay: int = 6, waiter_max_attempts: int = 100, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) @@ -727,7 +721,6 @@ def _check_success_task(self) -> None: raise AirflowException(response) for task in response["tasks"]: - if task.get("stopCode", "") == "TaskFailedToStart": # Reset task arn here otherwise the retry run will not start # a new task but keep polling the old dead one diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index bea4223987265..56e9269f88d79 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -25,6 +25,7 @@ from botocore.exceptions import ClientError, WaiterError +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.eks import EksHook @@ -83,7 +84,6 @@ def _create_compute( log = logging.getLogger(__name__) eks_hook = EksHook(aws_conn_id=aws_conn_id, region_name=region) if compute == "nodegroup" and nodegroup_name: - # this is to satisfy mypy subnets = subnets or [] create_nodegroup_kwargs = create_nodegroup_kwargs or {} @@ -107,7 +107,6 @@ def _create_compute( status_args=["nodegroup.status"], ) elif compute == "fargate" and fargate_profile_name: - # this is to satisfy mypy create_fargate_profile_kwargs = create_fargate_profile_kwargs or {} fargate_selectors = fargate_selectors or [] @@ -366,7 +365,7 @@ def __init__( region: str | None = None, waiter_delay: int = 30, waiter_max_attempts: int = 80, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: self.nodegroup_subnets = nodegroup_subnets @@ -489,7 +488,7 @@ def __init__( region: str | None = None, waiter_delay: int = 10, waiter_max_attempts: int = 60, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: self.cluster_name = cluster_name @@ -690,7 +689,7 @@ def __init__( region: str | None = None, waiter_delay: int = 30, waiter_max_attempts: int = 40, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: self.cluster_name = cluster_name @@ -780,7 +779,7 @@ def __init__( region: str | None = None, waiter_delay: int = 30, waiter_max_attempts: int = 60, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index f9eacdb79d191..8330a586e4426 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Sequence from uuid import uuid4 +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook @@ -96,7 +97,7 @@ def __init__( waiter_delay: int | None = None, waiter_max_attempts: int | None = None, execution_role_arn: str | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): if not exactly_one(job_flow_id is None, job_flow_name is None): @@ -510,7 +511,7 @@ def __init__( max_tries: int | None = None, tags: dict | None = None, max_polling_attempts: int | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -695,7 +696,7 @@ def __init__( waiter_delay: int | None | ArgNotSet = NOTSET, waiter_countdown: int | None = None, waiter_check_interval_seconds: int = 60, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ): if waiter_max_attempts is NOTSET: @@ -900,7 +901,7 @@ def __init__( aws_conn_id: str = "aws_default", waiter_delay: int = 60, waiter_max_attempts: int = 20, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 060ac358a40c2..265d057de51ae 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Sequence from airflow import AirflowException +from airflow.configuration import conf from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -98,7 +99,7 @@ def __init__( create_job_kwargs: dict | None = None, run_job_kwargs: dict | None = None, wait_for_completion: bool = True, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), verbose: bool = False, update_config: bool = False, job_poll_interval: int | float = 6, diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py index a7efb9f5c0cd2..71e2607039c35 100644 --- a/airflow/providers/amazon/aws/operators/glue_crawler.py +++ b/airflow/providers/amazon/aws/operators/glue_crawler.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Sequence from airflow import AirflowException +from airflow.configuration import conf from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger if TYPE_CHECKING: @@ -61,7 +62,7 @@ def __init__( region_name: str | None = None, poll_interval: int = 5, wait_for_completion: bool = True, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py index 9aef6701660ed..c58961db2e8f2 100644 --- a/airflow/providers/amazon/aws/operators/rds.py +++ b/airflow/providers/amazon/aws/operators/rds.py @@ -24,6 +24,7 @@ from mypy_boto3_rds.type_defs import TagTypeDef +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.rds import RdsHook @@ -554,7 +555,7 @@ def __init__( rds_kwargs: dict | None = None, aws_conn_id: str = "aws_default", wait_for_completion: bool = True, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), waiter_delay: int = 30, waiter_max_attempts: int = 60, **kwargs, @@ -645,7 +646,7 @@ def __init__( rds_kwargs: dict | None = None, aws_conn_id: str = "aws_default", wait_for_completion: bool = True, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), waiter_delay: int = 30, waiter_max_attempts: int = 60, **kwargs, diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 905c34ff3ac56..cde4a32226e91 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -20,6 +20,7 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook @@ -148,7 +149,7 @@ def __init__( wait_for_completion: bool = False, max_attempt: int = 5, poll_interval: int = 60, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) @@ -327,7 +328,7 @@ def __init__( poll_interval: int = 15, max_attempt: int = 20, aws_conn_id: str = "aws_default", - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) @@ -470,7 +471,7 @@ def __init__( cluster_identifier: str, aws_conn_id: str = "aws_default", wait_for_completion: bool = False, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: int = 10, max_attempts: int = 10, **kwargs, @@ -560,7 +561,7 @@ def __init__( *, cluster_identifier: str, aws_conn_id: str = "aws_default", - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: int = 10, max_attempts: int = 15, **kwargs, @@ -647,7 +648,7 @@ def __init__( wait_for_completion: bool = True, aws_conn_id: str = "aws_default", poll_interval: int = 30, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), max_attempts: int = 30, **kwargs, ): @@ -668,7 +669,6 @@ def __init__( self.max_attempts = max_attempts def execute(self, context: Context): - while self._attempts >= 1: try: self.redshift_hook.delete_cluster( diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 4dac7df00790f..ac1b7a73d2de6 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -25,6 +25,7 @@ from botocore.exceptions import ClientError +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -198,7 +199,7 @@ def __init__( max_attempts: int | None = None, max_ingestion_time: int | None = None, action_if_job_exists: str = "timestamp", - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) @@ -392,7 +393,7 @@ def __init__( check_interval: int = CHECK_INTERVAL_SECOND, max_ingestion_time: int | None = None, operation: str = "create", - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) @@ -551,7 +552,7 @@ def __init__( max_ingestion_time: int | None = None, check_if_job_exists: bool = True, action_if_job_exists: str = "timestamp", - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) @@ -700,7 +701,7 @@ def __init__( wait_for_completion: bool = True, check_interval: int = CHECK_INTERVAL_SECOND, max_ingestion_time: int | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) @@ -862,7 +863,7 @@ def __init__( max_ingestion_time: int | None = None, check_if_job_exists: bool = True, action_if_job_exists: str = "timestamp", - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py index 2033d1e86bee8..32da5b4cf2524 100644 --- a/airflow/providers/amazon/aws/sensors/batch.py +++ b/airflow/providers/amazon/aws/sensors/batch.py @@ -22,6 +22,7 @@ from deprecated import deprecated +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger @@ -58,7 +59,7 @@ def __init__( job_id: str, aws_conn_id: str = "aws_default", region_name: str | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poke_interval: float = 5, max_retries: int = 5, **kwargs, diff --git a/airflow/providers/amazon/aws/sensors/ec2.py b/airflow/providers/amazon/aws/sensors/ec2.py index c5d77610319ee..2b7b63f7e6c7a 100644 --- a/airflow/providers/amazon/aws/sensors/ec2.py +++ b/airflow/providers/amazon/aws/sensors/ec2.py @@ -20,6 +20,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger @@ -55,7 +56,7 @@ def __init__( instance_id: str, aws_conn_id: str = "aws_default", region_name: str | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): if target_state not in self.valid_states: diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 2f44caab06196..9953dfa78260c 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -23,6 +23,7 @@ from deprecated import deprecated +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri @@ -271,7 +272,7 @@ def __init__( max_retries: int | None = None, aws_conn_id: str = "aws_default", poll_interval: int = 10, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -425,7 +426,7 @@ def __init__( target_states: Iterable[str] | None = None, failed_states: Iterable[str] | None = None, max_attempts: int = 60, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) @@ -549,7 +550,7 @@ def __init__( target_states: Iterable[str] | None = None, failed_states: Iterable[str] | None = None, max_attempts: int = 60, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 7192585afd58b..4d15cdc2125d3 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -26,6 +26,8 @@ from deprecated import deprecated +from airflow.configuration import conf + if TYPE_CHECKING: from airflow.utils.context import Context @@ -87,7 +89,7 @@ def __init__( check_fn: Callable[..., bool] | None = None, aws_conn_id: str = "aws_default", verify: str | bool | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) @@ -238,10 +240,9 @@ def __init__( min_objects: int = 1, previous_objects: set[str] | None = None, allow_delete: bool = True, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: - super().__init__(**kwargs) self.bucket_name = bucket_name diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index fa4f35734371e..f5e519315f4b9 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -20,6 +20,7 @@ from time import sleep from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook @@ -88,10 +89,9 @@ def __init__( extra_options: dict[str, Any] | None = None, extra_headers: dict[str, Any] | None = None, retry_args: dict[str, Any] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ) -> None: - super().__init__(**kwargs) self.spark_params = { diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 696611c6c2e98..e3ac7708e74cd 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -33,6 +33,7 @@ from slugify import slugify from urllib3.exceptions import HTTPError +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.kubernetes import pod_generator from airflow.kubernetes.pod_generator import PodGenerator @@ -305,7 +306,7 @@ def __init__( configmaps: list[str] | None = None, skip_on_exit_code: int | Container[int] | None = None, base_container_name: str | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 2, log_pod_spec_on_failure: bool = True, on_finish_action: str = "delete_pod", diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index ab93d8f49b7cc..fb27f0c01a118 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -24,6 +24,7 @@ from logging import Logger from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState @@ -315,7 +316,7 @@ def __init__( access_control_list: list[dict[str, str]] | None = None, wait_for_termination: bool = True, git_source: dict[str, str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: """Creates a new ``DatabricksSubmitRunOperator``.""" @@ -605,7 +606,7 @@ def __init__( databricks_retry_args: dict[Any, Any] | None = None, do_xcom_push: bool = True, wait_for_termination: bool = True, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: """Creates a new ``DatabricksRunNowOperator``.""" diff --git a/airflow/providers/dbt/cloud/operators/dbt.py b/airflow/providers/dbt/cloud/operators/dbt.py index f316c47f3db42..c977539afb73a 100644 --- a/airflow/providers/dbt/cloud/operators/dbt.py +++ b/airflow/providers/dbt/cloud/operators/dbt.py @@ -22,6 +22,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.dbt.cloud.hooks.dbt import ( @@ -99,7 +100,7 @@ def __init__( timeout: int = 60 * 60 * 24 * 7, check_interval: int = 60, additional_run_config: dict[str, Any] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/dbt/cloud/sensors/dbt.py b/airflow/providers/dbt/cloud/sensors/dbt.py index 5838f6d6247d2..3b5ae549a35a5 100644 --- a/airflow/providers/dbt/cloud/sensors/dbt.py +++ b/airflow/providers/dbt/cloud/sensors/dbt.py @@ -20,6 +20,7 @@ import warnings from typing import TYPE_CHECKING, Any +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger @@ -50,7 +51,7 @@ def __init__( dbt_cloud_conn_id: str = DbtCloudHook.default_conn_name, run_id: int, account_id: int | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: if deferrable: diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 98929b6e6dec6..970e7813ed9c2 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -29,6 +29,7 @@ from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob from google.cloud.bigquery.table import RowIterator +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.xcom import XCom @@ -200,7 +201,7 @@ def __init__( location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, labels: dict | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 4.0, **kwargs, ) -> None: @@ -320,7 +321,7 @@ def __init__( location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, labels: dict | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 4.0, **kwargs, ) -> None: @@ -460,7 +461,7 @@ def __init__( location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, labels: dict | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 4.0, **kwargs, ) -> None: @@ -854,7 +855,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 4.0, as_dict: bool = False, use_legacy_sql: bool = True, @@ -1876,7 +1877,6 @@ def __init__( exists_ok: bool | None = None, **kwargs, ) -> None: - self.dataset_id = dataset_id self.project_id = project_id self.location = location @@ -2623,7 +2623,7 @@ def __init__( cancel_on_kill: bool = True, result_retry: Retry = DEFAULT_RETRY, result_timeout: float | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 4.0, **kwargs, ) -> None: diff --git a/airflow/providers/google/cloud/operators/bigquery_dts.py b/airflow/providers/google/cloud/operators/bigquery_dts.py index d9e013afa68c3..e10618bc39bdf 100644 --- a/airflow/providers/google/cloud/operators/bigquery_dts.py +++ b/airflow/providers/google/cloud/operators/bigquery_dts.py @@ -32,6 +32,7 @@ ) from airflow import AirflowException +from airflow.configuration import conf from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook, get_object_id from airflow.providers.google.cloud.links.bigquery_dts import BigQueryDataTransferConfigLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator @@ -279,7 +280,7 @@ def __init__( metadata: Sequence[tuple[str, str]] = (), gcp_conn_id="google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/google/cloud/operators/cloud_build.py b/airflow/providers/google/cloud/operators/cloud_build.py index 4242f561c1003..14fed55a3af6b 100644 --- a/airflow/providers/google/cloud/operators/cloud_build.py +++ b/airflow/providers/google/cloud/operators/cloud_build.py @@ -28,6 +28,7 @@ from google.api_core.retry import Retry from google.cloud.devtools.cloudbuild_v1.types import Build, BuildTrigger, RepoSource +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_build import CloudBuildHook from airflow.providers.google.cloud.links.cloud_build import ( @@ -176,7 +177,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, poll_interval: float = 4.0, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), location: str = "global", **kwargs, ) -> None: diff --git a/airflow/providers/google/cloud/operators/cloud_composer.py b/airflow/providers/google/cloud/operators/cloud_composer.py index d04a1606fc84f..c9b52d855915a 100644 --- a/airflow/providers/google/cloud/operators/cloud_composer.py +++ b/airflow/providers/google/cloud/operators/cloud_composer.py @@ -27,6 +27,7 @@ from google.protobuf.field_mask_pb2 import FieldMask from airflow import AirflowException +from airflow.configuration import conf from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook from airflow.providers.google.cloud.links.base import BaseGoogleLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator @@ -135,7 +136,7 @@ def __init__( retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), pooling_period_seconds: int = 30, **kwargs, ) -> None: @@ -264,7 +265,7 @@ def __init__( metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), pooling_period_seconds: int = 30, **kwargs, ) -> None: @@ -509,7 +510,7 @@ def __init__( metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), pooling_period_seconds: int = 30, **kwargs, ) -> None: diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py index 5c77cbd86c948..b1144663c7065 100644 --- a/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/airflow/providers/google/cloud/operators/cloud_sql.py @@ -22,6 +22,7 @@ from googleapiclient.errors import HttpError +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import Connection @@ -955,7 +956,7 @@ def __init__( api_version: str = "v1beta4", validate_body: bool = True, impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poke_interval: int = 10, **kwargs, ) -> None: diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 5ae1115a34ae6..a5e9588214858 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow import AirflowException +from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType from airflow.providers.google.cloud.hooks.dataflow import ( @@ -419,7 +420,6 @@ def set_current_job_id(job_id): variables=pipeline_options, ) while is_running and self.check_if_running == CheckJobRunning.WaitForRun: - is_running = self.dataflow_hook.is_job_dataflow_running( name=self.job_name, variables=pipeline_options, @@ -611,7 +611,7 @@ def __init__( cancel_timeout: int | None = 10 * 60, wait_until_finished: bool | None = None, append_job_name: bool = True, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -801,7 +801,7 @@ def __init__( cancel_timeout: int | None = 10 * 60, wait_until_finished: bool | None = None, impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), append_job_name: bool = True, *args, **kwargs, diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index db7d785347a83..d14a495bc0cbd 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -36,6 +36,7 @@ from google.protobuf.duration_pb2 import Duration from google.protobuf.field_mask_pb2 import FieldMask +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -484,7 +485,7 @@ def __init__( metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, **kwargs, ) -> None: @@ -849,7 +850,7 @@ def __init__( metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, **kwargs, ): @@ -981,7 +982,7 @@ def __init__( job_error_states: set[str] | None = None, impersonation_chain: str | Sequence[str] | None = None, asynchronous: bool = False, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, **kwargs, ) -> None: @@ -1731,7 +1732,7 @@ def __init__( metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, **kwargs, ) -> None: @@ -1859,7 +1860,7 @@ def __init__( metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, **kwargs, ) -> None: @@ -1979,7 +1980,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, asynchronous: bool = False, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, cancel_on_kill: bool = True, wait_timeout: int | None = None, @@ -2139,7 +2140,7 @@ def __init__( metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, **kwargs, ): @@ -2270,7 +2271,7 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, result_retry: Retry | _MethodDefault = DEFAULT, asynchronous: bool = False, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 5, **kwargs, ): diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index bf14828d87027..086a7d99b7e66 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -26,6 +26,7 @@ from google.cloud.container_v1.types import Cluster from kubernetes.client.models import V1Pod +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction @@ -34,6 +35,7 @@ except ImportError: # preserve backward compatibility for older versions of cncf.kubernetes provider from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator + from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEHook, GKEPodHook from airflow.providers.google.cloud.links.kubernetes_engine import ( KubernetesEngineClusterLink, @@ -108,7 +110,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", api_version: str = "v2", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: int = 10, **kwargs, ) -> None: @@ -255,7 +257,7 @@ def __init__( api_version: str = "v2", impersonation_chain: str | Sequence[str] | None = None, poll_interval: int = 10, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py index b776d20dff42d..20d9d19886234 100644 --- a/airflow/providers/google/cloud/operators/mlengine.py +++ b/airflow/providers/google/cloud/operators/mlengine.py @@ -27,6 +27,7 @@ from googleapiclient.errors import HttpError +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook from airflow.providers.google.cloud.links.mlengine import ( @@ -722,7 +723,6 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) self._project_id = project_id self._model_name = model_name @@ -804,7 +804,6 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) self._project_id = project_id self._model_name = model_name @@ -883,7 +882,6 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) self._project_id = project_id self._model_name = model_name @@ -961,7 +959,6 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) self._project_id = project_id self._model_name = model_name @@ -1098,7 +1095,7 @@ def __init__( labels: dict[str, str] | None = None, impersonation_chain: str | Sequence[str] | None = None, hyperparameters: dict | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), cancel_on_kill: bool = True, **kwargs, ) -> None: @@ -1370,7 +1367,6 @@ def __init__( raise AirflowException("Google Cloud project id is required.") def execute(self, context: Context): - hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, impersonation_chain=self._impersonation_chain, diff --git a/airflow/providers/google/cloud/sensors/bigquery.py b/airflow/providers/google/cloud/sensors/bigquery.py index db109bf2c1e89..e4e1819ef1341 100644 --- a/airflow/providers/google/cloud/sensors/bigquery.py +++ b/airflow/providers/google/cloud/sensors/bigquery.py @@ -22,6 +22,7 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook from airflow.providers.google.cloud.triggers.bigquery import ( @@ -71,7 +72,7 @@ def __init__( table_id: str, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: if deferrable and "poke_interval" not in kwargs: @@ -184,7 +185,7 @@ def __init__( partition_id: str, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: if deferrable and "poke_interval" not in kwargs: diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py index 08fd37022dfe7..7048789601e4d 100644 --- a/airflow/providers/google/cloud/sensors/gcs.py +++ b/airflow/providers/google/cloud/sensors/gcs.py @@ -27,6 +27,7 @@ from google.api_core.retry import Retry from google.cloud.storage.retry import DEFAULT_RETRY +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.triggers.gcs import ( @@ -76,10 +77,9 @@ def __init__( google_cloud_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, retry: Retry = DEFAULT_RETRY, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: - super().__init__(**kwargs) self.bucket = bucket self.object = object @@ -208,10 +208,9 @@ def __init__( ts_func: Callable = ts_function, google_cloud_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: - super().__init__(**kwargs) self.bucket = bucket self.object = object @@ -298,7 +297,7 @@ def __init__( prefix: str, google_cloud_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -412,10 +411,9 @@ def __init__( allow_delete: bool = True, google_cloud_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: - super().__init__(**kwargs) self.bucket = bucket diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py index 2e03b3669d3dc..db9f39b19b0b3 100644 --- a/airflow/providers/google/cloud/sensors/pubsub.py +++ b/airflow/providers/google/cloud/sensors/pubsub.py @@ -23,6 +23,7 @@ from google.cloud.pubsub_v1.types import ReceivedMessage +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.pubsub import PubSubHook from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger @@ -103,10 +104,9 @@ def __init__( messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None, impersonation_chain: str | Sequence[str] | None = None, poke_interval: float = 10.0, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: - super().__init__(**kwargs) self.gcp_conn_id = gcp_conn_id self.project_id = project_id diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 7ec62db9bfbd0..8836e3ee354e0 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -25,6 +25,7 @@ from google.cloud.bigquery import DEFAULT_RETRY, UnknownJob from airflow import AirflowException +from airflow.configuration import conf from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink @@ -114,7 +115,7 @@ def __init__( job_id: str | None = None, force_rerun: bool = False, reattach_states: set[str] | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 88b6d09323708..da462eae74df9 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -36,6 +36,7 @@ from google.cloud.bigquery.table import EncryptionConfiguration, Table, TableReference from airflow import AirflowException +from airflow.configuration import conf from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -218,7 +219,7 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, labels=None, description=None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), result_retry: Retry = DEFAULT_RETRY, result_timeout: float | None = None, cancel_on_kill: bool = True, @@ -228,7 +229,6 @@ def __init__( project_id: str | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) self.hook: BigQueryHook | None = None self.configuration: dict[str, Any] = {} @@ -718,7 +718,6 @@ def _validate_src_fmt_configs( def _cleanse_time_partitioning( self, destination_dataset_table: str | None, time_partitioning_in: dict | None ) -> dict: # if it is a partitioned table ($ is in the table name) add partition load option - if time_partitioning_in is None: time_partitioning_in = {} diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py b/airflow/providers/microsoft/azure/operators/data_factory.py index 8906c02ae12c7..a2b2c528bf9b9 100644 --- a/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/airflow/providers/microsoft/azure/operators/data_factory.py @@ -20,6 +20,7 @@ import warnings from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import BaseOperator, BaseOperatorLink, XCom @@ -140,7 +141,7 @@ def __init__( parameters: dict[str, Any] | None = None, timeout: int = 60 * 60 * 24 * 7, check_interval: int = 60, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index 70ae1f69dc3f7..b4ebedce698aa 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -20,6 +20,7 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, @@ -60,7 +61,7 @@ def __init__( azure_data_factory_conn_id: str = AzureDataFactoryHook.default_conn_name, resource_group_name: str | None = None, factory_name: str | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py b/airflow/providers/microsoft/azure/sensors/wasb.py index 4e2ec2d502c17..0c227f2ea37b7 100644 --- a/airflow/providers/microsoft/azure/sensors/wasb.py +++ b/airflow/providers/microsoft/azure/sensors/wasb.py @@ -21,6 +21,7 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.microsoft.azure.hooks.wasb import WasbHook from airflow.providers.microsoft.azure.triggers.wasb import WasbBlobSensorTrigger, WasbPrefixSensorTrigger @@ -53,7 +54,7 @@ def __init__( wasb_conn_id: str = "wasb_default", check_options: dict | None = None, public_read: bool = False, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -151,7 +152,7 @@ def __init__( prefix: str, wasb_conn_id: str = "wasb_default", check_options: dict | None = None, - deferrable: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index b1e6c6be983d0..81526737e8d10 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -56,6 +56,36 @@ Writing a deferrable operator takes a bit more work. There are some main points * You can defer multiple times, and you can defer before/after your Operator does significant work, or only defer if certain conditions are met (e.g. a system does not have an immediate answer). Deferral is entirely under your control. * Any Operator can defer; no special marking on its class is needed, and it's not limited to Sensors. * In order for any changes to a Trigger to be reflected, the *triggerer* needs to be restarted whenever the Trigger is modified. +* If you want add an operator or sensor that supports both deferrable and non-deferrable modes. It's suggested to add ``deferable: bool = conf.getboolean("operators", "default_deferrable", fallback=False)`` to the ``__init__`` method of the operator and use it to decide whether to run the operator in deferrable mode. You'll be able to configure the default value of ``deferrable`` of all the operators and sensors that supports switch between deferrable and non-deferrable mode through ``default_deferrable`` in the ``operator`` section. Here's an example of a sensor that supports both modes.:: + + import time + from datetime import timedelta + + from airflow.sensors.base import BaseSensorOperator + from airflow.triggers.temporal import TimeDeltaTrigger + + + class WaitOneHourSensor(BaseSensorOperator): + def __init__( + self, + deferable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs + ): + super().__init__(**kwargs) + self.deferrable = deferable + + def execute(self, context): + if deferrable: + self.defer( + trigger=TimeDeltaTrigger(timedelta(hours=1)), + method_name="execute_complete" + ) + else: + time.sleep(3600) + + def execute_complete(self, context, event=None): + # We have no more work to do here. Mark as complete. + return Triggering Deferral diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 2b118176721d5..c49e0bb034795 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -919,7 +919,6 @@ def test_render_template_fields_logging( caplog, monkeypatch, task, context, expected_exception, expected_rendering, expected_log, not_expected_log ): """Verify if operator attributes are correctly templated.""" - # Trigger templating and verify results def _do_render(): task.render_template_fields(context=context)