Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMR serverless Create/Start/Stop/Delete Application deferrable mode #32513

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "emr-serverless"
super().__init__(*args, **kwargs)

def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}):
def cancel_running_jobs(
self, application_id: str, waiter_config: dict | None = None, wait_for_completion: bool = True
) -> int:
"""
List all jobs in an intermediate state, cancel them, then wait for those jobs to reach terminal state.
Cancel jobs in an intermediate state, and return the number of cancelled jobs.

If wait_for_completion is True, then the method will wait until all jobs are
cancelled before returning.

Note: if new jobs are triggered while this operation is ongoing,
it's going to time out and return an error.
Expand All @@ -284,13 +289,16 @@ def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}):
)
for job_id in job_ids:
self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id)
if count > 0:
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
self.get_waiter("no_job_running").wait(
applicationId=application_id,
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
WaiterConfig=waiter_config,
)
if wait_for_completion:
if count > 0:
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
self.get_waiter("no_job_running").wait(
applicationId=application_id,
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
WaiterConfig=waiter_config or {},
)

return count

syedahsn marked this conversation as resolved.
Show resolved Hide resolved

class EmrContainerHook(AwsBaseHook):
Expand Down
172 changes: 155 additions & 17 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@
EmrAddStepsTrigger,
EmrContainerTrigger,
EmrCreateJobFlowTrigger,
EmrServerlessCancelJobsTrigger,
EmrServerlessCreateApplicationTrigger,
EmrServerlessDeleteApplicationTrigger,
EmrServerlessStartApplicationTrigger,
EmrServerlessStartJobTrigger,
EmrServerlessStopApplicationTrigger,
EmrTerminateJobFlowTrigger,
)
from airflow.providers.amazon.aws.utils.waiter import waiter
Expand Down Expand Up @@ -974,7 +978,7 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
:param release_label: The EMR release version associated with the application.
:param job_type: The type of application you want to start, such as Spark or Hive.
:param wait_for_completion: If true, wait for the Application to start before returning. Default to True.
If set to False, ``waiter_countdown`` and ``waiter_check_interval_seconds`` will only be applied when
If set to False, ``waiter_max_attempts`` and ``waiter_delay`` will only be applied when
waiting for the application to be in the ``CREATED`` state.
:param client_request_token: The client idempotency token of the application to create.
Its value must be unique for each request.
Expand All @@ -987,6 +991,9 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
:waiter_max_attempts: Number of times the waiter should poll the application to check the state.
If not set, the waiter will use its default value.
:param waiter_delay: Number of seconds between polling the state of the application.
:param deferrable: If True, the operator will wait asynchronously for application to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
"""

