Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a deferrable mode to BatchCreateComputeEnvironmentOperator #32036

Merged
merged 9 commits into from
Jun 27, 2023
52 changes: 41 additions & 11 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
BatchJobQueueLink,
)
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from airflow.providers.amazon.aws.triggers.batch import (
BatchCreateComputeEnvironmentTrigger,
BatchOperatorTrigger,
)
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher

Expand Down Expand Up @@ -402,14 +405,16 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
services on your behalf (templated).
:param tags: Tags that you apply to the compute-environment to help you
categorize and organize your resources.
:param max_retries: Exponential back-off retries, 4200 = 48 hours; polling
is only used when waiters is None.
:param status_retries: Number of HTTP retries to get job status, 10; polling
is only used when waiters is None.
:param poll_interval: How long to wait in seconds between 2 polls at the environment status.
Only useful when deferrable is True.
:param max_retries: How many times to poll for the environment status.
Only useful when deferrable is True.
:param aws_conn_id: Connection ID of AWS credentials / region name. If None,
credential boto3 strategy will be used.
:param region_name: Region name to use in AWS Hook. Overrides the
``region_name`` in connection if provided.
:param deferrable: If True, the operator will wait asynchronously for the environment to be created.
This mode requires aiobotocore module to be installed. (default: False)
"""

template_fields: Sequence[str] = (
Expand All @@ -428,31 +433,41 @@ def __init__(
unmanaged_v_cpus: int | None = None,
service_role: str | None = None,
tags: dict | None = None,
poll_interval: int = 30,
vandonr-amz marked this conversation as resolved.
Show resolved Hide resolved
max_retries: int | None = None,
status_retries: int | None = None,
aws_conn_id: str | None = None,
region_name: str | None = None,
deferrable: bool = False,
**kwargs,
):
if "status_retries" in kwargs:
warnings.warn(
"The `status_retries` parameter is unused and should be removed. "
"It'll be deleted in a future version.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
kwargs.pop("status_retries") # remove before calling super() to prevent unexpected arg error

super().__init__(**kwargs)

self.compute_environment_name = compute_environment_name
self.environment_type = environment_type
self.state = state
self.unmanaged_v_cpus = unmanaged_v_cpus
self.compute_resources = compute_resources
self.service_role = service_role
self.tags = tags or {}
self.max_retries = max_retries
self.status_retries = status_retries
self.poll_interval = poll_interval
self.max_retries = max_retries or 120
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.deferrable = deferrable

@cached_property
def hook(self):
"""Create and return a BatchClientHook."""
return BatchClientHook(
max_retries=self.max_retries,
status_retries=self.status_retries,
vandonr-amz marked this conversation as resolved.
Show resolved Hide resolved
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
)
Expand All @@ -468,6 +483,21 @@ def execute(self, context: Context):
"serviceRole": self.service_role,
"tags": self.tags,
}
self.hook.client.create_compute_environment(**trim_none_values(kwargs))
response = self.hook.client.create_compute_environment(**trim_none_values(kwargs))
arn = response["computeEnvironmentArn"]
vandonr-amz marked this conversation as resolved.
Show resolved Hide resolved

if self.deferrable:
self.defer(
trigger=BatchCreateComputeEnvironmentTrigger(
arn, self.poll_interval, self.max_retries, self.aws_conn_id, self.region_name
),
method_name="execute_complete",
)

self.log.info("AWS Batch compute environment created successfully")
return arn

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while waiting for the compute environment to be ready: {event}")
return event["value"]
57 changes: 57 additions & 0 deletions airflow/providers/amazon/aws/triggers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from botocore.exceptions import WaiterError

from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent


Expand Down Expand Up @@ -188,3 +189,59 @@ async def run(self):
"message": f"Job {self.job_id} Succeeded",
}
)


class BatchCreateComputeEnvironmentTrigger(BaseTrigger):
"""
Trigger for BatchCreateComputeEnvironmentOperator.
The trigger will asynchronously poll the boto3 API and wait for the compute environment to be ready.

:param job_id: A unique identifier for the cluster.
:param max_retries: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: region name to use in AWS Hook
:param poll_interval: The amount of time in seconds to wait between attempts.
"""

def __init__(
self,
compute_env_arn: str | None = None,
poll_interval: int = 30,
max_retries: int = 10,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
):
super().__init__()
self.compute_env_arn = compute_env_arn
self.max_retries = max_retries
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BatchOperatorTrigger arguments and classpath."""
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"compute_env_arn": self.compute_env_arn,
"max_retries": self.max_retries,
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"poll_interval": self.poll_interval,
},
)

