Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferrable BatchOperator #29300

Merged
merged 14 commits into from
Apr 5, 2023
239 changes: 238 additions & 1 deletion airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
"""
from __future__ import annotations

import asyncio
from random import uniform
from time import sleep
from typing import Any

import botocore.client
import botocore.exceptions
import botocore.waiter

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook
from airflow.typing_compat import Protocol, runtime_checkable


Expand Down Expand Up @@ -544,3 +546,238 @@ def exp(tries):
delay = 1 + pow(tries * 0.6, 2)
delay = min(max_interval, delay)
return uniform(delay / 3, delay)


class BatchClientAsyncHook(BatchClientHook, AwsBaseAsyncHook):
"""
Async client for AWS Batch services.

:param job_id: the job ID, usually unknown (None) until the
submit_job operation gets the jobId defined by AWS Batch

:param waiters: an :py:class:`.BatchWaiters` object (see note below);
if None, polling is used with max_retries and status_retries.

.. note::
Several methods use a default random delay to check or poll for job status, i.e.
``random.sample()``
Using a random interval helps to avoid AWS API throttle limits
when many concurrent tasks request job-descriptions.

To modify the global defaults for the range of jitter allowed when a
random delay is used to check Batch job status, modify these defaults, e.g.:

BatchClient.DEFAULT_DELAY_MIN = 0
BatchClient.DEFAULT_DELAY_MAX = 5

When explicit delay values are used, a 1 second random jitter is applied to the
delay . It is generally recommended that random jitter is added to API requests.
A convenience method is provided for this, e.g. to get a random delay of
10 sec +/- 5 sec: ``delay = BatchClient.add_jitter(10, width=5, minima=0)``
"""

def __init__(self, job_id: str | None, waiters: Any = None, *args: Any, **kwargs: Any) -> None:
rajaths010494 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)
self.job_id = job_id
self.waiters = waiters

async def monitor_job(self) -> dict[str, str] | None:
"""
Monitor an AWS Batch job
monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout
is given while creating the task. These exceptions should be handled in taskinstance.py
instead of here like it was previously done

:raises: AirflowException
"""
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")

if self.waiters:
rajaths010494 marked this conversation as resolved.
Show resolved Hide resolved
self.waiters.wait_for_job(self.job_id)
return None
else:
await self.wait_for_job(self.job_id)
await self.check_job_success(self.job_id)
success_msg = f"AWS Batch job ({self.job_id}) succeeded"
self.log.info(success_msg)
return {"status": "success", "message": success_msg}

async def check_job_success(self, job_id: str) -> bool: # type: ignore[override]
"""
Check the final status of the Batch job; return True if the job
'SUCCEEDED', else raise an AirflowException

:param job_id: a Batch job ID

:raises: AirflowException
"""
job = await self.get_job_description(job_id)
job_status = job.get("status")
if job_status == self.SUCCESS_STATE:
self.log.info("AWS Batch job (%s) succeeded: %s", job_id, job)
return True

if job_status == self.FAILURE_STATE:
raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}")

if job_status in self.INTERMEDIATE_STATES:
raise AirflowException(f"AWS Batch job ({job_id}) is not complete: {job}")

raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}")

@staticmethod
async def delay(delay: int | float | None = None) -> None: # type: ignore[override]
"""
Pause execution for ``delay`` seconds.

:param delay: a delay to pause execution using ``time.sleep(delay)``;
a small 1 second jitter is applied to the delay.

.. note::
This method uses a default random delay, i.e.
``random.sample()``;
using a random interval helps to avoid AWS API throttle limits
when many concurrent tasks request job-descriptions.
"""
if delay is None:
delay = uniform(BatchClientHook.DEFAULT_DELAY_MIN, BatchClientHook.DEFAULT_DELAY_MAX)
else:
delay = BatchClientAsyncHook.add_jitter(delay)
await asyncio.sleep(delay)

async def wait_for_job( # type: ignore[override]
self, job_id: str, delay: int | float | None = None
) -> None:
"""
Wait for Batch job to complete.

:param job_id: a Batch job ID

:param delay: a delay before polling for job status

:raises: AirflowException
"""
await self.delay(delay)
await self.poll_for_job_running(job_id, delay)
await self.poll_for_job_complete(job_id, delay)
self.log.info("AWS Batch job (%s) has completed", job_id)

async def poll_for_job_complete( # type: ignore[override]
self, job_id: str, delay: int | float | None = None
) -> None:
"""
Poll for job completion. The status that indicates job completion
are: 'SUCCEEDED'|'FAILED'.

