From 37ca318ba1b20eb13be69f173a74e5b9156e264f Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Mon, 6 Mar 2023 00:25:08 -0800 Subject: [PATCH 1/3] Change base_aws.py to support async_conn Add async custom waiter support in get_waiter, and base_waiter.py Add Deferrable mode to RedshiftCreateClusterOperator Add RedshiftCreateClusterTrigger and unit test Add README.md for writing Triggers for AMPP --- airflow/providers/amazon/aws/hooks/base_aws.py | 9 +++++++++ .../providers/amazon/aws/triggers/redshift_cluster.py | 4 +++- airflow/providers/amazon/aws/waiters/base_waiter.py | 1 + 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 541ef37d312a..5b29c7c0dedd 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -42,6 +42,7 @@ import jinja2 import requests import tenacity +from aiobotocore.session import AioSession, get_session as async_get_session from botocore.client import ClientMeta from botocore.config import Config from botocore.credentials import ReadOnlyCredentials @@ -658,6 +659,14 @@ def async_conn(self): return self.get_client_type(region_name=self.region_name, deferrable=True) + @cached_property + def async_conn(self): + """Get an Aiobotocore client to use for async operations (cached).""" + if not self.client_type: + raise ValueError("client_type must be specified.") + + return self.get_client_type(region_name=self.region_name, deferrable=True) + @cached_property def conn_client_meta(self) -> ClientMeta: """Get botocore client metadata from Hook connection (cached).""" diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index ef19d0b5a1d6..7d39011946f8 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -21,6 +21,9 @@ from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook from airflow.triggers.base import BaseTrigger, TriggerEvent +from typing import Any + +from airflow.compat.functools import cached_property class RedshiftClusterTrigger(BaseTrigger): @@ -87,7 +90,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]: if self.attempts < 1: yield TriggerEvent({"status": "error", "message": str(e)}) - class RedshiftCreateClusterTrigger(BaseTrigger): """ Trigger for RedshiftCreateClusterOperator. diff --git a/airflow/providers/amazon/aws/waiters/base_waiter.py b/airflow/providers/amazon/aws/waiters/base_waiter.py index 488767a084a2..0662c049a96f 100644 --- a/airflow/providers/amazon/aws/waiters/base_waiter.py +++ b/airflow/providers/amazon/aws/waiters/base_waiter.py @@ -18,6 +18,7 @@ from __future__ import annotations import boto3 +from aiobotocore.waiter import create_waiter_with_client as create_async_waiter_with_client from botocore.waiter import Waiter, WaiterModel, create_waiter_with_client From e60898bbb596f29605005e0d1dcb534ccdd7e2cd Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 22 Mar 2023 05:33:36 -0700 Subject: [PATCH 2/3] Add deferrable mode to redshift delete cluster --- .../amazon/aws/operators/redshift_cluster.py | 31 ++++++++++++--- .../amazon/aws/triggers/redshift_cluster.py | 36 +++++++++++++++++ .../aws/triggers/test_redshift_cluster.py | 39 ++++++++++++++++++- 3 files changed, 99 insertions(+), 7 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 77ac521c9baf..c84551bfc1d4 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -22,10 +22,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook -from airflow.providers.amazon.aws.triggers.redshift_cluster import ( - RedshiftClusterTrigger, - RedshiftCreateClusterTrigger, -) +from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger, RedshiftCreateClusterTrigger, RedshiftDeleteClusterTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -595,6 +592,9 @@ class RedshiftDeleteClusterOperator(BaseOperator): The default value is ``True`` :param aws_conn_id: aws connection to use :param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state + :param max_attempts: Number of attempts the cluster should be polled to detemine the cluster + was deleted. + :param deferrable: If True, the operator will run as a deferrable operator. """ template_fields: Sequence[str] = ("cluster_identifier",) @@ -609,7 +609,9 @@ def __init__( final_cluster_snapshot_identifier: str | None = None, wait_for_completion: bool = True, aws_conn_id: str = "aws_default", - poll_interval: float = 30.0, + poll_interval: int = 30, + max_attempts: int = 20, + deferrable: bool = True, **kwargs, ): super().__init__(**kwargs) @@ -624,6 +626,9 @@ def __init__( self._attempts = 10 self._attempt_interval = 15 self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id) + self.deferrable = deferrable + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id def execute(self, context: Context): while self._attempts >= 1: @@ -642,10 +647,24 @@ def execute(self, context: Context): time.sleep(self._attempt_interval) else: raise + if self.deferrable: + self.defer( + trigger=RedshiftDeleteClusterTrigger( + cluster_identifier=self.cluster_identifier, + poll_interval=self.poll_interval, + max_attempts=self.max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) if self.wait_for_completion: waiter = self.redshift_hook.get_conn().get_waiter("cluster_deleted") waiter.wait( ClusterIdentifier=self.cluster_identifier, - WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": 30}, + WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": self.max_attempts}, ) + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error deleting cluster: {event}") + return diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index 7d39011946f8..8197b45ebeab 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -139,3 +139,39 @@ async def run(self): }, ) yield TriggerEvent({"status": "success", "message": "Cluster Created"}) + +class RedshiftDeleteClusterTrigger(BaseTrigger): + def __init__( + self, + cluster_identifier: str, + poll_interval: int, + max_attempts: int, + aws_conn_id: str, + ): + self.cluster_identifier = cluster_identifier + self.poll_interval = poll_interval + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger", + { + "cluster_identifier": str(self.cluster_identifier), + "poll_interval": str(self.poll_interval), + "max_attempts": str(self.max_attempts), + "aws_conn_id": str(self.aws_conn_id), + }, + ) + + async def run(self): + self.redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + async with self.redshift_hook.async_conn as client: + await client.get_waiter("cluster_deleted").wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": int(self.max_attempts), + }, + ) + yield TriggerEvent({"status": "success", "message": "Cluster deleted"}) diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index 941258659e9a..7a95c54fbec2 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -20,7 +20,7 @@ import pytest -from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftCreateClusterTrigger +from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftCreateClusterTrigger, RedshiftDeleteClusterTrigger from airflow.triggers.base import TriggerEvent if sys.version_info < (3, 8): @@ -72,3 +72,40 @@ async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): response = await generator.asend(None) assert response == TriggerEvent({"status": "success", "message": "Cluster Created"}) + +class TestRedshiftDeleteClusterTrigger: + def test_redshift_create_cluster_trigger_serialize(self): + redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + class_path, args = redshift_delete_cluster_trigger.serialize() + assert ( + class_path + == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger" + ) + assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPT) + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") + async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + mock.get_waiter().wait = AsyncMock() + + redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_delete_cluster_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Cluster deleted"}) \ No newline at end of file From cc1c7dd97127d7f335f4f5834dfd5fb40652d1ee Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Tue, 25 Apr 2023 09:54:33 -0700 Subject: [PATCH 3/3] Rebase on main Add/update some unit tests --- .../providers/amazon/aws/hooks/base_aws.py | 9 -------- .../amazon/aws/operators/redshift_cluster.py | 13 ++++++++--- .../amazon/aws/waiters/base_waiter.py | 1 - .../aws/operators/test_redshift_cluster.py | 23 ++++++++++++++++++- 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 5b29c7c0dedd..541ef37d312a 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -42,7 +42,6 @@ import jinja2 import requests import tenacity -from aiobotocore.session import AioSession, get_session as async_get_session from botocore.client import ClientMeta from botocore.config import Config from botocore.credentials import ReadOnlyCredentials @@ -659,14 +658,6 @@ def async_conn(self): return self.get_client_type(region_name=self.region_name, deferrable=True) - @cached_property - def async_conn(self): - """Get an Aiobotocore client to use for async operations (cached).""" - if not self.client_type: - raise ValueError("client_type must be specified.") - - return self.get_client_type(region_name=self.region_name, deferrable=True) - @cached_property def conn_client_meta(self) -> ClientMeta: """Get botocore client metadata from Hook connection (cached).""" diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index c84551bfc1d4..70a540557d1f 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -22,7 +22,11 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook -from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger, RedshiftCreateClusterTrigger, RedshiftDeleteClusterTrigger +from airflow.providers.amazon.aws.triggers.redshift_cluster import ( + RedshiftClusterTrigger, + RedshiftCreateClusterTrigger, + RedshiftDeleteClusterTrigger, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -592,7 +596,7 @@ class RedshiftDeleteClusterOperator(BaseOperator): The default value is ``True`` :param aws_conn_id: aws connection to use :param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state - :param max_attempts: Number of attempts the cluster should be polled to detemine the cluster + :param max_attempts: Number of attempts the cluster should be polled to determine the cluster was deleted. :param deferrable: If True, the operator will run as a deferrable operator. """ @@ -611,7 +615,7 @@ def __init__( aws_conn_id: str = "aws_default", poll_interval: int = 30, max_attempts: int = 20, - deferrable: bool = True, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -664,7 +668,10 @@ def execute(self, context: Context): ClusterIdentifier=self.cluster_identifier, WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": self.max_attempts}, ) + def execute_complete(self, context, event=None): if event["status"] != "success": raise AirflowException(f"Error deleting cluster: {event}") + else: + self.log.info("Cluster deleted successfully") return diff --git a/airflow/providers/amazon/aws/waiters/base_waiter.py b/airflow/providers/amazon/aws/waiters/base_waiter.py index 0662c049a96f..488767a084a2 100644 --- a/airflow/providers/amazon/aws/waiters/base_waiter.py +++ b/airflow/providers/amazon/aws/waiters/base_waiter.py @@ -18,7 +18,6 @@ from __future__ import annotations import boto3 -from aiobotocore.waiter import create_waiter_with_client as create_async_waiter_with_client from botocore.waiter import Waiter, WaiterModel, create_waiter_with_client diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index 64a276f14d02..c8bad24b91bb 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -31,7 +31,10 @@ RedshiftPauseClusterOperator, RedshiftResumeClusterOperator, ) -from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger +from airflow.providers.amazon.aws.triggers.redshift_cluster import ( + RedshiftClusterTrigger, + RedshiftDeleteClusterTrigger, +) class TestRedshiftCreateClusterOperator: @@ -481,3 +484,21 @@ def test_delete_cluster_multiple_attempts_fail(self, _, mock_conn, mock_delete_c redshift_operator.execute(None) assert mock_delete_cluster.call_count == 10 + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.delete_cluster") + def test_delete_cluster_deferrable(self, mock_delete_cluster): + mock_delete_cluster.return_value = True + + redshift_operator = RedshiftDeleteClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + aws_conn_id="aws_conn_test", + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + redshift_operator.execute(None) + + assert isinstance( + exc.value.trigger, RedshiftDeleteClusterTrigger + ), "Trigger is not a RedshiftDeleteClusterTrigger"