def __init__(
Expand All @@ -1001,6 +1008,7 @@ def __init__(
waiter_check_interval_seconds: int | ArgNotSet = NOTSET,
waiter_max_attempts: int | ArgNotSet = NOTSET,
waiter_delay: int | ArgNotSet = NOTSET,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
if waiter_check_interval_seconds is NOTSET:
Expand Down Expand Up @@ -1032,6 +1040,7 @@ def __init__(
self.config = config or {}
self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type]
self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
self.deferrable = deferrable
super().__init__(**kwargs)

self.client_request_token = client_request_token or str(uuid4())
Expand All @@ -1054,8 +1063,19 @@ def execute(self, context: Context) -> str | None:
raise AirflowException(f"Application Creation failed: {response}")

self.log.info("EMR serverless application created: %s", application_id)
waiter = self.hook.get_waiter("serverless_app_created")
if self.deferrable:
self.defer(
trigger=EmrServerlessCreateApplicationTrigger(
application_id=application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="start_application_deferred",
)

waiter = self.hook.get_waiter("serverless_app_created")
wait(
waiter=waiter,
waiter_delay=self.waiter_delay,
Expand All @@ -1081,6 +1101,32 @@ def execute(self, context: Context) -> str | None:
)
return application_id

def start_application_deferred(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] != "success":
raise AirflowException(f"Application {event['application_id']} failed to create")
self.log.info("Starting application %s", event["application_id"])
self.hook.conn.start_application(applicationId=event["application_id"])
self.defer(
trigger=EmrServerlessStartApplicationTrigger(
application_id=event["application_id"],
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None or event["status"] != "success":
raise AirflowException(f"Trigger error: Application failed to start, event is {event}")

self.log.info("Application %s started", event["application_id"])
return event["application_id"]


class EmrServerlessStartJobOperator(BaseOperator):
"""
Expand Down Expand Up @@ -1312,14 +1358,21 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
:param application_id: ID of the EMR Serverless application to stop.
:param wait_for_completion: If true, wait for the Application to stop before returning. Default to True
:param aws_conn_id: AWS connection to use
:param waiter_countdown: Total amount of time, in seconds, the operator will wait for
:param waiter_countdown: (deprecated) Total amount of time, in seconds, the operator will wait for
the application be stopped. Defaults to 5 minutes.
:param waiter_check_interval_seconds: Number of seconds between polling the state of the application.
Defaults to 30 seconds.
:param waiter_check_interval_seconds: (deprecated) Number of seconds between polling the state of the
application. Defaults to 60 seconds.
:param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled.
Otherwise, trying to stop an app with running jobs will return an error.
If you want to wait for the jobs to finish gracefully, use
:class:`airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor`
:waiter_max_attempts: Number of times the waiter should poll the application to check the state.
Default is 25.
:param waiter_delay: Number of seconds between polling the state of the application.
Default is 60 seconds.
:param deferrable: If True, the operator will wait asynchronously for the application to stop.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
"""

template_fields: Sequence[str] = ("application_id",)
Expand All @@ -1334,6 +1387,7 @@ def __init__(
waiter_max_attempts: int | ArgNotSet = NOTSET,
waiter_delay: int | ArgNotSet = NOTSET,
force_stop: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
if waiter_check_interval_seconds is NOTSET:
Expand All @@ -1359,10 +1413,11 @@ def __init__(
)
self.aws_conn_id = aws_conn_id
self.application_id = application_id
self.wait_for_completion = wait_for_completion
self.wait_for_completion = False if deferrable else wait_for_completion
self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type]
self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
self.force_stop = force_stop
self.deferrable = deferrable
super().__init__(**kwargs)

@cached_property
Expand All @@ -1374,16 +1429,46 @@ def execute(self, context: Context) -> None:
self.log.info("Stopping application: %s", self.application_id)

if self.force_stop:
self.hook.cancel_running_jobs(
self.application_id,
waiter_config={
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
},
count = self.hook.cancel_running_jobs(
application_id=self.application_id,
wait_for_completion=False,
)
if count > 0:
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
if self.deferrable:
self.defer(
trigger=EmrServerlessCancelJobsTrigger(
application_id=self.application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="stop_application",
)
self.hook.get_waiter("no_job_running").wait(
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
applicationId=self.application_id,
states=list(self.hook.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
WaiterConfig={
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
},
)
else:
self.log.info("no running jobs found with application ID %s", self.application_id)

self.hook.conn.stop_application(applicationId=self.application_id)

if self.deferrable:
self.defer(
trigger=EmrServerlessStopApplicationTrigger(
application_id=self.application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="execute_complete",
)
if self.wait_for_completion:
waiter = self.hook.get_waiter("serverless_app_stopped")
wait(
Expand All @@ -1397,6 +1482,30 @@ def execute(self, context: Context) -> None:
)
self.log.info("EMR serverless application %s stopped successfully", self.application_id)

def stop_application(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] == "success":
self.hook.conn.stop_application(applicationId=self.application_id)
self.defer(
trigger=EmrServerlessStopApplicationTrigger(
application_id=self.application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] == "success":
self.log.info("EMR serverless application %s stopped successfully", self.application_id)
syedahsn marked this conversation as resolved.
Show resolved Hide resolved


class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperator):
"""
Expand All @@ -1410,10 +1519,17 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
:param wait_for_completion: If true, wait for the Application to be deleted before returning.
Defaults to True. Note that this operator will always wait for the application to be STOPPED first.
:param aws_conn_id: AWS connection to use
:param waiter_countdown: Total amount of time, in seconds, the operator will wait for each step of first,
the application to be stopped, and then deleted. Defaults to 25 minutes.
:param waiter_check_interval_seconds: Number of seconds between polling the state of the application.
:param waiter_countdown: (deprecated) Total amount of time, in seconds, the operator will wait for each
step of first,the application to be stopped, and then deleted. Defaults to 25 minutes.
:param waiter_check_interval_seconds: (deprecated) Number of seconds between polling the state
of the application. Defaults to 60 seconds.
:waiter_max_attempts: Number of times the waiter should poll the application to check the state.
Defaults to 25.
:param waiter_delay: Number of seconds between polling the state of the application.
Defaults to 60 seconds.
:param deferrable: If True, the operator will wait asynchronously for application to be deleted.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled.
Otherwise, trying to delete an app with running jobs will return an error.
If you want to wait for the jobs to finish gracefully, use
Expand All @@ -1432,6 +1548,7 @@ def __init__(
waiter_max_attempts: int | ArgNotSet = NOTSET,
waiter_delay: int | ArgNotSet = NOTSET,
force_stop: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
if waiter_check_interval_seconds is NOTSET:
Expand Down Expand Up @@ -1467,6 +1584,8 @@ def __init__(
force_stop=force_stop,
**kwargs,
)
self.deferrable = deferrable
self.wait_for_delete_completion = False if deferrable else wait_for_completion

def execute(self, context: Context) -> None:
# super stops the app (or makes sure it's already stopped)
Expand All @@ -1478,7 +1597,19 @@ def execute(self, context: Context) -> None:
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Application deletion failed: {response}")

if self.wait_for_delete_completion:
if self.deferrable:
self.defer(
trigger=EmrServerlessDeleteApplicationTrigger(
application_id=self.application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="execute_complete",
)

elif self.wait_for_delete_completion:
waiter = self.hook.get_waiter("serverless_app_terminated")

wait(
Expand All @@ -1492,3 +1623,10 @@ def execute(self, context: Context) -> None:
)

self.log.info("EMR serverless application deleted")

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] == "success":
self.log.info("EMR serverless application %s deleted successfully", self.application_id)
Loading