diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 0049143aea46..1e9848d7670a 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -1415,7 +1415,7 @@ def cancel_job( location: Optional[str] = None, ) -> None: """ - Cancels a job an wait for cancellation to complete + Cancel a job and wait for cancellation to complete :param job_id: id of the job. :param project_id: Google Cloud Project where the job is running diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 550c3174068e..e86c19e83b40 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -2139,7 +2139,7 @@ def _submit_job( hook: BigQueryHook, job_id: str, ) -> BigQueryJob: - # Submit a new job and wait for it to complete and get the result. + # Submit a new job without waiting for it to complete. return hook.insert_job( configuration=self.configuration, project_id=self.project_id, @@ -2147,6 +2147,7 @@ def _submit_job( job_id=job_id, timeout=self.result_timeout, retry=self.result_retry, + nowait=True, ) @staticmethod @@ -2174,7 +2175,6 @@ def execute(self, context: Any): try: self.log.info("Executing: %s'", self.configuration) job = self._submit_job(hook, job_id) - self._handle_job_error(job) except Conflict: # If the job already exists retrieve it job = hook.get_job( @@ -2182,11 +2182,7 @@ def execute(self, context: Any): location=self.location, job_id=job_id, ) - if job.state in self.reattach_states: - # We are reattaching to a job - job.result(timeout=self.result_timeout, retry=self.result_retry) - self._handle_job_error(job) - else: + if job.state not in self.reattach_states: # Same job configuration so we need force_rerun raise AirflowException( f"Job with id: {job_id} already exists and is in {job.state} state. If you " @@ -2221,10 +2217,16 @@ def execute(self, context: Any): BigQueryTableLink.persist(**persist_kwargs) self.job_id = job.job_id - return job.job_id + # Wait for the job to complete + job.result(timeout=self.result_timeout, retry=self.result_retry) + self._handle_job_error(job) + + return self.job_id def on_kill(self) -> None: if self.job_id and self.cancel_on_kill: self.hook.cancel_job( # type: ignore[union-attr] job_id=self.job_id, project_id=self.project_id, location=self.location ) + else: + self.log.info('Skipping to cancel job: %s:%s.%s', self.project_id, self.location, self.job_id) diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 7be855a9a0e6..494749282b6a 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -23,7 +23,7 @@ from google.cloud.bigquery import DEFAULT_RETRY from google.cloud.exceptions import Conflict -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowTaskTimeout from airflow.providers.google.cloud.operators.bigquery import ( BigQueryCheckOperator, BigQueryConsoleIndexableLink, @@ -830,6 +830,7 @@ def test_execute_query_success(self, mock_hook): configuration=configuration, location=TEST_DATASET_LOCATION, job_id=real_job_id, + nowait=True, project_id=TEST_GCP_PROJECT_ID, retry=DEFAULT_RETRY, timeout=None, @@ -870,6 +871,7 @@ def test_execute_copy_success(self, mock_hook): configuration=configuration, location=TEST_DATASET_LOCATION, job_id=real_job_id, + nowait=True, project_id=TEST_GCP_PROJECT_ID, retry=DEFAULT_RETRY, timeout=None, @@ -913,6 +915,45 @@ def test_on_kill(self, mock_hook): project_id=TEST_GCP_PROJECT_ID, ) + @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') + @mock.patch('airflow.providers.google.cloud.hooks.bigquery.BigQueryJob') + def test_on_kill_after_execution_timeout(self, mock_job, mock_hook): + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + + configuration = { + "query": { + "query": "SELECT * FROM any", + "useLegacySql": False, + } + } + + mock_job.job_id = real_job_id + mock_job.error_result = False + mock_job.result.side_effect = AirflowTaskTimeout() + + mock_hook.return_value.insert_job.return_value = mock_job + mock_hook.return_value.generate_job_id.return_value = real_job_id + + op = BigQueryInsertJobOperator( + task_id="insert_query_job", + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=job_id, + project_id=TEST_GCP_PROJECT_ID, + cancel_on_kill=True, + ) + with pytest.raises(AirflowTaskTimeout): + op.execute(context=MagicMock()) + + op.on_kill() + mock_hook.return_value.cancel_job.assert_called_once_with( + job_id=real_job_id, + location=TEST_DATASET_LOCATION, + project_id=TEST_GCP_PROJECT_ID, + ) + @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute_failure(self, mock_hook): job_id = "123456" @@ -1018,6 +1059,7 @@ def test_execute_force_rerun(self, mock_hook): configuration=configuration, location=TEST_DATASET_LOCATION, job_id=real_job_id, + nowait=True, project_id=TEST_GCP_PROJECT_ID, retry=DEFAULT_RETRY, timeout=None, @@ -1038,7 +1080,7 @@ def test_execute_no_force_rerun(self, mock_hook): } } - mock_hook.return_value.insert_job.return_value.result.side_effect = Conflict("any") + mock_hook.return_value.insert_job.side_effect = Conflict("any") mock_hook.return_value.generate_job_id.return_value = real_job_id job = MagicMock( job_id=real_job_id,