Skip to content

Commit

Permalink
Fix logic to cancel the external job if the TaskInstance is not in a …
Browse files Browse the repository at this point in the history
…running or deferred state for DataprocCreateClusterOperator (#39446)
  • Loading branch information
sunank200 authored May 8, 2024
1 parent 9a50475 commit 3d575fe
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
44 changes: 42 additions & 2 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@
import asyncio
import re
import time
from typing import Any, AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence

from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus

from airflow.exceptions import AirflowException
from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session


class DataprocBaseTrigger(BaseTrigger):
Expand Down Expand Up @@ -178,6 +184,36 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
query = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
task_instance = query.one_or_none()
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_instance

def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
return task_instance.state != TaskInstanceState.DEFERRED

async def run(self) -> AsyncIterator[TriggerEvent]:
try:
while True:
Expand Down Expand Up @@ -207,7 +243,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
await asyncio.sleep(self.polling_interval_seconds)
except asyncio.CancelledError:
try:
if self.delete_on_error:
if self.delete_on_error and self.safe_to_cancel():
self.log.info(
"Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in "
"deferred state."
)
self.log.info("Deleting cluster %s.", self.cluster_name)
# The synchronous hook is utilized to delete the cluster when a task is cancelled.
# This is because the asynchronous hook deletion is not awaited when the trigger task
Expand Down
35 changes: 34 additions & 1 deletion tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import asyncio
import logging
from asyncio import Future
from asyncio import CancelledError, Future, sleep
from unittest import mock

import pytest
Expand Down Expand Up @@ -60,6 +60,7 @@ def cluster_trigger():
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)


Expand Down Expand Up @@ -328,6 +329,38 @@ async def test_delete_when_error_occurred(self, mock_delete_cluster, cluster_tri

mock_delete_cluster.assert_not_called()

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.safe_to_cancel")
async def test_cluster_trigger_run_cancelled_not_safe_to_cancel(
self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, cluster_trigger
):
"""Test the trigger's cancellation behavior when it is not safe to cancel."""
mock_safe_to_cancel.return_value = False
cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
future_cluster = asyncio.Future()
future_cluster.set_result(cluster)
mock_get_async_hook.return_value.get_cluster.return_value = future_cluster

mock_delete_cluster = mock.MagicMock()
mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster

cluster_trigger.delete_on_error = True

async_gen = cluster_trigger.run()
task = asyncio.create_task(async_gen.__anext__())
await sleep(0)
task.cancel()

try:
await task
except CancelledError:
pass

assert mock_delete_cluster.call_count == 0
mock_delete_cluster.assert_not_called()


@pytest.mark.db_test
class TestDataprocBatchTrigger:
Expand Down

0 comments on commit 3d575fe

Please sign in to comment.