So the status options that this will wait for are the transitions from:
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED'

:param job_id: a Batch job ID

:param delay: a delay before polling for job status

:raises: AirflowException
"""
await self.delay(delay)
complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE]
await self.poll_job_status(job_id, complete_status)

async def poll_for_job_running( # type: ignore[override]
self, job_id: str, delay: int | float | None = None
) -> None:
"""
Poll for job running. The status that indicates a job is running or
already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'.

So the status options that this will wait for are the transitions from:
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'|'SUCCEEDED'|'FAILED'

The completed status options are included for cases where the status
changes too quickly for polling to detect a RUNNING status that moves
quickly from STARTING to RUNNING to completed (often a failure).

:param job_id: a Batch job ID

:param delay: a delay before polling for job status

:raises: AirflowException
"""
await self.delay(delay)
running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE]
await self.poll_job_status(job_id, running_status)

async def get_job_description(self, job_id: str) -> dict[str, str]: # type: ignore[override]
"""
Get job description (using status_retries).

:param job_id: a Batch job ID
:raises: AirflowException
"""
retries = 0
async with await self.get_client_async() as client:
while True:
try:
response = client.describe_jobs(jobs=[job_id])
return self.parse_job_description(job_id, response)

except botocore.exceptions.ClientError as err:
error = err.response.get("Error", {})
if error.get("Code") == "TooManyRequestsException":
pass # allow it to retry, if possible
else:
raise AirflowException(f"AWS Batch job ({job_id}) description error: {err}")

retries += 1
if retries >= self.status_retries:
raise AirflowException(
f"AWS Batch job ({job_id}) description error: exceeded status_retries "
f"({self.status_retries})"
)

pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.status_retries,
pause,
)
await self.delay(pause)

async def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: # type: ignore[override]
"""
Poll for job status using an exponential back-off strategy (with max_retries).
The Batch job status polled are:
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'

:param job_id: a Batch job ID
:param match_status: a list of job status to match
:raises: AirflowException
"""
retries = 0
while True:
job = await self.get_job_description(job_id)
job_status = job.get("status")
self.log.info(
"AWS Batch job (%s) check status (%s) in %s",
job_id,
job_status,
match_status,
)
if job_status in match_status:
return True

if retries >= self.max_retries:
raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries")

retries += 1
pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.max_retries,
pause,
)
await self.delay(pause)
37 changes: 37 additions & 0 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
BatchJobQueueLink,
)
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from airflow.providers.amazon.aws.utils import trim_none_values

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,6 +72,7 @@ class BatchOperator(BaseOperator):
Override the region_name in connection (if provided)
:param tags: collection of tags to apply to the AWS Batch job submission
if None, no tags are submitted
:param deferrable: Run operator in the deferrable mode.
rajaths010494 marked this conversation as resolved.
Show resolved Hide resolved

.. note::
Any custom waiters must return a waiter for these calls:
Expand Down Expand Up @@ -125,6 +127,7 @@ def __init__(
region_name: str | None = None,
tags: dict | None = None,
wait_for_completion: bool = True,
deferrable: bool = False,
**kwargs,
):

Expand All @@ -139,6 +142,8 @@ def __init__(
self.waiters = waiters
self.tags = tags or {}
self.wait_for_completion = wait_for_completion
self.deferrable = deferrable

self.hook = BatchClientHook(
max_retries=max_retries,
status_retries=status_retries,
Expand All @@ -154,11 +159,43 @@ def execute(self, context: Context):
"""
self.submit_job(context)

if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=BatchOperatorTrigger(
job_id=self.job_id,
job_name=self.job_name,
job_definition=self.job_definition,
job_queue=self.job_queue,
overrides=self.overrides,
array_properties=self.array_properties,
parameters=self.parameters,
waiters=self.waiters,
tags=self.tags,
max_retries=self.hook.max_retries,
status_retries=self.hook.status_retries,
aws_conn_id=self.hook.aws_conn_id,
region_name=self.hook.region_name,
),
method_name="execute_complete",
)

if self.wait_for_completion:
self.monitor_job(context)

return self.job_id

def execute_complete(self, context: Context, event: dict[str, Any]):
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if "status" in event and event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])
return self.job_id

def on_kill(self):
response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user")
self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response)
Expand Down
Loading