Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Deferrable Mode to RedshiftDeleteClusterOperator #30870

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterTrigger,
RedshiftDeleteClusterTrigger,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -595,6 +596,9 @@ class RedshiftDeleteClusterOperator(BaseOperator):
The default value is ``True``
:param aws_conn_id: aws connection to use
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state
:param max_attempts: Number of attempts the cluster should be polled to determine the cluster
was deleted.
:param deferrable: If True, the operator will run as a deferrable operator.
"""

template_fields: Sequence[str] = ("cluster_identifier",)
Expand All @@ -609,7 +613,9 @@ def __init__(
final_cluster_snapshot_identifier: str | None = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
poll_interval: float = 30.0,
poll_interval: int = 30,
max_attempts: int = 20,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -624,6 +630,9 @@ def __init__(
self._attempts = 10
self._attempt_interval = 15
self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id)
self.deferrable = deferrable
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
while self._attempts >= 1:
Expand All @@ -642,10 +651,27 @@ def execute(self, context: Context):
time.sleep(self._attempt_interval)
else:
raise
if self.deferrable:
self.defer(
trigger=RedshiftDeleteClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
)

if self.wait_for_completion:
waiter = self.redshift_hook.get_conn().get_waiter("cluster_deleted")
waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": 30},
WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": self.max_attempts},
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error deleting cluster: {event}")
else:
self.log.info("Cluster deleted successfully")
return
40 changes: 39 additions & 1 deletion airflow/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from typing import Any

from airflow.compat.functools import cached_property


class RedshiftClusterTrigger(BaseTrigger):
Expand Down Expand Up @@ -87,7 +90,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
if self.attempts < 1:
yield TriggerEvent({"status": "error", "message": str(e)})


class RedshiftCreateClusterTrigger(BaseTrigger):
"""
Trigger for RedshiftCreateClusterOperator.
Expand Down Expand Up @@ -137,3 +139,39 @@ async def run(self):
},
)
yield TriggerEvent({"status": "success", "message": "Cluster Created"})

class RedshiftDeleteClusterTrigger(BaseTrigger):
def __init__(
self,
cluster_identifier: str,
poll_interval: int,
max_attempts: int,
aws_conn_id: str,
):
self.cluster_identifier = cluster_identifier
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger",
{
"cluster_identifier": str(self.cluster_identifier),
"poll_interval": str(self.poll_interval),
"max_attempts": str(self.max_attempts),
"aws_conn_id": str(self.aws_conn_id),
},
)

async def run(self):
self.redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
async with self.redshift_hook.async_conn as client:
await client.get_waiter("cluster_deleted").wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": int(self.poll_interval),
"MaxAttempts": int(self.max_attempts),
},
)
yield TriggerEvent({"status": "success", "message": "Cluster deleted"})
23 changes: 22 additions & 1 deletion tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
)
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftDeleteClusterTrigger,
)


class TestRedshiftCreateClusterOperator:
Expand Down Expand Up @@ -481,3 +484,21 @@ def test_delete_cluster_multiple_attempts_fail(self, _, mock_conn, mock_delete_c
redshift_operator.execute(None)

assert mock_delete_cluster.call_count == 10

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.delete_cluster")
def test_delete_cluster_deferrable(self, mock_delete_cluster):
mock_delete_cluster.return_value = True

redshift_operator = RedshiftDeleteClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
deferrable=True,
)

with pytest.raises(TaskDeferred) as exc:
redshift_operator.execute(None)

assert isinstance(
exc.value.trigger, RedshiftDeleteClusterTrigger
), "Trigger is not a RedshiftDeleteClusterTrigger"
39 changes: 38 additions & 1 deletion tests/providers/amazon/aws/triggers/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pytest

from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftCreateClusterTrigger
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftCreateClusterTrigger, RedshiftDeleteClusterTrigger
from airflow.triggers.base import TriggerEvent

if sys.version_info < (3, 8):
Expand Down Expand Up @@ -72,3 +72,40 @@ async def test_redshift_create_cluster_trigger_run(self, mock_async_conn):
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "message": "Cluster Created"})

class TestRedshiftDeleteClusterTrigger:
def test_redshift_create_cluster_trigger_serialize(self):
redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)
class_path, args = redshift_delete_cluster_trigger.serialize()
assert (
class_path
== "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger"
)
assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER
assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
assert args["max_attempts"] == str(TEST_MAX_ATTEMPT)
assert args["aws_conn_id"] == TEST_AWS_CONN_ID

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn")
async def test_redshift_create_cluster_trigger_run(self, mock_async_conn):
mock = async_mock.MagicMock()
mock_async_conn.__aenter__.return_value = mock
mock.get_waiter().wait = AsyncMock()

redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)

generator = redshift_delete_cluster_trigger.run()
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "message": "Cluster deleted"})