From b359c7261a3e46edefc96ff7b3f2b4cb0d6ada18 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 8 May 2024 14:40:42 +0545 Subject: [PATCH] Fix logic to cancel the external job if the TaskInstance is not in a running or deferred state for DataprocSubmitJobOperator (#39447) --- .../google/cloud/triggers/dataproc.py | 41 ++++++++++++++++++- .../google/cloud/triggers/test_dataproc.py | 8 +++- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 939e5bbcac716..99800d266a86a 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -116,6 +116,41 @@ def serialize(self): }, ) + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + """ + Get the task instance for the current task. + + :param session: Sqlalchemy session + """ + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_instance + + def safe_to_cancel(self) -> bool: + """ + Whether it is safe to cancel the external job which is being executed by this trigger. + + This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. + Because in those cases, we should NOT cancel the external job. + """ + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state != TaskInstanceState.DEFERRED + async def run(self): try: while True: @@ -131,7 +166,11 @@ async def run(self): except asyncio.CancelledError: self.log.info("Task got cancelled.") try: - if self.job_id and self.cancel_on_kill: + if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): + self.log.info( + "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not" + " in deferred state." + ) self.log.info("Cancelling the job: %s", self.job_id) # The synchronous hook is utilized to delete the cluster when a task is cancelled. This # is because the asynchronous hook deletion is not awaited when the trigger task is diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 08294a5ac59d2..39ed949463c4b 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -124,6 +124,7 @@ def submit_trigger(): region=TEST_REGION, gcp_conn_id=TEST_GCP_CONN_ID, polling_interval_seconds=TEST_POLL_INTERVAL, + cancel_on_kill=True, ) @@ -569,12 +570,15 @@ async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigge assert event.payload == expected_event.payload @pytest.mark.asyncio + @pytest.mark.parametrize("is_safe_to_cancel", [True, False]) @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.safe_to_cancel") async def test_submit_trigger_run_cancelled( - self, mock_get_sync_hook, mock_get_async_hook, submit_trigger + self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_trigger, is_safe_to_cancel ): """Test the trigger correctly handles an asyncio.CancelledError.""" + mock_safe_to_cancel.return_value = is_safe_to_cancel mock_async_hook = mock_get_async_hook.return_value mock_async_hook.get_job.side_effect = asyncio.CancelledError @@ -598,7 +602,7 @@ async def test_submit_trigger_run_cancelled( pytest.fail(f"Unexpected exception raised: {e}") # Check if cancel_job was correctly called - if submit_trigger.cancel_on_kill: + if submit_trigger.cancel_on_kill and is_safe_to_cancel: mock_sync_hook.cancel_job.assert_called_once_with( job_id=submit_trigger.job_id, project_id=submit_trigger.project_id,