diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index 1d4c87d6767b1..baf6780e07802 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -44,7 +44,8 @@ class GlueJobHook(AwsBaseHook): :param retry_limit: Maximum number of times to retry this job if it fails :param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job :param region_name: aws region name (example: us-east-1) - :param iam_role_name: AWS IAM Role for Glue Job Execution + :param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None. + :param iam_role_arn: AWS IAM Role ARN for Glue Job Execution, If set `iam_role_name` must equal None. :param create_job_kwargs: Extra arguments for Glue Job Creation :param update_config: Update job configuration on Glue (default: False) @@ -72,6 +73,7 @@ def __init__( retry_limit: int = 0, num_of_dpus: int | float | None = None, iam_role_name: str | None = None, + iam_role_arn: str | None = None, create_job_kwargs: dict | None = None, update_config: bool = False, job_poll_interval: int | float = 6, @@ -85,6 +87,7 @@ def __init__( self.retry_limit = retry_limit self.s3_bucket = s3_bucket self.role_name = iam_role_name + self.role_arn = iam_role_arn self.s3_glue_logs = "logs/glue-logs/" self.create_job_kwargs = create_job_kwargs or {} self.update_config = update_config @@ -93,6 +96,8 @@ def __init__( worker_type_exists = "WorkerType" in self.create_job_kwargs num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs + if self.role_arn and self.role_name: + raise ValueError("Cannot set iam_role_arn and iam_role_name simultaneously") if worker_type_exists and num_workers_exists: if num_of_dpus is not None: raise ValueError("Cannot specify num_of_dpus with custom WorkerType") @@ -114,12 +119,16 @@ def create_glue_job_config(self) -> dict: "ScriptLocation": self.script_location, } command = self.create_job_kwargs.pop("Command", default_command) - execution_role = self.get_iam_execution_role() + if not self.role_arn: + execution_role = self.get_iam_execution_role() + role_arn = execution_role["Role"]["Arn"] + else: + role_arn = self.role_arn config = { "Name": self.job_name, "Description": self.desc, - "Role": execution_role["Role"]["Arn"], + "Role": role_arn, "ExecutionProperty": {"MaxConcurrentRuns": self.concurrent_run_limit}, "Command": command, "MaxRetries": self.retry_limit, @@ -144,7 +153,6 @@ def list_jobs(self) -> list: return self.conn.get_jobs() def get_iam_execution_role(self) -> dict: - """Get IAM Role for job execution.""" try: iam_client = self.get_session(region_name=self.region_name).client( "iam", endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 265d057de51ae..d47d1c25de5ec 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -53,7 +53,8 @@ class GlueJobOperator(BaseOperator): :param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job. :param region_name: aws region name (example: us-east-1) :param s3_bucket: S3 bucket where logs and local etl script will be uploaded - :param iam_role_name: AWS IAM Role for Glue Job Execution + :param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None. + :param iam_role_arn: AWS IAM ARN for Glue Job Execution. If set `iam_role_name` must equal None. :param create_job_kwargs: Extra arguments for Glue Job Creation :param run_job_kwargs: Extra arguments for Glue Job Run :param wait_for_completion: Whether to wait for job run completion. (default: True) @@ -72,6 +73,7 @@ class GlueJobOperator(BaseOperator): "create_job_kwargs", "s3_bucket", "iam_role_name", + "iam_role_arn", ) template_ext: Sequence[str] = () template_fields_renderers = { @@ -96,6 +98,7 @@ def __init__( region_name: str | None = None, s3_bucket: str | None = None, iam_role_name: str | None = None, + iam_role_arn: str | None = None, create_job_kwargs: dict | None = None, run_job_kwargs: dict | None = None, wait_for_completion: bool = True, @@ -118,6 +121,7 @@ def __init__( self.region_name = region_name self.s3_bucket = s3_bucket self.iam_role_name = iam_role_name + self.iam_role_arn = iam_role_arn self.s3_protocol = "s3://" self.s3_artifacts_prefix = "artifacts/glue-scripts/" self.create_job_kwargs = create_job_kwargs @@ -154,6 +158,7 @@ def glue_job_hook(self) -> GlueJobHook: region_name=self.region_name, s3_bucket=self.s3_bucket, iam_role_name=self.iam_role_name, + iam_role_arn=self.iam_role_arn, create_job_kwargs=self.create_job_kwargs, update_config=self.update_config, job_poll_interval=self.job_poll_interval, diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py index c41598f3d9aa6..1fae16e339558 100644 --- a/tests/providers/amazon/aws/hooks/test_glue.py +++ b/tests/providers/amazon/aws/hooks/test_glue.py @@ -66,6 +66,22 @@ def test_get_iam_execution_role(self, role_path): assert "Arn" in iam_role["Role"] assert iam_role["Role"]["Arn"] == f"arn:aws:iam::123456789012:role{role_path}{expected_role}" + @mock.patch.object(GlueJobHook, "get_iam_execution_role") + @mock.patch.object(GlueJobHook, "conn") + def test_init_iam_role_value_error(self, mock_conn, mock_get_iam_execution_role): + mock_get_iam_execution_role.return_value = mock.MagicMock( + Role={"RoleName": "my_test_role_name", "RoleArn": "my_test_role"} + ) + + with pytest.raises(ValueError, match="Cannot set iam_role_arn and iam_role_name simultaneously"): + GlueJobHook( + job_name="aws_test_glue_job", + desc="This is test case job from Airflow", + s3_bucket="some-bucket", + iam_role_name="my_test_role_name", + iam_role_arn="my_test_role", + ) + @mock.patch.object(AwsBaseHook, "conn") def test_has_job_exists(self, mock_conn): job_name = "aws_test_glue_job" @@ -90,6 +106,56 @@ class JobNotFoundException(Exception): assert result is False mock_conn.get_job.assert_called_once_with(JobName=job_name) + @mock.patch.object(GlueJobHook, "get_iam_execution_role") + @mock.patch.object(AwsBaseHook, "conn") + def test_role_arn_has_job_exists(self, mock_conn, mock_get_iam_execution_role): + """ + Calls 'create_or_update_glue_job' with no existing job. + Should create a new job. + """ + + class JobNotFoundException(Exception): + pass + + expected_job_name = "aws_test_glue_job" + job_description = "This is test case job from Airflow" + role_name = "my_test_role" + role_name_arn = "test_role" + some_s3_bucket = "bucket" + + mock_conn.exceptions.EntityNotFoundException = JobNotFoundException + mock_conn.get_job.side_effect = JobNotFoundException() + mock_get_iam_execution_role.return_value = {"Role": {"RoleName": role_name, "Arn": role_name_arn}} + + hook = GlueJobHook( + s3_bucket=some_s3_bucket, + job_name=expected_job_name, + desc=job_description, + concurrent_run_limit=2, + retry_limit=3, + num_of_dpus=5, + iam_role_arn=role_name_arn, + create_job_kwargs={"Command": {}}, + region_name=self.some_aws_region, + update_config=True, + ) + + result = hook.create_or_update_glue_job() + + mock_conn.get_job.assert_called_once_with(JobName=expected_job_name) + mock_conn.create_job.assert_called_once_with( + Command={}, + Description=job_description, + ExecutionProperty={"MaxConcurrentRuns": 2}, + LogUri=f"s3://{some_s3_bucket}/logs/glue-logs/{expected_job_name}", + MaxCapacity=5, + MaxRetries=3, + Name=expected_job_name, + Role=role_name_arn, + ) + mock_conn.update_job.assert_not_called() + assert result == expected_job_name + @mock.patch.object(GlueJobHook, "get_iam_execution_role") @mock.patch.object(GlueJobHook, "conn") def test_create_or_update_glue_job_create_new_job(self, mock_conn, mock_get_iam_execution_role): diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py index 9eed48e47adc9..dc298563ae65c 100644 --- a/tests/providers/amazon/aws/operators/test_glue.py +++ b/tests/providers/amazon/aws/operators/test_glue.py @@ -48,6 +48,7 @@ def test_render_template(self, create_task_instance_of_operator): script_args="{{ dag.dag_id }}", create_job_kwargs="{{ dag.dag_id }}", iam_role_name="{{ dag.dag_id }}", + iam_role_arn="{{ dag.dag_id }}", s3_bucket="{{ dag.dag_id }}", job_name="{{ dag.dag_id }}", ) @@ -57,6 +58,7 @@ def test_render_template(self, create_task_instance_of_operator): assert DAG_ID == rendered_template.script_args assert DAG_ID == rendered_template.create_job_kwargs assert DAG_ID == rendered_template.iam_role_name + assert DAG_ID == rendered_template.iam_role_arn assert DAG_ID == rendered_template.s3_bucket assert DAG_ID == rendered_template.job_name @@ -99,6 +101,27 @@ def test_execute_without_failure( mock_print_job_logs.assert_not_called() assert glue.job_name == JOB_NAME + @mock.patch.object(GlueJobHook, "initialize_job") + @mock.patch.object(GlueJobHook, "get_conn") + def test_role_arn_execute_deferrable(self, _, mock_initialize_job): + glue = GlueJobOperator( + task_id=TASK_ID, + job_name=JOB_NAME, + script_location="s3://folder/file", + aws_conn_id="aws_default", + region_name="us-west-2", + s3_bucket="some_bucket", + iam_role_arn="test_role", + deferrable=True, + ) + mock_initialize_job.return_value = {"JobRunState": "RUNNING", "JobRunId": JOB_RUN_ID} + + with pytest.raises(TaskDeferred) as defer: + glue.execute(mock.MagicMock()) + + assert defer.value.trigger.job_name == JOB_NAME + assert defer.value.trigger.run_id == JOB_RUN_ID + @mock.patch.object(GlueJobHook, "initialize_job") @mock.patch.object(GlueJobHook, "get_conn") def test_execute_deferrable(self, _, mock_initialize_job):