From 59cd614bcb192fccc6c274ed1cecbb65407ba572 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Tue, 23 May 2023 05:49:36 +0000 Subject: [PATCH 01/18] Add extra operator links for EMR Serverless - Includes Dashboard UI, S3 and CloudWatch consoles - Only shows links relevant to the job --- airflow/providers/amazon/aws/links/emr.py | 124 ++++++++++++++++- airflow/providers/amazon/aws/operators/emr.py | 125 +++++++++++++++++- airflow/providers/amazon/provider.yaml | 4 + 3 files changed, 245 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index 6c8cd2181eee1..f6d329b4b6c8c 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -16,19 +16,24 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any, TYPE_CHECKING +from airflow.models import XCom + +if TYPE_CHECKING: + import boto3 + + from airflow.models import BaseOperator + from airflow.models.taskinstancekey import TaskInstanceKey from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink from airflow.utils.helpers import exactly_one -if TYPE_CHECKING: - import boto3 - class EmrClusterLink(BaseAwsLink): - """Helper class for constructing AWS EMR Cluster Link.""" + """Helper class for constructing Amazon EMR Cluster Link.""" name = "EMR Cluster" key = "emr_cluster" @@ -36,7 +41,7 @@ class EmrClusterLink(BaseAwsLink): class EmrLogsLink(BaseAwsLink): - """Helper class for constructing AWS EMR Logs Link.""" + """Helper class for constructing Amazon EMR Logs Link.""" name = "EMR Cluster Logs" key = "emr_logs" @@ -48,6 +53,16 @@ def format_link(self, **kwargs) -> str: return super().format_link(**kwargs) +def get_serverless_log_uri(*, s3_log_uri: str, application_id: str, job_run_id: str) -> str: + """ + Retrieves the S3 URI to EMR Serverless Job logs. + + Any EMR Serverless job may have a different S3 logging location (or none), which is an S3 URI. + The logging location is then {s3_uri}/applications/{application_id}/jobs/{job_run_id}. + """ + return f"{s3_log_uri}/applications/{application_id}/jobs/{job_run_id}" + + def get_log_uri( *, cluster: dict[str, Any] | None = None, emr_client: boto3.client = None, job_flow_id: str | None = None ) -> str | None: @@ -66,3 +81,100 @@ def get_log_uri( return None log_uri = S3Hook.parse_s3_url(cluster_info["LogUri"]) return "/".join(log_uri) + + +class EmrServerlessLogsLink(BaseAwsLink): + """Helper class for constructing Amazon EMR Serverless Logs Link.""" + + name = "Spark Driver stdout" + key = "emr_serverless_logs" + + def get_link( + self, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, + ) -> str: + """ + Link to Amazon Web Services Console. + + :param operator: airflow operator + :param ti_key: TaskInstance ID to return link for + :return: link to external system + """ + conf = XCom.get_value(key=self.key, ti_key=ti_key) + if not conf: + return "" + conn_id = operator.aws_conn_id + hook = EmrServerlessHook(aws_conn_id=conn_id) + resp = hook.conn.get_dashboard_for_job_run( + applicationId=conf.get("application_id"), jobRunId=conf.get("job_run_id") + ) + o = urlparse(resp["url"]) + return o._replace(path="/logs/SPARK_DRIVER/stdout.gz").geturl() + + +class EmrServerlessDashboardLink(BaseAwsLink): + """Helper class for constructing Amazon EMR Serverless Dashboard Link.""" + + name = "EMR Serverless Dashboard" + key = "emr_serverless_dashboard" + + def get_link( + self, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, + ) -> str: + """ + Link to Amazon Web Services Console. + + :param operator: airflow operator + :param ti_key: TaskInstance ID to return link for + :return: link to external system + """ + conf = XCom.get_value(key=self.key, ti_key=ti_key) + if not conf: + return "" + conn_id = operator.aws_conn_id + hook = EmrServerlessHook(aws_conn_id=conn_id) + # Dashboard cannot be served when job is pending/scheduled + resp = hook.conn.get_dashboard_for_job_run( + applicationId=conf.get("application_id"), jobRunId=conf.get("job_run_id") + ) + return resp["url"] + + +class EmrServerlessS3LogsLink(BaseAwsLink): + """Helper class for constructing Amazon EMR Serverless Logs Link.""" + + name = "S3 Logs" + key = "emr_serverless_s3_logs" + format_str = ( + BASE_AWS_CONSOLE_LINK + + "/s3/buckets/{bucket_name}?region={region_name}&prefix={prefix}/applications/{application_id}/jobs/{job_run_id}/" # noqa: E501 + ) + + def format_link(self, **kwargs) -> str: + bucket, prefix = S3Hook.parse_s3_url(kwargs["log_uri"]) + kwargs["bucket_name"] = bucket + kwargs["prefix"] = prefix.rstrip("/") + return super().format_link(**kwargs) + + +class EmrServerlessCloudWatchLogsLink(BaseAwsLink): + """Helper class for constructing Amazon EMR Serverless Logs Link.""" + + name = "CloudWatch Logs" + key = "emr_serverless_cloudwatch_logs" + format_str = ( + BASE_AWS_CONSOLE_LINK + + "/cloudwatch/home?region={region_name}#logsV2:log-groups/log-group/{awslogs_group}{stream_prefix}" + ) + + def format_link(self, **kwargs) -> str: + kwargs["awslogs_group"] = quote_plus(kwargs["awslogs_group"]) + kwargs["stream_prefix"] = quote_plus("?logStreamNameFilter=").replace("%", "$") + quote_plus( + kwargs["stream_prefix"] + ) + return super().format_link(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 1bf2375a16a2f..9d39298510fe8 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -28,7 +28,16 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook -from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.links.emr import ( + EmrClusterLink, + EmrLogsLink, + EmrServerlessCloudWatchLogsLink, + EmrServerlessDashboardLink, + EmrServerlessLogsLink, + EmrServerlessS3LogsLink, + get_log_uri, +) from airflow.providers.amazon.aws.triggers.emr import ( EmrAddStepsTrigger, EmrContainerTrigger, @@ -1166,6 +1175,7 @@ class EmrServerlessStartJobOperator(BaseOperator): "execution_role_arn", "job_driver", "configuration_overrides", + "aws_conn_id", ) template_fields_renderers = { @@ -1173,6 +1183,18 @@ class EmrServerlessStartJobOperator(BaseOperator): "configuration_overrides": "json", } + @property + def operator_extra_links(self): + op_extra_links = [EmrServerlessDashboardLink()] + if "sparkSubmit" in self.job_driver: + op_extra_links.extend([EmrServerlessLogsLink()]) + if self.has_monitoring_enabled("s3MonitoringConfiguration"): + op_extra_links.extend([EmrServerlessS3LogsLink()]) + if self.has_monitoring_enabled("cloudWatchLoggingConfiguration"): + op_extra_links.extend([EmrServerlessCloudWatchLogsLink()]) + + return tuple(op_extra_links) + def __init__( self, application_id: str, @@ -1234,7 +1256,6 @@ def hook(self) -> EmrServerlessHook: return EmrServerlessHook(aws_conn_id=self.aws_conn_id) def execute(self, context: Context, event: dict[str, Any] | None = None) -> str | None: - app_state = self.hook.conn.get_application(applicationId=self.application_id)["application"]["state"] if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES: self.log.info("Application state is %s", app_state) @@ -1277,6 +1298,9 @@ def execute(self, context: Context, event: dict[str, Any] | None = None) -> str self.job_id = response["jobRunId"] self.log.info("EMR serverless job started: %s", self.job_id) + + self.persist_links(context) + if self.deferrable: self.defer( trigger=EmrServerlessStartJobTrigger( @@ -1289,6 +1313,7 @@ def execute(self, context: Context, event: dict[str, Any] | None = None) -> str method_name="execute_complete", timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), ) + if self.wait_for_completion: waiter = self.hook.get_waiter("serverless_job_completed") wait( @@ -1346,6 +1371,102 @@ def on_kill(self) -> None: check_interval_seconds=self.waiter_delay, ) + def has_monitoring_enabled(self, config_key: str) -> bool: + """ + Check if monitoring is enabled for the job. + + This is used to determine what extra links should be shown. + """ + monitoring_config = (self.configuration_overrides or {}).get("monitoringConfiguration") + if monitoring_config is None or config_key not in monitoring_config: + return False + + # CloudWatch can have an "enabled" flag set to False + if config_key == "cloudWatchLoggingConfiguration": + return monitoring_config.get(config_key).get("enabled") is True + + return config_key in monitoring_config + + def persist_links(self, context: Context): + """Populate the relevant extra links for the EMR Serverless jobs.""" + # Persist the EMR Serverless Dashboard link (Spark/Tez UI) + EmrServerlessDashboardLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + application_id=self.application_id, + job_run_id=self.job_id, + ) + + # If this is a Spark job, persist the EMR Serverless logs link (Driver stdout) + if "sparkSubmit" in self.job_driver: + EmrServerlessLogsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + application_id=self.application_id, + job_run_id=self.job_id, + ) + + # Add S3 and/or CloudWatch links if either is enabled + if self.has_monitoring_enabled("s3MonitoringConfiguration"): + log_uri = ( + (self.configuration_overrides or {}) + .get("monitoringConfiguration", {}) + .get("s3MonitoringConfiguration", {}) + .get("logUri") + ) + bucket, prefix = S3Hook.parse_s3_url( + f"{log_uri.rstrip('/')}/applications/{self.application_id}/jobs/{self.job_id}" + ) + EmrServerlessS3LogsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + log_uri=log_uri, + application_id=self.application_id, + job_run_id=self.job_id, + ) + emrs_s3_url = EmrServerlessS3LogsLink().format_link( + aws_domain=EmrServerlessCloudWatchLogsLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + log_uri=log_uri, + application_id=self.application_id, + job_run_id=self.job_id, + ) + self.log.info("You can view EMR Serverless Job run S3 logs at: %s", emrs_s3_url) + + if self.has_monitoring_enabled("cloudWatchLoggingConfiguration"): + cloudwatch_config = ( + (self.configuration_overrides or {}) + .get("monitoringConfiguration", {}) + .get("cloudWatchLoggingConfiguration", {}) + ) + log_group_name = cloudwatch_config.get("logGroupName", "/aws/emr-serverless") + log_stream_prefix = cloudwatch_config.get("logStreamNamePrefix", "") + log_stream_prefix = f"{log_stream_prefix}/applications/{self.application_id}/jobs/{self.job_id}" + + EmrServerlessCloudWatchLogsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + awslogs_group=log_group_name, + stream_prefix=log_stream_prefix, + ) + emrs_cloudwatch_url = EmrServerlessCloudWatchLogsLink().format_link( + aws_domain=EmrServerlessCloudWatchLogsLink.get_aws_domain(self.hook.conn_partition), + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + awslogs_group=log_group_name, + stream_prefix=log_stream_prefix, + ) + self.log.info("You can view EMR Serverless Job run CloudWatch logs at: %s", emrs_cloudwatch_url) + class EmrServerlessStopApplicationOperator(BaseOperator): """ diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 19c92314d3a72..cfa6aacd22115 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -673,6 +673,10 @@ extra-links: - airflow.providers.amazon.aws.links.batch.BatchJobQueueLink - airflow.providers.amazon.aws.links.emr.EmrClusterLink - airflow.providers.amazon.aws.links.emr.EmrLogsLink + - airflow.providers.amazon.aws.links.emr.EmrServerlessCloudWatchLogsLink + - airflow.providers.amazon.aws.links.emr.EmrServerlessDashboardLink + - airflow.providers.amazon.aws.links.emr.EmrServerlessLogsLink + - airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink - airflow.providers.amazon.aws.links.glue.GlueJobRunDetailsLink - airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink From 93170c5405be336852163148abe0fe2f6fbe04fa Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Fri, 8 Sep 2023 15:39:08 -0700 Subject: [PATCH 02/18] Fix imports and add context mock to tests --- airflow/providers/amazon/aws/links/emr.py | 4 ++- .../aws/operators/test_emr_serverless.py | 29 ++++++++++--------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index f6d329b4b6c8c..c994cdd0986ab 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -16,7 +16,9 @@ # under the License. from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from urllib.parse import quote_plus, urlparse + from airflow.models import XCom if TYPE_CHECKING: diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index a93791e1a1e3f..eed160f55babf 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -356,6 +356,9 @@ def test_create_application_deferrable(self, mock_conn): class TestEmrServerlessStartJobOperator: + def setup_method(self): + self.mock_context = MagicMock() + @mock.patch.object(EmrServerlessHook, "get_waiter") @mock.patch.object(EmrServerlessHook, "conn") def test_job_run_app_started(self, mock_conn, mock_get_waiter): @@ -376,7 +379,7 @@ def test_job_run_app_started(self, mock_conn, mock_get_waiter): configuration_overrides=configuration_overrides, ) default_name = operator.name - id = operator.execute(None) + id = operator.execute(self.mock_context) assert operator.wait_for_completion is True mock_conn.get_application.assert_called_once_with(applicationId=application_id) @@ -415,7 +418,7 @@ def test_job_run_job_failed(self, mock_conn, mock_get_waiter): ) default_name = operator.name with pytest.raises(AirflowException) as ex_message: - id = operator.execute(None) + id = operator.execute(self.mock_context) assert id == job_run_id assert "Serverless Job failed:" in str(ex_message.value) mock_conn.get_application.assert_called_once_with(applicationId=application_id) @@ -448,7 +451,7 @@ def test_job_run_app_not_started(self, mock_conn, mock_get_waiter): ) default_name = operator.name - id = operator.execute(None) + id = operator.execute(self.mock_context) assert operator.wait_for_completion is True mock_conn.get_application.assert_called_once_with(applicationId=application_id) @@ -492,7 +495,7 @@ def test_job_run_app_not_started_app_failed(self, mock_conn, mock_get_waiter, mo configuration_overrides=configuration_overrides, ) with pytest.raises(AirflowException) as ex_message: - operator.execute(None) + operator.execute(self.mock_context) assert "Serverless Application failed to start:" in str(ex_message.value) assert operator.wait_for_completion is True assert mock_get_waiter().wait.call_count == 2 @@ -517,7 +520,7 @@ def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_ge wait_for_completion=False, ) default_name = operator.name - id = operator.execute(None) + id = operator.execute(self.mock_context) mock_conn.get_application.assert_called_once_with(applicationId=application_id) mock_get_waiter().wait.assert_called_once() @@ -551,7 +554,7 @@ def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_get_wa wait_for_completion=False, ) default_name = operator.name - id = operator.execute(None) + id = operator.execute(self.mock_context) assert id == job_run_id mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token, @@ -583,7 +586,7 @@ def test_failed_start_job_run(self, mock_conn, mock_get_waiter): ) default_name = operator.name with pytest.raises(AirflowException) as ex_message: - operator.execute(None) + operator.execute(self.mock_context) assert "EMR serverless job failed to start:" in str(ex_message.value) mock_conn.get_application.assert_called_once_with(applicationId=application_id) @@ -621,7 +624,7 @@ def test_start_job_run_fail_on_wait_for_completion(self, mock_conn, mock_get_wai ) default_name = operator.name with pytest.raises(AirflowException) as ex_message: - operator.execute(None) + operator.execute(self.mock_context) assert "Serverless Job failed:" in str(ex_message.value) mock_conn.get_application.call_count == 2 @@ -653,7 +656,7 @@ def test_start_job_default_name(self, mock_conn, mock_get_waiter): job_driver=job_driver, configuration_overrides=configuration_overrides, ) - operator.execute(None) + operator.execute(self.mock_context) default_name = operator.name generated_name_uuid = default_name.split("_")[-1] assert default_name.startswith("emr_serverless_job_airflow") @@ -687,7 +690,7 @@ def test_start_job_custom_name(self, mock_conn, mock_get_waiter): configuration_overrides=configuration_overrides, name=custom_name, ) - operator.execute(None) + operator.execute(self.mock_context) mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token, @@ -717,7 +720,7 @@ def test_cancel_job_run(self, mock_conn): wait_for_completion=False, ) - id = operator.execute(None) + id = operator.execute(self.mock_context) operator.on_kill() mock_conn.cancel_job_run.assert_called_once_with( applicationId=application_id, @@ -768,7 +771,7 @@ def test_start_job_deferrable(self, mock_conn): ) with pytest.raises(TaskDeferred): - operator.execute(None) + operator.execute(self.mock_context) @mock.patch.object(EmrServerlessHook, "get_waiter") @mock.patch.object(EmrServerlessHook, "conn") @@ -788,7 +791,7 @@ def test_start_job_deferrable_app_not_started(self, mock_conn, mock_get_waiter): ) with pytest.raises(TaskDeferred): - operator.execute(None) + operator.execute(self.mock_context) class TestEmrServerlessDeleteOperator: From 906c1de9084da3a7e8d7e989912072ba36c4e461 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Fri, 8 Sep 2023 15:41:16 -0700 Subject: [PATCH 03/18] Move TYPE_CHECK --- airflow/providers/amazon/aws/links/emr.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index c994cdd0986ab..29d31aa61e26f 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -19,7 +19,12 @@ from typing import TYPE_CHECKING, Any from urllib.parse import quote_plus, urlparse +from airflow.exceptions import AirflowException from airflow.models import XCom +from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink +from airflow.utils.helpers import exactly_one if TYPE_CHECKING: import boto3 @@ -27,12 +32,6 @@ from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook -from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink -from airflow.utils.helpers import exactly_one - class EmrClusterLink(BaseAwsLink): """Helper class for constructing Amazon EMR Cluster Link.""" From 18a6b0a2d690df774701f613543cbbb86297b431 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Tue, 24 Oct 2023 12:46:20 -0700 Subject: [PATCH 04/18] Remove unused variables --- airflow/providers/amazon/aws/operators/emr.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 9d39298510fe8..a34c70014e670 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -28,7 +28,6 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook -from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.links.emr import ( EmrClusterLink, EmrLogsLink, @@ -1418,9 +1417,6 @@ def persist_links(self, context: Context): .get("s3MonitoringConfiguration", {}) .get("logUri") ) - bucket, prefix = S3Hook.parse_s3_url( - f"{log_uri.rstrip('/')}/applications/{self.application_id}/jobs/{self.job_id}" - ) EmrServerlessS3LogsLink.persist( context=context, operator=self, @@ -1438,7 +1434,7 @@ def persist_links(self, context: Context): application_id=self.application_id, job_run_id=self.job_id, ) - self.log.info("You can view EMR Serverless Job run S3 logs at: %s", emrs_s3_url) + self.log.info("S3 logs available at: %s", emrs_s3_url) if self.has_monitoring_enabled("cloudWatchLoggingConfiguration"): cloudwatch_config = ( @@ -1465,7 +1461,7 @@ def persist_links(self, context: Context): awslogs_group=log_group_name, stream_prefix=log_stream_prefix, ) - self.log.info("You can view EMR Serverless Job run CloudWatch logs at: %s", emrs_cloudwatch_url) + self.log.info("CloudWatch logs available at: %s", emrs_cloudwatch_url) class EmrServerlessStopApplicationOperator(BaseOperator): From e1d8970b16fcb02beac760a151e6bfcf3cfa6e82 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Tue, 24 Oct 2023 12:51:03 -0700 Subject: [PATCH 05/18] Pass in connection ID string instead of operator --- airflow/providers/amazon/aws/links/emr.py | 6 ++---- airflow/providers/amazon/aws/operators/emr.py | 2 ++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index 29d31aa61e26f..20da2a06fbaef 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -106,8 +106,7 @@ def get_link( conf = XCom.get_value(key=self.key, ti_key=ti_key) if not conf: return "" - conn_id = operator.aws_conn_id - hook = EmrServerlessHook(aws_conn_id=conn_id) + hook = EmrServerlessHook(aws_conn_id=conf.get("conn_id")) resp = hook.conn.get_dashboard_for_job_run( applicationId=conf.get("application_id"), jobRunId=conf.get("job_run_id") ) @@ -137,8 +136,7 @@ def get_link( conf = XCom.get_value(key=self.key, ti_key=ti_key) if not conf: return "" - conn_id = operator.aws_conn_id - hook = EmrServerlessHook(aws_conn_id=conn_id) + hook = EmrServerlessHook(aws_conn_id=conf.get("conn_id")) # Dashboard cannot be served when job is pending/scheduled resp = hook.conn.get_dashboard_for_job_run( applicationId=conf.get("application_id"), jobRunId=conf.get("job_run_id") diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index a34c70014e670..3d0ee2d52a4a5 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1394,6 +1394,7 @@ def persist_links(self, context: Context): operator=self, region_name=self.hook.conn_region_name, aws_partition=self.hook.conn_partition, + conn_id=self.hook.aws_conn_id, application_id=self.application_id, job_run_id=self.job_id, ) @@ -1405,6 +1406,7 @@ def persist_links(self, context: Context): operator=self, region_name=self.hook.conn_region_name, aws_partition=self.hook.conn_partition, + conn_id=self.hook.aws_conn_id, application_id=self.application_id, job_run_id=self.job_id, ) From b56a0fcbc41cedbfafa20e313c17db5dc8dcb328 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Tue, 24 Oct 2023 13:55:33 -0700 Subject: [PATCH 06/18] Use mock.MagicMock --- tests/providers/amazon/aws/operators/test_emr_serverless.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 733271bb4b6c2..6b42c61d1c056 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -357,7 +357,7 @@ def test_create_application_deferrable(self, mock_conn): class TestEmrServerlessStartJobOperator: def setup_method(self): - self.mock_context = MagicMock() + self.mock_context = mock.MagicMock() @mock.patch.object(EmrServerlessHook, "get_waiter") @mock.patch.object(EmrServerlessHook, "conn") From 987cce066d643dd0df33ddbd9a64bab3de4e4627 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Tue, 24 Oct 2023 15:31:01 -0700 Subject: [PATCH 07/18] Disable application UI logs by default --- airflow/providers/amazon/aws/operators/emr.py | 20 ++++++++++++++++--- .../operators/emr/emr_serverless.rst | 12 +++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 625f573826bf3..1596f45049913 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1173,6 +1173,9 @@ class EmrServerlessStartJobOperator(BaseOperator): :param deferrable: If True, the operator will wait asynchronously for the crawl to complete. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) + :param enable_application_ui_links: If True, the operator will generate one-time links to EMR Serverless application UIs. + The generated links will allow any user with access to the DAG to see the Spark or Tez UI or Spark stdout logs. + Defaults to False. """ template_fields: Sequence[str] = ( @@ -1191,9 +1194,18 @@ class EmrServerlessStartJobOperator(BaseOperator): @property def operator_extra_links(self): - op_extra_links = [EmrServerlessDashboardLink()] - if "sparkSubmit" in self.job_driver: - op_extra_links.extend([EmrServerlessLogsLink()]) + """ + Dynamically add extra links depending on the job type and if they're enabled. + + If S3 or CloudWatch monitoring configurations exist, add links directly to the relevant conoles. + Only add dashboard links if they're explicitly enabled. These are one-time links that any user can acccess, + but expire on first click or one hour, whichever comes first. + """ + op_extra_links = [] + if self.enable_application_ui_links: + op_extra_links.extend([EmrServerlessDashboardLink()]) + if "sparkSubmit" in self.job_driver: + op_extra_links.extend([EmrServerlessLogsLink()]) if self.has_monitoring_enabled("s3MonitoringConfiguration"): op_extra_links.extend([EmrServerlessS3LogsLink()]) if self.has_monitoring_enabled("cloudWatchLoggingConfiguration"): @@ -1217,6 +1229,7 @@ def __init__( waiter_max_attempts: int | ArgNotSet = NOTSET, waiter_delay: int | ArgNotSet = NOTSET, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + enable_application_ui_links: bool = False, **kwargs, ): if waiter_check_interval_seconds is NOTSET: @@ -1252,6 +1265,7 @@ def __init__( self.waiter_delay = int(waiter_delay) # type: ignore[arg-type] self.job_id: str | None = None self.deferrable = deferrable + self.enable_application_ui_links = enable_application_ui_links super().__init__(**kwargs) self.client_request_token = client_request_token or str(uuid4()) diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst index bcd5995e5c909..d4ce08da7bc31 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst @@ -67,6 +67,18 @@ the aiobotocore module to be installed. .. _howto/operator:EmrServerlessStopApplicationOperator: +Open Application UIs +"""""""""""""""""""" + +The operator can also be configured to generate one-time links to the application UIs and Spark stdout logs +by passing the ``enable_application_ui_links=True`` as a parameter. + +You need to ensure you have the following IAM permissions to generate the dashboard link. + +.. code-block:: + + "emr-serverless:GetDashboardForJobRun" + Stop an EMR Serverless Application ================================== From 76e19acb19c346970f0798bf26ca2bcba0c3221c Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Tue, 24 Oct 2023 16:15:40 -0700 Subject: [PATCH 08/18] Update doc lints --- airflow/providers/amazon/aws/operators/emr.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 1596f45049913..99fd7688e2444 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1173,9 +1173,9 @@ class EmrServerlessStartJobOperator(BaseOperator): :param deferrable: If True, the operator will wait asynchronously for the crawl to complete. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) - :param enable_application_ui_links: If True, the operator will generate one-time links to EMR Serverless application UIs. - The generated links will allow any user with access to the DAG to see the Spark or Tez UI or Spark stdout logs. - Defaults to False. + :param enable_application_ui_links: If True, the operator will generate one-time links to EMR Serverless + application UIs. The generated links will allow any user with access to the DAG to see the Spark or + Tez UI or Spark stdout logs. Defaults to False. """ template_fields: Sequence[str] = ( @@ -1197,9 +1197,9 @@ def operator_extra_links(self): """ Dynamically add extra links depending on the job type and if they're enabled. - If S3 or CloudWatch monitoring configurations exist, add links directly to the relevant conoles. - Only add dashboard links if they're explicitly enabled. These are one-time links that any user can acccess, - but expire on first click or one hour, whichever comes first. + If S3 or CloudWatch monitoring configurations exist, add links directly to the relevant consoles. + Only add dashboard links if they're explicitly enabled. These are one-time links that any user + can access, but expire on first click or one hour, whichever comes first. """ op_extra_links = [] if self.enable_application_ui_links: From 625500c06e6b7fd81d4df6593b966fba49cee1ba Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Mon, 30 Oct 2023 11:47:11 -0700 Subject: [PATCH 09/18] Update airflow/providers/amazon/aws/links/emr.py Co-authored-by: Andrey Anshin --- airflow/providers/amazon/aws/links/emr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index 20da2a06fbaef..19a87dbf97752 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -149,9 +149,9 @@ class EmrServerlessS3LogsLink(BaseAwsLink): name = "S3 Logs" key = "emr_serverless_s3_logs" - format_str = ( - BASE_AWS_CONSOLE_LINK - + "/s3/buckets/{bucket_name}?region={region_name}&prefix={prefix}/applications/{application_id}/jobs/{job_run_id}/" # noqa: E501 + format_str = BASE_AWS_CONSOLE_LINK + ( + "/s3/buckets/{bucket_name}?region={region_name}" + "&prefix={prefix}/applications/{application_id}/jobs/{job_run_id}/" ) def format_link(self, **kwargs) -> str: From 7a73d32093f4380c7eeb8b54fb70f914de9b2305 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Mon, 30 Oct 2023 14:08:08 -0700 Subject: [PATCH 10/18] Support dynamic task mapping --- airflow/providers/amazon/aws/operators/emr.py | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 99fd7688e2444..63dbf113f7fa5 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -27,6 +27,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator +from airflow.models.mappedoperator import MappedOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import ( EmrClusterLink, @@ -1202,14 +1203,33 @@ def operator_extra_links(self): can access, but expire on first click or one hour, whichever comes first. """ op_extra_links = [] - if self.enable_application_ui_links: + + if isinstance(self, MappedOperator): + enable_application_ui_links = self.partial_kwargs.get( + "enable_application_ui_links" + ) or self.expand_input.value.get("enable_application_ui_links") + job_driver = self.partial_kwargs.get( + "job_driver" + ) or self.expand_input.value.get("job_driver") + configuration_overrides = self.partial_kwargs.get( + "configuration_overrides" + ) or self.expand_input.value.get("configuration_overrides") + + else: + enable_application_ui_links = self.enable_application_ui_links + configuration_overrides = self.configuration_overrides + job_driver = self.job_driver + + + if enable_application_ui_links: op_extra_links.extend([EmrServerlessDashboardLink()]) - if "sparkSubmit" in self.job_driver: + if "sparkSubmit" in job_driver: op_extra_links.extend([EmrServerlessLogsLink()]) - if self.has_monitoring_enabled("s3MonitoringConfiguration"): + if self.is_monitoring_in_job_override("s3MonitoringConfiguration", configuration_overrides): op_extra_links.extend([EmrServerlessS3LogsLink()]) - if self.has_monitoring_enabled("cloudWatchLoggingConfiguration"): + if self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", configuration_overrides): op_extra_links.extend([EmrServerlessCloudWatchLogsLink()]) + return tuple(op_extra_links) @@ -1218,7 +1238,7 @@ def __init__( application_id: str, execution_role_arn: str, job_driver: dict, - configuration_overrides: dict | None, + configuration_overrides: dict | None = None, client_request_token: str = "", config: dict | None = None, wait_for_completion: bool = True, @@ -1391,13 +1411,16 @@ def on_kill(self) -> None: check_interval_seconds=self.waiter_delay, ) - def has_monitoring_enabled(self, config_key: str) -> bool: + def is_monitoring_in_job_override(self, config_key: str, job_override: dict | None) -> bool: """ Check if monitoring is enabled for the job. + Note: This is not compatible with application defaults: + https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/default-configs.html + This is used to determine what extra links should be shown. """ - monitoring_config = (self.configuration_overrides or {}).get("monitoringConfiguration") + monitoring_config = job_override.get("monitoringConfiguration") if monitoring_config is None or config_key not in monitoring_config: return False @@ -1433,7 +1456,7 @@ def persist_links(self, context: Context): ) # Add S3 and/or CloudWatch links if either is enabled - if self.has_monitoring_enabled("s3MonitoringConfiguration"): + if self.is_monitoring_in_job_override("s3MonitoringConfiguration", self.configuration_overrides): log_uri = ( (self.configuration_overrides or {}) .get("monitoringConfiguration", {}) @@ -1459,7 +1482,7 @@ def persist_links(self, context: Context): ) self.log.info("S3 logs available at: %s", emrs_s3_url) - if self.has_monitoring_enabled("cloudWatchLoggingConfiguration"): + if self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", self.configuration_overrides): cloudwatch_config = ( (self.configuration_overrides or {}) .get("monitoringConfiguration", {}) From e4b7efb0669395a34d8c8c2b6a98508f453c8010 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Tue, 31 Oct 2023 10:00:13 -0700 Subject: [PATCH 11/18] Lint/static check fixes --- airflow/providers/amazon/aws/operators/emr.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 63dbf113f7fa5..b9de6d5e29c59 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1208,9 +1208,7 @@ def operator_extra_links(self): enable_application_ui_links = self.partial_kwargs.get( "enable_application_ui_links" ) or self.expand_input.value.get("enable_application_ui_links") - job_driver = self.partial_kwargs.get( - "job_driver" - ) or self.expand_input.value.get("job_driver") + job_driver = self.partial_kwargs.get("job_driver") or self.expand_input.value.get("job_driver") configuration_overrides = self.partial_kwargs.get( "configuration_overrides" ) or self.expand_input.value.get("configuration_overrides") @@ -1220,7 +1218,6 @@ def operator_extra_links(self): configuration_overrides = self.configuration_overrides job_driver = self.job_driver - if enable_application_ui_links: op_extra_links.extend([EmrServerlessDashboardLink()]) if "sparkSubmit" in job_driver: @@ -1229,7 +1226,6 @@ def operator_extra_links(self): op_extra_links.extend([EmrServerlessS3LogsLink()]) if self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", configuration_overrides): op_extra_links.extend([EmrServerlessCloudWatchLogsLink()]) - return tuple(op_extra_links) @@ -1420,7 +1416,7 @@ def is_monitoring_in_job_override(self, config_key: str, job_override: dict | No This is used to determine what extra links should be shown. """ - monitoring_config = job_override.get("monitoringConfiguration") + monitoring_config = (job_override or {}).get("monitoringConfiguration") if monitoring_config is None or config_key not in monitoring_config: return False From eb62bf1f0a7f169b2269409902868124c6e2402c Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Fri, 17 Nov 2023 13:34:25 -0800 Subject: [PATCH 12/18] Update review comments --- airflow/providers/amazon/aws/links/emr.py | 18 +++++++++++------- .../operators/emr/emr_serverless.rst | 7 ++++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index 19a87dbf97752..2df62afdada8a 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -85,7 +85,7 @@ def get_log_uri( class EmrServerlessLogsLink(BaseAwsLink): - """Helper class for constructing Amazon EMR Serverless Logs Link.""" + """Helper class for constructing Amazon EMR Serverless link to Spark stdout logs.""" name = "Spark Driver stdout" key = "emr_serverless_logs" @@ -97,11 +97,11 @@ def get_link( ti_key: TaskInstanceKey, ) -> str: """ - Link to Amazon Web Services Console. + Pre-signed URL to the Spark stdout log. :param operator: airflow operator :param ti_key: TaskInstance ID to return link for - :return: link to external system + :return: Pre-signed URL to Spark stdout log. Empty string if no Spark stdout log is available. """ conf = XCom.get_value(key=self.key, ti_key=ti_key) if not conf: @@ -127,11 +127,11 @@ def get_link( ti_key: TaskInstanceKey, ) -> str: """ - Link to Amazon Web Services Console. + Pre-signed URL to the application UI for the EMR Serverless job. :param operator: airflow operator :param ti_key: TaskInstance ID to return link for - :return: link to external system + :return: Pre-signed URL to application UI. """ conf = XCom.get_value(key=self.key, ti_key=ti_key) if not conf: @@ -145,7 +145,7 @@ def get_link( class EmrServerlessS3LogsLink(BaseAwsLink): - """Helper class for constructing Amazon EMR Serverless Logs Link.""" + """Helper class for constructing link to S3 console for Amazon EMR Serverless Logs.""" name = "S3 Logs" key = "emr_serverless_s3_logs" @@ -162,7 +162,11 @@ def format_link(self, **kwargs) -> str: class EmrServerlessCloudWatchLogsLink(BaseAwsLink): - """Helper class for constructing Amazon EMR Serverless Logs Link.""" + """ + Helper class for constructing link to CloudWatch console for Amazon EMR Serverless Logs. + + This is a deep link that filters on a specific job run. + """ name = "CloudWatch Logs" key = "emr_serverless_cloudwatch_logs" diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst index d4ce08da7bc31..e69018f3a9932 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst @@ -71,7 +71,8 @@ Open Application UIs """""""""""""""""""" The operator can also be configured to generate one-time links to the application UIs and Spark stdout logs -by passing the ``enable_application_ui_links=True`` as a parameter. +by passing the ``enable_application_ui_links=True`` as a parameter. Once the job starts running, these links +are available in the Details section of the relevant Task. You need to ensure you have the following IAM permissions to generate the dashboard link. @@ -79,6 +80,10 @@ You need to ensure you have the following IAM permissions to generate the dashbo "emr-serverless:GetDashboardForJobRun" +If Amazon S3 or Amazon CloudWatch logs are +`enabled for EMR Serverless `__, +links to the respective console will also be available in the task logs and task Details. + Stop an EMR Serverless Application ================================== From fedc00a197e86b61cbc2d23df6348488fab88d9c Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Fri, 12 Jan 2024 12:23:37 -0800 Subject: [PATCH 13/18] Configure get_dashboard call for EMR Serverless to only retry once --- airflow/providers/amazon/aws/links/emr.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index c89b0ca5cd696..34fd7141a119f 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -106,7 +106,11 @@ def get_link( conf = XCom.get_value(key=self.key, ti_key=ti_key) if not conf: return "" - hook = EmrServerlessHook(aws_conn_id=conf.get("conn_id")) + # If get_dashboard_for_job_run fails for whatever reason, fail after 1 attempt + # so that the rest of the links load in a reasonable time frame. + hook = EmrServerlessHook( + aws_conn_id=conf.get("conn_id"), config={"retries": {"total_max_attempts": 1}} + ) resp = hook.conn.get_dashboard_for_job_run( applicationId=conf.get("application_id"), jobRunId=conf.get("job_run_id") ) @@ -136,8 +140,13 @@ def get_link( conf = XCom.get_value(key=self.key, ti_key=ti_key) if not conf: return "" - hook = EmrServerlessHook(aws_conn_id=conf.get("conn_id")) - # Dashboard cannot be served when job is pending/scheduled + # If get_dashboard_for_job_run fails for whatever reason, fail after 1 attempt + # so that the rest of the links load in a reasonable time frame. + hook = EmrServerlessHook( + aws_conn_id=conf.get("conn_id"), config={"retries": {"total_max_attempts": 1}} + ) + # Dashboard cannot be served when job is pending/scheduled, + # in which case an empty string still gets returned. resp = hook.conn.get_dashboard_for_job_run( applicationId=conf.get("application_id"), jobRunId=conf.get("job_run_id") ) @@ -164,7 +173,7 @@ def format_link(self, **kwargs) -> str: class EmrServerlessCloudWatchLogsLink(BaseAwsLink): """ Helper class for constructing link to CloudWatch console for Amazon EMR Serverless Logs. - + This is a deep link that filters on a specific job run. """ From dc2df1546d4fbf52f0ffaadb24faecc978ad8c50 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Fri, 12 Jan 2024 12:50:18 -0800 Subject: [PATCH 14/18] Whitespace --- .../operators/emr/emr_serverless.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst index e69018f3a9932..65a0fc8bfebe6 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst @@ -80,7 +80,7 @@ You need to ensure you have the following IAM permissions to generate the dashbo "emr-serverless:GetDashboardForJobRun" -If Amazon S3 or Amazon CloudWatch logs are +If Amazon S3 or Amazon CloudWatch logs are `enabled for EMR Serverless `__, links to the respective console will also be available in the task logs and task Details. From b9acdf6c93abbc3786f0deb0d102d54f5b877a33 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Mon, 22 Jan 2024 11:02:36 -0800 Subject: [PATCH 15/18] Add unit tests for EMRS link generation --- airflow/providers/amazon/aws/operators/emr.py | 21 +- .../aws/operators/test_emr_serverless.py | 304 +++++++++++++++++- 2 files changed, 314 insertions(+), 11 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index af5a8dd1be786..4c8b94ac60709 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1434,18 +1434,19 @@ def is_monitoring_in_job_override(self, config_key: str, job_override: dict | No def persist_links(self, context: Context): """Populate the relevant extra links for the EMR Serverless jobs.""" # Persist the EMR Serverless Dashboard link (Spark/Tez UI) - EmrServerlessDashboardLink.persist( - context=context, - operator=self, - region_name=self.hook.conn_region_name, - aws_partition=self.hook.conn_partition, - conn_id=self.hook.aws_conn_id, - application_id=self.application_id, - job_run_id=self.job_id, - ) + if self.enable_application_ui_links: + EmrServerlessDashboardLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + conn_id=self.hook.aws_conn_id, + application_id=self.application_id, + job_run_id=self.job_id, + ) # If this is a Spark job, persist the EMR Serverless logs link (Driver stdout) - if "sparkSubmit" in self.job_driver: + if self.enable_application_ui_links and "sparkSubmit" in self.job_driver: EmrServerlessLogsLink.persist( context=context, operator=self, diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 8103fe03ed9ab..eed292c3cd7e5 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -45,8 +45,24 @@ execution_role_arn = "test_emr_serverless_role_arn" job_driver = {"test_key": "test_value"} +spark_job_driver = {"sparkSubmit": {"entryPoint": "test.py"}} configuration_overrides = {"monitoringConfiguration": {"test_key": "test_value"}} job_run_id = "test_job_run_id" +s3_logs_location = "s3://test_bucket/test_key/" +cloudwatch_logs_group_name = "/aws/emrs" +cloudwatch_logs_prefix = "myapp" +s3_configuration_overrides = { + "monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": s3_logs_location}} +} +cloudwatch_configuration_overrides = { + "monitoringConfiguration": { + "cloudWatchLoggingConfiguration": { + "enabled": True, + "logGroupName": cloudwatch_logs_group_name, + "logStreamNamePrefix": cloudwatch_logs_prefix, + } + } +} application_id_delete_operator = "test_emr_serverless_delete_application_operator" @@ -777,7 +793,7 @@ def test_start_job_deferrable(self, mock_conn): @mock.patch.object(EmrServerlessHook, "get_waiter") @mock.patch.object(EmrServerlessHook, "conn") def test_start_job_deferrable_app_not_started(self, mock_conn, mock_get_waiter): - mock_get_waiter.return_value = True + mock_get_waiter.wait.return_value = True mock_conn.get_application.return_value = {"application": {"state": "CREATING"}} mock_conn.start_application.return_value = { "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -794,6 +810,292 @@ def test_start_job_deferrable_app_not_started(self, mock_conn, mock_get_waiter): with pytest.raises(TaskDeferred): operator.execute(self.mock_context) + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessCloudWatchLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessDashboardLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink.persist") + def test_links_start_job_default( + self, + mock_s3_logs_link, + mock_logs_link, + mock_dashboard_link, + mock_cloudwatch_link, + mock_conn, + mock_get_waiter, + ): + mock_get_waiter.wait.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + "jobRunId": job_run_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + ) + operator.execute(self.mock_context) + mock_conn.start_job_run.assert_called_once() + + mock_s3_logs_link.assert_not_called() + mock_logs_link.assert_not_called() + mock_dashboard_link.assert_not_called() + mock_cloudwatch_link.assert_not_called() + + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessCloudWatchLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessDashboardLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink.persist") + def test_links_s3_enabled( + self, + mock_s3_logs_link, + mock_logs_link, + mock_dashboard_link, + mock_cloudwatch_link, + mock_conn, + mock_get_waiter, + ): + mock_get_waiter.wait.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + "jobRunId": job_run_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=s3_configuration_overrides, + ) + operator.execute(self.mock_context) + mock_conn.start_job_run.assert_called_once() + + mock_logs_link.assert_not_called() + mock_dashboard_link.assert_not_called() + mock_cloudwatch_link.assert_not_called() + mock_s3_logs_link.assert_called_once_with( + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + aws_partition=mock.ANY, + log_uri=s3_logs_location, + application_id=application_id, + job_run_id=job_run_id, + ) + + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessCloudWatchLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessDashboardLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink.persist") + def test_links_cloudwatch_enabled( + self, + mock_s3_logs_link, + mock_logs_link, + mock_dashboard_link, + mock_cloudwatch_link, + mock_conn, + mock_get_waiter, + ): + mock_get_waiter.wait.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + "jobRunId": job_run_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=cloudwatch_configuration_overrides, + ) + operator.execute(self.mock_context) + mock_conn.start_job_run.assert_called_once() + + mock_logs_link.assert_not_called() + mock_dashboard_link.assert_not_called() + mock_s3_logs_link.assert_not_called() + mock_cloudwatch_link.assert_called_once_with( + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + aws_partition=mock.ANY, + awslogs_group=cloudwatch_logs_group_name, + stream_prefix=f"{cloudwatch_logs_prefix}/applications/{application_id}/jobs/{job_run_id}", + ) + + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessCloudWatchLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessDashboardLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink.persist") + def test_links_applicationui_enabled( + self, + mock_s3_logs_link, + mock_logs_link, + mock_dashboard_link, + mock_cloudwatch_link, + mock_conn, + mock_get_waiter, + ): + mock_get_waiter.wait.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + "jobRunId": job_run_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=cloudwatch_configuration_overrides, + enable_application_ui_links=True, + ) + operator.execute(self.mock_context) + mock_conn.start_job_run.assert_called_once() + + mock_logs_link.assert_not_called() + mock_s3_logs_link.assert_not_called() + mock_dashboard_link.assert_called_with( + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + aws_partition=mock.ANY, + conn_id=mock.ANY, + application_id=application_id, + job_run_id=job_run_id, + ) + mock_cloudwatch_link.assert_called_once_with( + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + aws_partition=mock.ANY, + awslogs_group=cloudwatch_logs_group_name, + stream_prefix=f"{cloudwatch_logs_prefix}/applications/{application_id}/jobs/{job_run_id}", + ) + + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessCloudWatchLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessDashboardLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink.persist") + def test_links_applicationui_with_spark_enabled( + self, + mock_s3_logs_link, + mock_logs_link, + mock_dashboard_link, + mock_cloudwatch_link, + mock_conn, + mock_get_waiter, + ): + mock_get_waiter.wait.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + "jobRunId": job_run_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=spark_job_driver, + configuration_overrides=s3_configuration_overrides, + enable_application_ui_links=True, + ) + operator.execute(self.mock_context) + mock_conn.start_job_run.assert_called_once() + + mock_logs_link.assert_called_once_with( + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + aws_partition=mock.ANY, + conn_id=mock.ANY, + application_id=application_id, + job_run_id=job_run_id, + ) + mock_dashboard_link.assert_called_with( + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + aws_partition=mock.ANY, + conn_id=mock.ANY, + application_id=application_id, + job_run_id=job_run_id, + ) + mock_cloudwatch_link.assert_not_called() + mock_s3_logs_link.assert_called_once_with( + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + aws_partition=mock.ANY, + log_uri=s3_logs_location, + application_id=application_id, + job_run_id=job_run_id, + ) + + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessCloudWatchLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessDashboardLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessLogsLink.persist") + @mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink.persist") + def test_links_spark_without_applicationui_enabled( + self, + mock_s3_logs_link, + mock_logs_link, + mock_dashboard_link, + mock_cloudwatch_link, + mock_conn, + mock_get_waiter, + ): + mock_get_waiter.wait.return_value = True + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + "jobRunId": job_run_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=spark_job_driver, + configuration_overrides=s3_configuration_overrides, + enable_application_ui_links=False, + ) + operator.execute(self.mock_context) + mock_conn.start_job_run.assert_called_once() + + mock_logs_link.assert_not_called() + mock_dashboard_link.assert_not_called() + mock_cloudwatch_link.assert_not_called() + mock_s3_logs_link.assert_called_once_with( + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + aws_partition=mock.ANY, + log_uri=s3_logs_location, + application_id=application_id, + job_run_id=job_run_id, + ) + class TestEmrServerlessDeleteOperator: @mock.patch.object(EmrServerlessHook, "get_waiter") From 91fb733a66d72b04e0d26cc3c23f300c59e62db8 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Mon, 12 Feb 2024 10:42:26 -0800 Subject: [PATCH 16/18] Address D401 check --- airflow/providers/amazon/aws/links/emr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index 34fd7141a119f..6417797e204ab 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -56,7 +56,7 @@ def format_link(self, **kwargs) -> str: def get_serverless_log_uri(*, s3_log_uri: str, application_id: str, job_run_id: str) -> str: """ - Retrieves the S3 URI to EMR Serverless Job logs. + Retrieve the S3 URI to EMR Serverless Job logs. Any EMR Serverless job may have a different S3 logging location (or none), which is an S3 URI. The logging location is then {s3_uri}/applications/{application_id}/jobs/{job_run_id}. From fd3f378ea1086df03456e7d924e53600a73f0af7 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Mon, 12 Feb 2024 22:34:11 -0800 Subject: [PATCH 17/18] Refactor get_serverless_dashboard_url into its own method, add link tests --- airflow/providers/amazon/aws/links/emr.py | 94 +++++------ tests/providers/amazon/aws/links/test_emr.py | 161 ++++++++++++++++++- 2 files changed, 202 insertions(+), 53 deletions(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index 6417797e204ab..163fc87ce261e 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -17,7 +17,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any -from urllib.parse import quote_plus, urlparse +from urllib.parse import ParseResult, quote_plus, urlparse from airflow.exceptions import AirflowException from airflow.models import XCom @@ -64,6 +64,33 @@ def get_serverless_log_uri(*, s3_log_uri: str, application_id: str, job_run_id: return f"{s3_log_uri}/applications/{application_id}/jobs/{job_run_id}" +def get_serverless_dashboard_url( + *, aws_conn_id: str | None = None, emrs_client: boto3.client = None, application_id: str, job_run_id: str +) -> ParseResult | None: + """ + Retrieve the URL to EMR Serverless dashboard. + + The URL is a one-use, ephemeral link that expires in 1 hour and is accessible without authentication. + + Either an AWS connection ID or existing EMR Serverless client must be passed. + If the connection ID is passed, a client is generated using that connection. + """ + if not exactly_one(aws_conn_id, emrs_client): + raise AirflowException("Requires either an AWS connection ID or an EMR Serverless Client.") + + if aws_conn_id: + # If get_dashboard_for_job_run fails for whatever reason, fail after 1 attempt + # so that the rest of the links load in a reasonable time frame. + hook = EmrServerlessHook(aws_conn_id=aws_conn_id, config={"retries": {"total_max_attempts": 1}}) + emrs_client = hook.conn + + response = emrs_client.get_dashboard_for_job_run(applicationId=application_id, jobRunId=job_run_id) + if "url" not in response: + return None + log_uri = urlparse(response["url"]) + return log_uri + + def get_log_uri( *, cluster: dict[str, Any] | None = None, emr_client: boto3.client = None, job_flow_id: str | None = None ) -> str | None: @@ -90,32 +117,14 @@ class EmrServerlessLogsLink(BaseAwsLink): name = "Spark Driver stdout" key = "emr_serverless_logs" - def get_link( - self, - operator: BaseOperator, - *, - ti_key: TaskInstanceKey, - ) -> str: - """ - Pre-signed URL to the Spark stdout log. - - :param operator: airflow operator - :param ti_key: TaskInstance ID to return link for - :return: Pre-signed URL to Spark stdout log. Empty string if no Spark stdout log is available. - """ - conf = XCom.get_value(key=self.key, ti_key=ti_key) - if not conf: - return "" - # If get_dashboard_for_job_run fails for whatever reason, fail after 1 attempt - # so that the rest of the links load in a reasonable time frame. - hook = EmrServerlessHook( - aws_conn_id=conf.get("conn_id"), config={"retries": {"total_max_attempts": 1}} + def format_link(self, application_id: str, job_run_id: str, **kwargs) -> str: + url = get_serverless_dashboard_url( + aws_conn_id=kwargs.get("conn_id"), application_id=application_id, job_run_id=job_run_id ) - resp = hook.conn.get_dashboard_for_job_run( - applicationId=conf.get("application_id"), jobRunId=conf.get("job_run_id") - ) - o = urlparse(resp["url"]) - return o._replace(path="/logs/SPARK_DRIVER/stdout.gz").geturl() + if url: + return url._replace(path="/logs/SPARK_DRIVER/stdout.gz").geturl() + else: + return "" class EmrServerlessDashboardLink(BaseAwsLink): @@ -124,33 +133,14 @@ class EmrServerlessDashboardLink(BaseAwsLink): name = "EMR Serverless Dashboard" key = "emr_serverless_dashboard" - def get_link( - self, - operator: BaseOperator, - *, - ti_key: TaskInstanceKey, - ) -> str: - """ - Pre-signed URL to the application UI for the EMR Serverless job. - - :param operator: airflow operator - :param ti_key: TaskInstance ID to return link for - :return: Pre-signed URL to application UI. - """ - conf = XCom.get_value(key=self.key, ti_key=ti_key) - if not conf: - return "" - # If get_dashboard_for_job_run fails for whatever reason, fail after 1 attempt - # so that the rest of the links load in a reasonable time frame. - hook = EmrServerlessHook( - aws_conn_id=conf.get("conn_id"), config={"retries": {"total_max_attempts": 1}} + def format_link(self, application_id: str, job_run_id: str, **kwargs) -> str: + url = get_serverless_dashboard_url( + aws_conn_id=kwargs.get("conn_id"), application_id=application_id, job_run_id=job_run_id ) - # Dashboard cannot be served when job is pending/scheduled, - # in which case an empty string still gets returned. - resp = hook.conn.get_dashboard_for_job_run( - applicationId=conf.get("application_id"), jobRunId=conf.get("job_run_id") - ) - return resp["url"] + if url: + return url.geturl() + else: + return "" class EmrServerlessS3LogsLink(BaseAwsLink): diff --git a/tests/providers/amazon/aws/links/test_emr.py b/tests/providers/amazon/aws/links/test_emr.py index 590e7f1c61f4a..292b977b084aa 100644 --- a/tests/providers/amazon/aws/links/test_emr.py +++ b/tests/providers/amazon/aws/links/test_emr.py @@ -16,11 +16,22 @@ # under the License. from __future__ import annotations +from unittest import mock from unittest.mock import MagicMock import pytest -from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.links.emr import ( + EmrClusterLink, + EmrLogsLink, + EmrServerlessCloudWatchLogsLink, + EmrServerlessDashboardLink, + EmrServerlessLogsLink, + EmrServerlessS3LogsLink, + get_log_uri, + get_serverless_dashboard_url, +) from tests.providers.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase @@ -75,3 +86,151 @@ def test_extra_link(self): ) def test_missing_log_url(self, log_url_extra: dict): self.assert_extra_link_url(expected_url="", **log_url_extra) + + +@pytest.fixture +def mocked_emr_serverless_hook(): + with mock.patch("airflow.providers.amazon.aws.links.emr.EmrServerlessHook") as m: + yield m + + +class TestEmrServerlessLogsLink(BaseAwsLinksTestCase): + link_class = EmrServerlessLogsLink + + def test_extra_link(self, mocked_emr_serverless_hook): + mocked_client = mocked_emr_serverless_hook.return_value.conn + mocked_client.get_dashboard_for_job_run.return_value = {"url": "https://example.com/?authToken=1234"} + + self.assert_extra_link_url( + expected_url="https://example.com/logs/SPARK_DRIVER/stdout.gz?authToken=1234", + conn_id="aws-test", + application_id="app-id", + job_run_id="job-run-id", + ) + + mocked_emr_serverless_hook.assert_called_with( + aws_conn_id="aws-test", config={"retries": {"total_max_attempts": 1}} + ) + mocked_client.get_dashboard_for_job_run.assert_called_with( + applicationId="app-id", + jobRunId="job-run-id", + ) + + +class TestEmrServerlessDashboardLink(BaseAwsLinksTestCase): + link_class = EmrServerlessDashboardLink + + def test_extra_link(self, mocked_emr_serverless_hook): + mocked_client = mocked_emr_serverless_hook.return_value.conn + mocked_client.get_dashboard_for_job_run.return_value = {"url": "https://example.com/?authToken=1234"} + + self.assert_extra_link_url( + expected_url="https://example.com/?authToken=1234", + conn_id="aws-test", + application_id="app-id", + job_run_id="job-run-id", + ) + + mocked_emr_serverless_hook.assert_called_with( + aws_conn_id="aws-test", config={"retries": {"total_max_attempts": 1}} + ) + mocked_client.get_dashboard_for_job_run.assert_called_with( + applicationId="app-id", + jobRunId="job-run-id", + ) + + +@pytest.mark.parametrize( + "dashboard_info, expected_uri", + [ + pytest.param( + {"url": "https://example.com/?authToken=first-unique-value"}, + "https://example.com/?authToken=first-unique-value", + id="first-call", + ), + pytest.param( + {"url": "https://example.com/?authToken=second-unique-value"}, + "https://example.com/?authToken=second-unique-value", + id="second-call", + ), + ], +) +def test_get_serverless_dashboard_url_with_client(mocked_emr_serverless_hook, dashboard_info, expected_uri): + mocked_client = mocked_emr_serverless_hook.return_value.conn + mocked_client.get_dashboard_for_job_run.return_value = dashboard_info + + url = get_serverless_dashboard_url( + emrs_client=mocked_client, application_id="anything", job_run_id="anything" + ) + assert url and url.geturl() == expected_uri + mocked_emr_serverless_hook.assert_not_called() + mocked_client.get_dashboard_for_job_run.assert_called_with( + applicationId="anything", + jobRunId="anything", + ) + + +def test_get_serverless_dashboard_url_with_conn_id(mocked_emr_serverless_hook): + mocked_client = mocked_emr_serverless_hook.return_value.conn + mocked_client.get_dashboard_for_job_run.return_value = { + "url": "https://example.com/?authToken=some-unique-value" + } + + url = get_serverless_dashboard_url( + aws_conn_id="aws-test", application_id="anything", job_run_id="anything" + ) + assert url and url.geturl() == "https://example.com/?authToken=some-unique-value" + mocked_emr_serverless_hook.assert_called_with( + aws_conn_id="aws-test", config={"retries": {"total_max_attempts": 1}} + ) + mocked_client.get_dashboard_for_job_run.assert_called_with( + applicationId="anything", + jobRunId="anything", + ) + + +def test_get_serverless_dashboard_url_parameters(): + with pytest.raises( + AirflowException, match="Requires either an AWS connection ID or an EMR Serverless Client" + ): + get_serverless_dashboard_url(application_id="anything", job_run_id="anything") + + with pytest.raises( + AirflowException, match="Requires either an AWS connection ID or an EMR Serverless Client" + ): + get_serverless_dashboard_url( + aws_conn_id="a", emrs_client="b", application_id="anything", job_run_id="anything" + ) + + +class TestEmrServerlessS3LogsLink(BaseAwsLinksTestCase): + link_class = EmrServerlessS3LogsLink + + def test_extra_link(self): + self.assert_extra_link_url( + expected_url=( + "https://console.aws.amazon.com/s3/buckets/bucket-name?region=us-west-1&prefix=logs/applications/app-id/jobs/job-run-id/" + ), + region_name="us-west-1", + aws_partition="aws", + log_uri="s3://bucket-name/logs/", + application_id="app-id", + job_run_id="job-run-id", + ) + + +class TestEmrServerlessCloudWatchLogsLink(BaseAwsLinksTestCase): + link_class = EmrServerlessCloudWatchLogsLink + + def test_extra_link(self): + self.assert_extra_link_url( + expected_url=( + "https://console.aws.amazon.com/cloudwatch/home?region=us-west-1#logsV2:log-groups/log-group/%2Faws%2Femrs$3FlogStreamNameFilter$3Dsome-prefix" + ), + region_name="us-west-1", + aws_partition="aws", + awslogs_group="/aws/emrs", + stream_prefix="some-prefix", + application_id="app-id", + job_run_id="job-run-id", + ) From 12842db66e75c24aa0acebf7025f767a4ac5e3b2 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Mon, 12 Feb 2024 23:08:34 -0800 Subject: [PATCH 18/18] Fix lints --- airflow/providers/amazon/aws/links/emr.py | 26 ++++++++++++-------- tests/providers/amazon/aws/links/test_emr.py | 4 +-- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py index 163fc87ce261e..d81bc93cc9b07 100644 --- a/airflow/providers/amazon/aws/links/emr.py +++ b/airflow/providers/amazon/aws/links/emr.py @@ -20,7 +20,6 @@ from urllib.parse import ParseResult, quote_plus, urlparse from airflow.exceptions import AirflowException -from airflow.models import XCom from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink @@ -29,9 +28,6 @@ if TYPE_CHECKING: import boto3 - from airflow.models import BaseOperator - from airflow.models.taskinstancekey import TaskInstanceKey - class EmrClusterLink(BaseAwsLink): """Helper class for constructing Amazon EMR Cluster Link.""" @@ -65,7 +61,11 @@ def get_serverless_log_uri(*, s3_log_uri: str, application_id: str, job_run_id: def get_serverless_dashboard_url( - *, aws_conn_id: str | None = None, emrs_client: boto3.client = None, application_id: str, job_run_id: str + *, + aws_conn_id: str | None = None, + emr_serverless_client: boto3.client = None, + application_id: str, + job_run_id: str, ) -> ParseResult | None: """ Retrieve the URL to EMR Serverless dashboard. @@ -75,16 +75,18 @@ def get_serverless_dashboard_url( Either an AWS connection ID or existing EMR Serverless client must be passed. If the connection ID is passed, a client is generated using that connection. """ - if not exactly_one(aws_conn_id, emrs_client): + if not exactly_one(aws_conn_id, emr_serverless_client): raise AirflowException("Requires either an AWS connection ID or an EMR Serverless Client.") if aws_conn_id: # If get_dashboard_for_job_run fails for whatever reason, fail after 1 attempt # so that the rest of the links load in a reasonable time frame. hook = EmrServerlessHook(aws_conn_id=aws_conn_id, config={"retries": {"total_max_attempts": 1}}) - emrs_client = hook.conn + emr_serverless_client = hook.conn - response = emrs_client.get_dashboard_for_job_run(applicationId=application_id, jobRunId=job_run_id) + response = emr_serverless_client.get_dashboard_for_job_run( + applicationId=application_id, jobRunId=job_run_id + ) if "url" not in response: return None log_uri = urlparse(response["url"]) @@ -117,7 +119,9 @@ class EmrServerlessLogsLink(BaseAwsLink): name = "Spark Driver stdout" key = "emr_serverless_logs" - def format_link(self, application_id: str, job_run_id: str, **kwargs) -> str: + def format_link(self, application_id: str | None = None, job_run_id: str | None = None, **kwargs) -> str: + if not application_id or not job_run_id: + return "" url = get_serverless_dashboard_url( aws_conn_id=kwargs.get("conn_id"), application_id=application_id, job_run_id=job_run_id ) @@ -133,7 +137,9 @@ class EmrServerlessDashboardLink(BaseAwsLink): name = "EMR Serverless Dashboard" key = "emr_serverless_dashboard" - def format_link(self, application_id: str, job_run_id: str, **kwargs) -> str: + def format_link(self, application_id: str | None = None, job_run_id: str | None = None, **kwargs) -> str: + if not application_id or not job_run_id: + return "" url = get_serverless_dashboard_url( aws_conn_id=kwargs.get("conn_id"), application_id=application_id, job_run_id=job_run_id ) diff --git a/tests/providers/amazon/aws/links/test_emr.py b/tests/providers/amazon/aws/links/test_emr.py index 292b977b084aa..00e983ed16826 100644 --- a/tests/providers/amazon/aws/links/test_emr.py +++ b/tests/providers/amazon/aws/links/test_emr.py @@ -160,7 +160,7 @@ def test_get_serverless_dashboard_url_with_client(mocked_emr_serverless_hook, da mocked_client.get_dashboard_for_job_run.return_value = dashboard_info url = get_serverless_dashboard_url( - emrs_client=mocked_client, application_id="anything", job_run_id="anything" + emr_serverless_client=mocked_client, application_id="anything", job_run_id="anything" ) assert url and url.geturl() == expected_uri mocked_emr_serverless_hook.assert_not_called() @@ -199,7 +199,7 @@ def test_get_serverless_dashboard_url_parameters(): AirflowException, match="Requires either an AWS connection ID or an EMR Serverless Client" ): get_serverless_dashboard_url( - aws_conn_id="a", emrs_client="b", application_id="anything", job_run_id="anything" + aws_conn_id="a", emr_serverless_client="b", application_id="anything", job_run_id="anything" )