Skip to content

Commit

Permalink
Improve DataprocCreateClusterOperator Triggers for Better Error Han…
Browse files Browse the repository at this point in the history
…dling and Resource Cleanup (#39130)
  • Loading branch information
sunank200 authored Apr 26, 2024
1 parent 0c96b06 commit bea1b7f
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 29 deletions.
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ def execute(self, context: Context) -> dict:
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
delete_on_error=self.delete_on_error,
),
method_name="execute_complete",
)
Expand Down
92 changes: 77 additions & 15 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from typing import Any, AsyncIterator, Sequence

from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus

from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -43,20 +44,32 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval_seconds: int = 30,
delete_on_error: bool = True,
):
super().__init__()
self.region = region
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_interval_seconds = polling_interval_seconds
self.delete_on_error = delete_on_error

def get_async_hook(self):
return DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

def get_sync_hook(self):
# The synchronous hook is utilized to delete the cluster when a task is cancelled.
# This is because the asynchronous hook deletion is not awaited when the trigger task
# is cancelled. The call for deleting the cluster through the sync hook is not a blocking
# call, which means it does not wait until the cluster is deleted.
return DataprocHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)


class DataprocSubmitTrigger(DataprocBaseTrigger):
"""
Expand Down Expand Up @@ -140,24 +153,73 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
"delete_on_error": self.delete_on_error,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
while True:
cluster = await self.get_async_hook().get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
try:
while True:
cluster = await self.fetch_cluster()
state = cluster.status.state
if state == ClusterStatus.State.ERROR:
await self.delete_when_error_occurred(cluster)
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
"cluster_state": ClusterStatus.State.DELETING,
"cluster": cluster,
}
)
return
elif state == ClusterStatus.State.RUNNING:
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
"cluster_state": state,
"cluster": cluster,
}
)
return
self.log.info("Current state is %s", state)
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
except asyncio.CancelledError:
try:
if self.delete_on_error:
self.log.info("Deleting cluster %s.", self.cluster_name)
# The synchronous hook is utilized to delete the cluster when a task is cancelled.
# This is because the asynchronous hook deletion is not awaited when the trigger task
# is cancelled. The call for deleting the cluster through the sync hook is not a blocking
# call, which means it does not wait until the cluster is deleted.
self.get_sync_hook().delete_cluster(
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
self.log.info("Deleted cluster %s during cancellation.", self.cluster_name)
except Exception as e:
self.log.error("Error during cancellation handling: %s", e)
raise AirflowException("Error during cancellation handling: %s", e)

async def fetch_cluster(self) -> Cluster:
"""Fetch the cluster status."""
return await self.get_async_hook().get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)

async def delete_when_error_occurred(self, cluster: Cluster) -> None:
"""
Delete the cluster on error.
:param cluster: The cluster to delete.
"""
if self.delete_on_error:
self.log.info("Deleting cluster %s.", self.cluster_name)
await self.get_async_hook().delete_cluster(
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
state = cluster.status.state
self.log.info("Dataproc cluster: %s is in state: %s", self.cluster_name, state)
if state in (
ClusterStatus.State.ERROR,
ClusterStatus.State.RUNNING,
):
break
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster})
self.log.info("Cluster %s has been deleted.", self.cluster_name)
else:
self.log.info("Cluster %s is not deleted as delete_on_error is set to False.", self.cluster_name)


class DataprocBatchTrigger(DataprocBaseTrigger):
Expand Down
125 changes: 111 additions & 14 deletions tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from unittest import mock

import pytest
from google.cloud.dataproc_v1 import Batch, ClusterStatus
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus
from google.protobuf.any_pb2 import Any
from google.rpc.status_pb2 import Status

Expand Down Expand Up @@ -70,6 +70,7 @@ def batch_trigger():
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)
return trigger

Expand All @@ -96,6 +97,7 @@ def diagnose_operation_trigger():
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)


Expand Down Expand Up @@ -147,6 +149,7 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, c
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": None,
"polling_interval_seconds": TEST_POLL_INTERVAL,
"delete_on_error": True,
}

@pytest.mark.asyncio
Expand Down Expand Up @@ -175,27 +178,37 @@ async def test_async_cluster_triggers_on_success_should_execute_successfully(

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
@mock.patch(
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
return_value=asyncio.Future(),
)
@mock.patch("google.auth.default")
async def test_async_cluster_trigger_run_returns_error_event(
self, mock_hook, cluster_trigger, async_get_cluster
self, mock_auth, mock_delete_cluster, mock_get_cluster, cluster_trigger, async_get_cluster, caplog
):
mock_hook.return_value = async_get_cluster(
mock_credentials = mock.MagicMock()
mock_credentials.universe_domain = "googleapis.com"

mock_auth.return_value = (mock_credentials, "project-id")

mock_delete_cluster.return_value = asyncio.Future()
mock_delete_cluster.return_value.set_result(None)

mock_get_cluster.return_value = async_get_cluster(
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
cluster_name=TEST_CLUSTER_NAME,
status=ClusterStatus(state=ClusterStatus.State.ERROR),
)

actual_event = await cluster_trigger.run().asend(None)
await asyncio.sleep(0.5)
caplog.set_level(logging.INFO)

expected_event = TriggerEvent(
{
"cluster_name": TEST_CLUSTER_NAME,
"cluster_state": ClusterStatus.State.ERROR,
"cluster": actual_event.payload["cluster"],
}
)
assert expected_event == actual_event
trigger_event = None
async for event in cluster_trigger.run():
trigger_event = event

assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME
assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
Expand All @@ -215,9 +228,93 @@ async def test_cluster_run_loop_is_still_running(
await asyncio.sleep(0.5)

assert not task.done()
assert f"Current state is: {ClusterStatus.State.CREATING}"
assert f"Current state is: {ClusterStatus.State.CREATING}."
assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
async def test_cluster_trigger_cancellation_handling(
self, mock_get_sync_hook, mock_get_async_hook, caplog
):
cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
mock_get_async_hook.return_value.get_cluster.return_value = asyncio.Future()
mock_get_async_hook.return_value.get_cluster.return_value.set_result(cluster)

mock_delete_cluster = mock.MagicMock()
mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster

cluster_trigger = DataprocClusterTrigger(
cluster_name="cluster_name",
project_id="project-id",
region="region",
gcp_conn_id="google_cloud_default",
impersonation_chain=None,
polling_interval_seconds=5,
delete_on_error=True,
)

cluster_trigger_gen = cluster_trigger.run()

try:
await cluster_trigger_gen.__anext__()
await cluster_trigger_gen.aclose()

except asyncio.CancelledError:
# Verify that cancellation was handled as expected
if cluster_trigger.delete_on_error:
mock_get_sync_hook.assert_called_once()
mock_delete_cluster.assert_called_once_with(
region=cluster_trigger.region,
cluster_name=cluster_trigger.cluster_name,
project_id=cluster_trigger.project_id,
)
assert "Deleting cluster" in caplog.text
assert "Deleted cluster" in caplog.text
else:
mock_delete_cluster.assert_not_called()
except Exception as e:
pytest.fail(f"Unexpected exception raised: {e}")

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, async_get_cluster):
mock_get_cluster.return_value = async_get_cluster(
status=ClusterStatus(state=ClusterStatus.State.RUNNING)
)
cluster = await cluster_trigger.fetch_cluster()

assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING"

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster")
async def test_delete_when_error_occurred(self, mock_delete_cluster, cluster_trigger):
mock_cluster = mock.MagicMock(spec=Cluster)
type(mock_cluster).status = mock.PropertyMock(
return_value=mock.MagicMock(state=ClusterStatus.State.ERROR)
)

mock_delete_future = asyncio.Future()
mock_delete_future.set_result(None)
mock_delete_cluster.return_value = mock_delete_future

cluster_trigger.delete_on_error = True

await cluster_trigger.delete_when_error_occurred(mock_cluster)

mock_delete_cluster.assert_called_once_with(
region=cluster_trigger.region,
cluster_name=cluster_trigger.cluster_name,
project_id=cluster_trigger.project_id,
)

mock_delete_cluster.reset_mock()
cluster_trigger.delete_on_error = False

await cluster_trigger.delete_when_error_occurred(mock_cluster)

mock_delete_cluster.assert_not_called()


@pytest.mark.db_test
class TestDataprocBatchTrigger:
Expand Down

0 comments on commit bea1b7f

Please sign in to comment.