async def run(self):
hook = BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
async with hook.async_conn as client:
waiter = hook.get_waiter("compute_env_ready", deferrable=True, client=client)
await async_wait(
waiter=waiter,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_retries,
args={"computeEnvironments": [self.compute_env_arn]},
failure_message="Failure while creating Compute Environment",
status_message="Compute Environment not ready yet",
status_args=["computeEnvironments[].status", "computeEnvironments[].statusReason"],
)
yield TriggerEvent({"status": "success", "value": self.compute_env_arn})
26 changes: 26 additions & 0 deletions airflow/providers/amazon/aws/waiters/batch.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,32 @@
"state": "failed"
}
]
},

"compute_env_ready": {
"delay": 30,
"operation": "DescribeComputeEnvironments",
"maxAttempts": 100,
"acceptors": [
{
"argument": "computeEnvironments[].status",
"expected": "VALID",
"matcher": "pathAll",
"state": "success"
},
{
"argument": "computeEnvironments[].status",
"expected": "INVALID",
"matcher": "pathAny",
"state": "failed"
},
{
"argument": "computeEnvironments[].status",
"expected": "DELETED",
"matcher": "pathAny",
"state": "failed"
}
]
}
}
}
28 changes: 27 additions & 1 deletion tests/providers/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator

# Use dummy AWS credentials
from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from airflow.providers.amazon.aws.triggers.batch import (
BatchCreateComputeEnvironmentTrigger,
BatchOperatorTrigger,
)

AWS_REGION = "eu-west-1"
AWS_ACCESS_KEY_ID = "airflow_dummy_key"
Expand Down Expand Up @@ -326,3 +329,26 @@ def test_execute(self, mock_conn):
computeResources=compute_resources,
tags=tags,
)

@mock.patch.object(BatchClientHook, "client")
def test_defer(self, client_mock):
client_mock.create_compute_environment.return_value = {"computeEnvironmentArn": "my_arn"}

operator = BatchCreateComputeEnvironmentOperator(
task_id="task",
compute_environment_name="my_env_name",
environment_type="my_env_type",
state="my_state",
compute_resources={},
max_retries=123456,
poll_interval=456789,
deferrable=True,
)

with pytest.raises(TaskDeferred) as deferred:
operator.execute(None)

assert isinstance(deferred.value.trigger, BatchCreateComputeEnvironmentTrigger)
assert deferred.value.trigger.compute_env_arn == "my_arn"
assert deferred.value.trigger.poll_interval == 456789
assert deferred.value.trigger.max_retries == 123456
43 changes: 42 additions & 1 deletion tests/providers/amazon/aws/triggers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
import pytest
from botocore.exceptions import WaiterError

from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger, BatchSensorTrigger
from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.triggers.batch import (
BatchCreateComputeEnvironmentTrigger,
BatchOperatorTrigger,
BatchSensorTrigger,
)
from airflow.triggers.base import TriggerEvent

BATCH_JOB_ID = "job_id"
Expand Down Expand Up @@ -181,3 +187,38 @@ async def test_batch_sensor_trigger_failure(
assert actual_response == TriggerEvent(
{"status": "failure", "message": f"Job Failed: Waiter {name} failed: {reason}"}
)


class TestBatchCreateComputeEnvironmentTrigger:
@pytest.mark.asyncio
@mock.patch.object(BatchClientHook, "async_conn")
@mock.patch.object(BatchClientHook, "get_waiter")
async def test_success(self, get_waiter_mock, conn_mock):
get_waiter_mock().wait = AsyncMock(
side_effect=[
WaiterError(
"situation normal", "first try", {"computeEnvironments": [{"status": "my_status"}]}
),
{},
]
)
trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3)

generator = trigger.run()
response: TriggerEvent = await generator.asend(None)

assert response.payload["status"] == "success"
assert response.payload["value"] == "my_arn"

@pytest.mark.asyncio
@mock.patch.object(BatchClientHook, "async_conn")
@mock.patch.object(BatchClientHook, "get_waiter")
async def test_failure(self, get_waiter_mock, conn_mock):
get_waiter_mock().wait = AsyncMock(
side_effect=[WaiterError("terminal failure", "terminal failure reason", {})]
)
trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3)

with pytest.raises(AirflowException):
generator = trigger.run()
await generator.asend(None)