diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py index 526ab9a8a4f0..10b93afbbc2b 100644 --- a/airflow/providers/amazon/aws/hooks/batch_client.py +++ b/airflow/providers/amazon/aws/hooks/batch_client.py @@ -26,15 +26,17 @@ """ from __future__ import annotations +import asyncio from random import uniform from time import sleep +from typing import Any import botocore.client import botocore.exceptions import botocore.waiter from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook from airflow.typing_compat import Protocol, runtime_checkable @@ -544,3 +546,238 @@ def exp(tries): delay = 1 + pow(tries * 0.6, 2) delay = min(max_interval, delay) return uniform(delay / 3, delay) + + +class BatchClientAsyncHook(BatchClientHook, AwsBaseAsyncHook): + """ + Async client for AWS Batch services. + + :param job_id: the job ID, usually unknown (None) until the + submit_job operation gets the jobId defined by AWS Batch + + :param waiters: an :py:class:`.BatchWaiters` object (see note below); + if None, polling is used with max_retries and status_retries. + + .. note:: + Several methods use a default random delay to check or poll for job status, i.e. + ``random.sample()`` + Using a random interval helps to avoid AWS API throttle limits + when many concurrent tasks request job-descriptions. + + To modify the global defaults for the range of jitter allowed when a + random delay is used to check Batch job status, modify these defaults, e.g.: + + BatchClient.DEFAULT_DELAY_MIN = 0 + BatchClient.DEFAULT_DELAY_MAX = 5 + + When explicit delay values are used, a 1 second random jitter is applied to the + delay . It is generally recommended that random jitter is added to API requests. + A convenience method is provided for this, e.g. to get a random delay of + 10 sec +/- 5 sec: ``delay = BatchClient.add_jitter(10, width=5, minima=0)`` + """ + + def __init__(self, job_id: str | None, waiters: Any = None, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.job_id = job_id + self.waiters = waiters + + async def monitor_job(self) -> dict[str, str] | None: + """ + Monitor an AWS Batch job + monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout + is given while creating the task. These exceptions should be handled in taskinstance.py + instead of here like it was previously done + + :raises: AirflowException + """ + if not self.job_id: + raise AirflowException("AWS Batch job - job_id was not found") + + if self.waiters: + self.waiters.wait_for_job(self.job_id) + return None + else: + await self.wait_for_job(self.job_id) + await self.check_job_success(self.job_id) + success_msg = f"AWS Batch job ({self.job_id}) succeeded" + self.log.info(success_msg) + return {"status": "success", "message": success_msg} + + async def check_job_success(self, job_id: str) -> bool: # type: ignore[override] + """ + Check the final status of the Batch job; return True if the job + 'SUCCEEDED', else raise an AirflowException + + :param job_id: a Batch job ID + + :raises: AirflowException + """ + job = await self.get_job_description(job_id) + job_status = job.get("status") + if job_status == self.SUCCESS_STATE: + self.log.info("AWS Batch job (%s) succeeded: %s", job_id, job) + return True + + if job_status == self.FAILURE_STATE: + raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}") + + if job_status in self.INTERMEDIATE_STATES: + raise AirflowException(f"AWS Batch job ({job_id}) is not complete: {job}") + + raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}") + + @staticmethod + async def delay(delay: int | float | None = None) -> None: # type: ignore[override] + """ + Pause execution for ``delay`` seconds. + + :param delay: a delay to pause execution using ``time.sleep(delay)``; + a small 1 second jitter is applied to the delay. + + .. note:: + This method uses a default random delay, i.e. + ``random.sample()``; + using a random interval helps to avoid AWS API throttle limits + when many concurrent tasks request job-descriptions. + """ + if delay is None: + delay = uniform(BatchClientHook.DEFAULT_DELAY_MIN, BatchClientHook.DEFAULT_DELAY_MAX) + else: + delay = BatchClientAsyncHook.add_jitter(delay) + await asyncio.sleep(delay) + + async def wait_for_job( # type: ignore[override] + self, job_id: str, delay: int | float | None = None + ) -> None: + """ + Wait for Batch job to complete. + + :param job_id: a Batch job ID + + :param delay: a delay before polling for job status + + :raises: AirflowException + """ + await self.delay(delay) + await self.poll_for_job_running(job_id, delay) + await self.poll_for_job_complete(job_id, delay) + self.log.info("AWS Batch job (%s) has completed", job_id) + + async def poll_for_job_complete( # type: ignore[override] + self, job_id: str, delay: int | float | None = None + ) -> None: + """ + Poll for job completion. The status that indicates job completion + are: 'SUCCEEDED'|'FAILED'. + + So the status options that this will wait for are the transitions from: + 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED' + + :param job_id: a Batch job ID + + :param delay: a delay before polling for job status + + :raises: AirflowException + """ + await self.delay(delay) + complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE] + await self.poll_job_status(job_id, complete_status) + + async def poll_for_job_running( # type: ignore[override] + self, job_id: str, delay: int | float | None = None + ) -> None: + """ + Poll for job running. The status that indicates a job is running or + already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'. + + So the status options that this will wait for are the transitions from: + 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'|'SUCCEEDED'|'FAILED' + + The completed status options are included for cases where the status + changes too quickly for polling to detect a RUNNING status that moves + quickly from STARTING to RUNNING to completed (often a failure). + + :param job_id: a Batch job ID + + :param delay: a delay before polling for job status + + :raises: AirflowException + """ + await self.delay(delay) + running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE] + await self.poll_job_status(job_id, running_status) + + async def get_job_description(self, job_id: str) -> dict[str, str]: # type: ignore[override] + """ + Get job description (using status_retries). + + :param job_id: a Batch job ID + :raises: AirflowException + """ + retries = 0 + async with await self.get_client_async() as client: + while True: + try: + response = client.describe_jobs(jobs=[job_id]) + return self.parse_job_description(job_id, response) + + except botocore.exceptions.ClientError as err: + error = err.response.get("Error", {}) + if error.get("Code") == "TooManyRequestsException": + pass # allow it to retry, if possible + else: + raise AirflowException(f"AWS Batch job ({job_id}) description error: {err}") + + retries += 1 + if retries >= self.status_retries: + raise AirflowException( + f"AWS Batch job ({job_id}) description error: exceeded status_retries " + f"({self.status_retries})" + ) + + pause = self.exponential_delay(retries) + self.log.info( + "AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds", + job_id, + retries, + self.status_retries, + pause, + ) + await self.delay(pause) + + async def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: # type: ignore[override] + """ + Poll for job status using an exponential back-off strategy (with max_retries). + The Batch job status polled are: + 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED' + + :param job_id: a Batch job ID + :param match_status: a list of job status to match + :raises: AirflowException + """ + retries = 0 + while True: + job = await self.get_job_description(job_id) + job_status = job.get("status") + self.log.info( + "AWS Batch job (%s) check status (%s) in %s", + job_id, + job_status, + match_status, + ) + if job_status in match_status: + return True + + if retries >= self.max_retries: + raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries") + + retries += 1 + pause = self.exponential_delay(retries) + self.log.info( + "AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds", + job_id, + retries, + self.max_retries, + pause, + ) + await self.delay(pause) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 6565bcecfbaf..79a10a7b17c2 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -37,6 +37,7 @@ BatchJobQueueLink, ) from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink +from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: @@ -71,6 +72,7 @@ class BatchOperator(BaseOperator): Override the region_name in connection (if provided) :param tags: collection of tags to apply to the AWS Batch job submission if None, no tags are submitted + :param deferrable: Run operator in the deferrable mode. .. note:: Any custom waiters must return a waiter for these calls: @@ -125,6 +127,7 @@ def __init__( region_name: str | None = None, tags: dict | None = None, wait_for_completion: bool = True, + deferrable: bool = False, **kwargs, ): @@ -139,6 +142,8 @@ def __init__( self.waiters = waiters self.tags = tags or {} self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.hook = BatchClientHook( max_retries=max_retries, status_retries=status_retries, @@ -154,11 +159,43 @@ def execute(self, context: Context): """ self.submit_job(context) + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=BatchOperatorTrigger( + job_id=self.job_id, + job_name=self.job_name, + job_definition=self.job_definition, + job_queue=self.job_queue, + overrides=self.overrides, + array_properties=self.array_properties, + parameters=self.parameters, + waiters=self.waiters, + tags=self.tags, + max_retries=self.hook.max_retries, + status_retries=self.hook.status_retries, + aws_conn_id=self.hook.aws_conn_id, + region_name=self.hook.region_name, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: self.monitor_job(context) return self.job_id + def execute_complete(self, context: Context, event: dict[str, Any]): + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if "status" in event and event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info(event["message"]) + return self.job_id + def on_kill(self): response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user") self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response) diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py new file mode 100644 index 000000000000..eb5a80a3c956 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, AsyncIterator + +from airflow.providers.amazon.aws.hooks.batch_client import BatchClientAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class BatchOperatorTrigger(BaseTrigger): + """ + Checks for the state of a previously submitted job to AWS Batch. + BatchOperatorTrigger is fired as deferred class with params to poll the job state in Triggerer + + :param job_id: the job ID, usually unknown (None) until the + submit_job operation gets the jobId defined by AWS Batch + :param job_name: the name for the job that will run on AWS Batch (templated) + :param job_definition: the job definition name on AWS Batch + :param job_queue: the queue name on AWS Batch + :param overrides: the `containerOverrides` parameter for boto3 (templated) + :param array_properties: the `arrayProperties` parameter for boto3 + :param parameters: the `parameters` for boto3 (templated) + :param waiters: a :class:`.BatchWaiters` object (see note below); + if None, polling is used with max_retries and status_retries. + :param tags: collection of tags to apply to the AWS Batch job submission + if None, no tags are submitted + :param max_retries: exponential back-off retries, 4200 = 48 hours; + polling is only used when waiters is None + :param status_retries: number of HTTP retries to get job status, 10; + polling is only used when waiters is None + :param aws_conn_id: connection id of AWS credentials / region name. If None, + credential boto3 strategy will be used. + :param region_name: AWS region name to use . + Override the region_name in connection (if provided) + """ + + def __init__( + self, + job_id: str | None, + job_name: str, + job_definition: str, + job_queue: str, + overrides: dict[str, str], + array_properties: dict[str, str], + parameters: dict[str, str], + waiters: Any, + tags: dict[str, str], + max_retries: int, + status_retries: int, + region_name: str | None, + aws_conn_id: str | None = "aws_default", + ): + super().__init__() + self.job_id = job_id + self.job_name = job_name + self.job_definition = job_definition + self.job_queue = job_queue + self.overrides = overrides or {} + self.array_properties = array_properties or {} + self.parameters = parameters or {} + self.waiters = waiters + self.tags = tags or {} + self.max_retries = max_retries + self.status_retries = status_retries + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BatchOperatorTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger", + { + "job_id": self.job_id, + "job_name": self.job_name, + "job_definition": self.job_definition, + "job_queue": self.job_queue, + "overrides": self.overrides, + "array_properties": self.array_properties, + "parameters": self.parameters, + "waiters": self.waiters, + "tags": self.tags, + "max_retries": self.max_retries, + "status_retries": self.status_retries, + "aws_conn_id": self.aws_conn_id, + "region_name": self.region_name, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: + """ + Make async connection using aiobotocore library to AWS Batch, + periodically poll for the job status on the Triggerer + + The status that indicates job completion are: 'SUCCEEDED'|'FAILED'. + + So the status options that this will poll for are the transitions from: + 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED' + """ + hook = BatchClientAsyncHook(job_id=self.job_id, waiters=self.waiters, aws_conn_id=self.aws_conn_id) + try: + response = await hook.monitor_job() + if response: + yield TriggerEvent(response) + else: + error_message = f"{self.job_id} failed" + yield TriggerEvent({"status": "error", "message": error_message}) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/docs/apache-airflow-providers-amazon/operators/batch.rst b/docs/apache-airflow-providers-amazon/operators/batch.rst index ba280cb38d37..0c686184b922 100644 --- a/docs/apache-airflow-providers-amazon/operators/batch.rst +++ b/docs/apache-airflow-providers-amazon/operators/batch.rst @@ -37,7 +37,10 @@ Operators Submit a new AWS Batch job ========================== -To submit a new AWS Batch job and monitor it until it reaches a terminal state you can +To submit a new AWS Batch job and monitor it until it reaches a terminal state. +You can also run this operator in deferrable mode by setting the parameter ``deferrable`` to True. +This will lead to efficient utilization of Airflow workers as polling for job status happens on +the triggerer asynchronously. Note that this will need triggerer to be available on your Airflow deployment. use :class:`~airflow.providers.amazon.aws.operators.batch.BatchOperator`. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_batch.py diff --git a/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py b/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py new file mode 100644 index 000000000000..10be746ef9c4 --- /dev/null +++ b/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py @@ -0,0 +1,213 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import sys + +import botocore +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.batch_client import BatchClientAsyncHook + +if sys.version_info < (3, 8): + # For compatibility with Python 3.7 + from asynctest import mock as async_mock +else: + from unittest import mock as async_mock + +pytest.importorskip("aiobotocore") + + +class TestBatchClientAsyncHook: + JOB_ID = "e2a459c5-381b-494d-b6e8-d6ee334db4e2" + BATCH_API_SUCCESS_RESPONSE = {"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status") + async def test_monitor_job_with_success(self, mock_poll_job_status, mock_client): + """Tests that the monitor_job method returns expected event once successful""" + mock_poll_job_status.return_value = True + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = ( + self.BATCH_API_SUCCESS_RESPONSE + ) + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) + result = await hook.monitor_job() + assert result == {"status": "success", "message": f"AWS Batch job ({self.JOB_ID}) succeeded"} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status") + async def test_monitor_job_with_no_job_id(self, mock_poll_job_status, mock_client): + """Tests that the monitor_job method raises expected exception when incorrect job id is passed""" + mock_poll_job_status.return_value = True + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = ( + self.BATCH_API_SUCCESS_RESPONSE + ) + + with pytest.raises(AirflowException) as exc_info: + hook = BatchClientAsyncHook(job_id=False, waiters=None) + await hook.monitor_job() + assert str(exc_info.value) == "AWS Batch job - job_id was not found" + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status") + async def test_hit_api_throttle(self, mock_poll_job_status, mock_client): + """ + Tests that the get_job_description method raises correct exception when retries + exceed the threshold + """ + mock_poll_job_status.return_value = True + mock_client.return_value.__aenter__.return_value.describe_jobs.side_effect = ( + botocore.exceptions.ClientError( + error_response={ + "Error": { + "Code": "TooManyRequestsException", + } + }, + operation_name="get job description", + ) + ) + """status_retries = 2 ensures that exponential_delay block is covered in batch_client.py + otherwise the code coverage will drop""" + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, status_retries=2) + with pytest.raises(AirflowException) as exc_info: + await hook.get_job_description(job_id=self.JOB_ID) + assert ( + str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) description error: exceeded " + "status_retries (2)" + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status") + async def test_client_error(self, mock_poll_job_status, mock_client): + """Test that the get_job_description method raises correct exception when the error code + from boto3 api is not TooManyRequestsException""" + mock_poll_job_status.return_value = True + mock_client.return_value.__aenter__.return_value.describe_jobs.side_effect = ( + botocore.exceptions.ClientError( + error_response={"Error": {"Code": "InvalidClientTokenId", "Message": "Malformed Token"}}, + operation_name="get job description", + ) + ) + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, status_retries=1) + with pytest.raises(AirflowException) as exc_info: + await hook.get_job_description(job_id=self.JOB_ID) + assert ( + str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) description error: An error " + "occurred (InvalidClientTokenId) when calling the get job description operation: " + "Malformed Token" + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + async def test_check_job_success(self, mock_client): + """Tests that the check_job_success method returns True when job succeeds""" + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = ( + self.BATCH_API_SUCCESS_RESPONSE + ) + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) + result = await hook.check_job_success(job_id=self.JOB_ID) + assert result is True + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + async def test_check_job_raises_exception_failed(self, mock_client): + """Tests that the check_job_success method raises exception correctly as per job state""" + mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "FAILED"}]} + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) + with pytest.raises(AirflowException) as exc_info: + await hook.check_job_success(job_id=self.JOB_ID) + assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) failed" + ": " + str( + mock_job["jobs"][0] + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + async def test_check_job_raises_exception_pending(self, mock_client): + """Tests that the check_job_success method raises exception correctly as per job state""" + mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "PENDING"}]} + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) + with pytest.raises(AirflowException) as exc_info: + await hook.check_job_success(job_id=self.JOB_ID) + assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not complete" + ": " + str( + mock_job["jobs"][0] + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + async def test_check_job_raises_exception_strange(self, mock_client): + """Tests that the check_job_success method raises exception correctly as per job state""" + mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "STRANGE"}]} + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) + with pytest.raises(AirflowException) as exc_info: + await hook.check_job_success(job_id=self.JOB_ID) + assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) has unknown status" + ": " + str( + mock_job["jobs"][0] + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + async def test_check_job_raises_exception_runnable(self, mock_client): + """Tests that the check_job_success method raises exception correctly as per job state""" + mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "RUNNABLE"}]} + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) + with pytest.raises(AirflowException) as exc_info: + await hook.check_job_success(job_id=self.JOB_ID) + assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not complete" + ": " + str( + mock_job["jobs"][0] + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + async def test_check_job_raises_exception_submitted(self, mock_client): + """Tests that the check_job_success method raises exception correctly as per job state""" + mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "SUBMITTED"}]} + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) + with pytest.raises(AirflowException) as exc_info: + await hook.check_job_success(job_id=self.JOB_ID) + assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not complete" + ": " + str( + mock_job["jobs"][0] + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + async def test_poll_job_status_raises_for_max_retries(self, mock_client): + mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "RUNNABLE"}]} + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, max_retries=1) + with pytest.raises(AirflowException) as exc_info: + await hook.poll_job_status(job_id=self.JOB_ID, match_status=["SUCCEEDED"]) + assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) status checks exceed " "max_retries" + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") + async def test_poll_job_status_in_match_status(self, mock_client): + mock_job = self.BATCH_API_SUCCESS_RESPONSE + mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job + hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, max_retries=1) + result = await hook.poll_job_status(job_id=self.JOB_ID, match_status=["SUCCEEDED"]) + assert result is True diff --git a/tests/providers/amazon/aws/deferrable/triggers/test_batch.py b/tests/providers/amazon/aws/deferrable/triggers/test_batch.py new file mode 100644 index 000000000000..ad534619f0c0 --- /dev/null +++ b/tests/providers/amazon/aws/deferrable/triggers/test_batch.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import importlib.util + +import pytest + +from airflow.providers.amazon.aws.triggers.batch import ( + BatchOperatorTrigger, +) +from airflow.triggers.base import TriggerEvent +from tests.providers.amazon.aws.utils.compat import async_mock + +JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3" +JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19" +MAX_RETRIES = 2 +STATUS_RETRIES = 3 +POKE_INTERVAL = 5 +AWS_CONN_ID = "airflow_test" +REGION_NAME = "eu-west-1" + + +@pytest.mark.skipif(not bool(importlib.util.find_spec("aiobotocore")), reason="aiobotocore require") +class TestBatchOperatorTrigger: + TRIGGER = BatchOperatorTrigger( + job_id=JOB_ID, + job_name=JOB_NAME, + job_definition="hello-world", + job_queue="queue", + waiters=None, + tags={}, + max_retries=MAX_RETRIES, + status_retries=STATUS_RETRIES, + parameters={}, + overrides={}, + array_properties={}, + region_name="eu-west-1", + aws_conn_id="airflow_test", + ) + + def test_batch_trigger_serialization(self): + """ + Asserts that the BatchOperatorTrigger correctly serializes its arguments + and classpath. + """ + + classpath, kwargs = self.TRIGGER.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger" + assert kwargs == { + "job_id": JOB_ID, + "job_name": JOB_NAME, + "job_definition": "hello-world", + "job_queue": "queue", + "waiters": None, + "tags": {}, + "max_retries": MAX_RETRIES, + "status_retries": STATUS_RETRIES, + "parameters": {}, + "overrides": {}, + "array_properties": {}, + "region_name": "eu-west-1", + "aws_conn_id": "airflow_test", + } + + @pytest.mark.asyncio + async def test_batch_trigger_run(self): + """Test that the task is not done when event is not returned from trigger.""" + + task = asyncio.create_task(self.TRIGGER.run().__anext__()) + await asyncio.sleep(0.5) + # TriggerEvent was not returned + assert task.done() is False + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job") + async def test_batch_trigger_completed(self, mock_response): + """Test if the success event is returned from trigger.""" + mock_response.return_value = {"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"} + + generator = self.TRIGGER.run() + actual_response = await generator.asend(None) + assert ( + TriggerEvent({"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}) + == actual_response + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job") + async def test_batch_trigger_failure(self, mock_response): + """Test if the failure event is returned from trigger.""" + mock_response.return_value = {"status": "error", "message": f"{JOB_ID} failed"} + + generator = self.TRIGGER.run() + actual_response = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": f"{JOB_ID} failed"}) == actual_response + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job") + async def test_batch_trigger_none(self, mock_response): + """Test if the failure event is returned when there is no response from hook.""" + mock_response.return_value = None + + generator = self.TRIGGER.run() + actual_response = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": f"{JOB_ID} failed"}) == actual_response + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job") + async def test_batch_trigger_exception(self, mock_response): + """Test if the exception is raised from trigger.""" + mock_response.side_effect = Exception("Test exception") + + task = [i async for i in self.TRIGGER.run()] + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py b/tests/providers/amazon/aws/hooks/test_batch_client.py index 13726e5518ff..d7e06d9eb23a 100644 --- a/tests/providers/amazon/aws/hooks/test_batch_client.py +++ b/tests/providers/amazon/aws/hooks/test_batch_client.py @@ -20,6 +20,7 @@ import logging from unittest import mock +import botocore import botocore.exceptions import pytest diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 0ddfcea59171..2192b7c4e20a 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -19,11 +19,18 @@ from unittest import mock +import pendulum import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred +from airflow.models import DAG +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator +from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger +from airflow.utils import timezone +from airflow.utils.types import DagRunType # Use dummy AWS credentials AWS_REGION = "eu-west-1" @@ -211,3 +218,114 @@ def test_execute(self, mock_conn): computeResources=compute_resources, tags=tags, ) + + +def create_context(task, dag=None): + if dag is None: + dag = DAG(dag_id="dag") + tzinfo = pendulum.timezone("UTC") + execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo) + dag_run = DagRun( + dag_id=dag.dag_id, + execution_date=execution_date, + run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), + ) + + task_instance = TaskInstance(task=task) + task_instance.dag_run = dag_run + task_instance.xcom_push = mock.Mock() + return { + "dag": dag, + "ts": execution_date.isoformat(), + "task": task, + "ti": task_instance, + "task_instance": task_instance, + "run_id": dag_run.run_id, + "dag_run": dag_run, + "execution_date": execution_date, + "data_interval_end": execution_date, + "logical_date": execution_date, + } + + +class TestBatchOperatorAsync: + JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3" + JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19" + MAX_RETRIES = 2 + STATUS_RETRIES = 3 + RESPONSE_WITHOUT_FAILURES = { + "jobName": JOB_NAME, + "jobId": JOB_ID, + } + + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") + def test_batch_op_async(self, get_client_type_mock): + get_client_type_mock.return_value.submit_job.return_value = self.RESPONSE_WITHOUT_FAILURES + task = BatchOperator( + task_id="task", + job_name=self.JOB_NAME, + job_queue="queue", + job_definition="hello-world", + max_retries=self.MAX_RETRIES, + status_retries=self.STATUS_RETRIES, + parameters=None, + overrides={}, + array_properties=None, + aws_conn_id="airflow_test", + region_name="eu-west-1", + tags={}, + deferrable=True, + ) + context = create_context(task) + with pytest.raises(TaskDeferred) as exc: + task.execute(context) + assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger is not a BatchOperatorTrigger" + + def test_batch_op_async_execute_failure(self): + """Tests that an AirflowException is raised in case of error event""" + + task = BatchOperator( + task_id="task", + job_name=self.JOB_NAME, + job_queue="queue", + job_definition="hello-world", + max_retries=self.MAX_RETRIES, + status_retries=self.STATUS_RETRIES, + parameters=None, + overrides={}, + array_properties=None, + aws_conn_id="airflow_test", + region_name="eu-west-1", + tags={}, + deferrable=True, + ) + with pytest.raises(AirflowException) as exc_info: + task.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) + + assert str(exc_info.value) == "test failure message" + + @pytest.mark.parametrize( + "event", + [{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}], + ) + def test_batch_op_async_execute_complete(self, caplog, event): + """Tests that execute_complete method returns None and that it prints expected log""" + task = BatchOperator( + task_id="task", + job_name=self.JOB_NAME, + job_queue="queue", + job_definition="hello-world", + max_retries=self.MAX_RETRIES, + status_retries=self.STATUS_RETRIES, + parameters=None, + overrides={}, + array_properties=None, + aws_conn_id="airflow_test", + region_name="eu-west-1", + tags={}, + deferrable=True, + ) + with mock.patch.object(task.log, "info") as mock_log_info: + assert task.execute_complete(context=None, event=event) is None + + mock_log_info.assert_called_with(f"AWS Batch job ({self.JOB_ID}) succeeded")