Skip to content

Commit

Permalink
Add parameter to pass role ARN to GlueJobOperator (#33408)
Browse files Browse the repository at this point in the history
  • Loading branch information
erdos2n authored Aug 15, 2023
1 parent cc360b7 commit 60df705
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 5 deletions.
16 changes: 12 additions & 4 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class GlueJobHook(AwsBaseHook):
:param retry_limit: Maximum number of times to retry this job if it fails
:param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job
:param region_name: aws region name (example: us-east-1)
:param iam_role_name: AWS IAM Role for Glue Job Execution
:param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None.
:param iam_role_arn: AWS IAM Role ARN for Glue Job Execution, If set `iam_role_name` must equal None.
:param create_job_kwargs: Extra arguments for Glue Job Creation
:param update_config: Update job configuration on Glue (default: False)
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
retry_limit: int = 0,
num_of_dpus: int | float | None = None,
iam_role_name: str | None = None,
iam_role_arn: str | None = None,
create_job_kwargs: dict | None = None,
update_config: bool = False,
job_poll_interval: int | float = 6,
Expand All @@ -85,6 +87,7 @@ def __init__(
self.retry_limit = retry_limit
self.s3_bucket = s3_bucket
self.role_name = iam_role_name
self.role_arn = iam_role_arn
self.s3_glue_logs = "logs/glue-logs/"
self.create_job_kwargs = create_job_kwargs or {}
self.update_config = update_config
Expand All @@ -93,6 +96,8 @@ def __init__(
worker_type_exists = "WorkerType" in self.create_job_kwargs
num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs

if self.role_arn and self.role_name:
raise ValueError("Cannot set iam_role_arn and iam_role_name simultaneously")
if worker_type_exists and num_workers_exists:
if num_of_dpus is not None:
raise ValueError("Cannot specify num_of_dpus with custom WorkerType")
Expand All @@ -114,12 +119,16 @@ def create_glue_job_config(self) -> dict:
"ScriptLocation": self.script_location,
}
command = self.create_job_kwargs.pop("Command", default_command)
execution_role = self.get_iam_execution_role()
if not self.role_arn:
execution_role = self.get_iam_execution_role()
role_arn = execution_role["Role"]["Arn"]
else:
role_arn = self.role_arn

config = {
"Name": self.job_name,
"Description": self.desc,
"Role": execution_role["Role"]["Arn"],
"Role": role_arn,
"ExecutionProperty": {"MaxConcurrentRuns": self.concurrent_run_limit},
"Command": command,
"MaxRetries": self.retry_limit,
Expand All @@ -144,7 +153,6 @@ def list_jobs(self) -> list:
return self.conn.get_jobs()

def get_iam_execution_role(self) -> dict:
"""Get IAM Role for job execution."""
try:
iam_client = self.get_session(region_name=self.region_name).client(
"iam", endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class GlueJobOperator(BaseOperator):
:param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job.
:param region_name: aws region name (example: us-east-1)
:param s3_bucket: S3 bucket where logs and local etl script will be uploaded
:param iam_role_name: AWS IAM Role for Glue Job Execution
:param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None.
:param iam_role_arn: AWS IAM ARN for Glue Job Execution. If set `iam_role_name` must equal None.
:param create_job_kwargs: Extra arguments for Glue Job Creation
:param run_job_kwargs: Extra arguments for Glue Job Run
:param wait_for_completion: Whether to wait for job run completion. (default: True)
Expand All @@ -72,6 +73,7 @@ class GlueJobOperator(BaseOperator):
"create_job_kwargs",
"s3_bucket",
"iam_role_name",
"iam_role_arn",
)
template_ext: Sequence[str] = ()
template_fields_renderers = {
Expand All @@ -96,6 +98,7 @@ def __init__(
region_name: str | None = None,
s3_bucket: str | None = None,
iam_role_name: str | None = None,
iam_role_arn: str | None = None,
create_job_kwargs: dict | None = None,
run_job_kwargs: dict | None = None,
wait_for_completion: bool = True,
Expand All @@ -118,6 +121,7 @@ def __init__(
self.region_name = region_name
self.s3_bucket = s3_bucket
self.iam_role_name = iam_role_name
self.iam_role_arn = iam_role_arn
self.s3_protocol = "s3://"
self.s3_artifacts_prefix = "artifacts/glue-scripts/"
self.create_job_kwargs = create_job_kwargs
Expand Down Expand Up @@ -154,6 +158,7 @@ def glue_job_hook(self) -> GlueJobHook:
region_name=self.region_name,
s3_bucket=self.s3_bucket,
iam_role_name=self.iam_role_name,
iam_role_arn=self.iam_role_arn,
create_job_kwargs=self.create_job_kwargs,
update_config=self.update_config,
job_poll_interval=self.job_poll_interval,
Expand Down
66 changes: 66 additions & 0 deletions tests/providers/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ def test_get_iam_execution_role(self, role_path):
assert "Arn" in iam_role["Role"]
assert iam_role["Role"]["Arn"] == f"arn:aws:iam::123456789012:role{role_path}{expected_role}"

@mock.patch.object(GlueJobHook, "get_iam_execution_role")
@mock.patch.object(GlueJobHook, "conn")
def test_init_iam_role_value_error(self, mock_conn, mock_get_iam_execution_role):
mock_get_iam_execution_role.return_value = mock.MagicMock(
Role={"RoleName": "my_test_role_name", "RoleArn": "my_test_role"}
)

with pytest.raises(ValueError, match="Cannot set iam_role_arn and iam_role_name simultaneously"):
GlueJobHook(
job_name="aws_test_glue_job",
desc="This is test case job from Airflow",
s3_bucket="some-bucket",
iam_role_name="my_test_role_name",
iam_role_arn="my_test_role",
)

@mock.patch.object(AwsBaseHook, "conn")
def test_has_job_exists(self, mock_conn):
job_name = "aws_test_glue_job"
Expand All @@ -90,6 +106,56 @@ class JobNotFoundException(Exception):
assert result is False
mock_conn.get_job.assert_called_once_with(JobName=job_name)

@mock.patch.object(GlueJobHook, "get_iam_execution_role")
@mock.patch.object(AwsBaseHook, "conn")
def test_role_arn_has_job_exists(self, mock_conn, mock_get_iam_execution_role):
"""
Calls 'create_or_update_glue_job' with no existing job.
Should create a new job.
"""

class JobNotFoundException(Exception):
pass

expected_job_name = "aws_test_glue_job"
job_description = "This is test case job from Airflow"
role_name = "my_test_role"
role_name_arn = "test_role"
some_s3_bucket = "bucket"

mock_conn.exceptions.EntityNotFoundException = JobNotFoundException
mock_conn.get_job.side_effect = JobNotFoundException()
mock_get_iam_execution_role.return_value = {"Role": {"RoleName": role_name, "Arn": role_name_arn}}

hook = GlueJobHook(
s3_bucket=some_s3_bucket,
job_name=expected_job_name,
desc=job_description,
concurrent_run_limit=2,
retry_limit=3,
num_of_dpus=5,
iam_role_arn=role_name_arn,
create_job_kwargs={"Command": {}},
region_name=self.some_aws_region,
update_config=True,
)

result = hook.create_or_update_glue_job()

mock_conn.get_job.assert_called_once_with(JobName=expected_job_name)
mock_conn.create_job.assert_called_once_with(
Command={},
Description=job_description,
ExecutionProperty={"MaxConcurrentRuns": 2},
LogUri=f"s3://{some_s3_bucket}/logs/glue-logs/{expected_job_name}",
MaxCapacity=5,
MaxRetries=3,
Name=expected_job_name,
Role=role_name_arn,
)
mock_conn.update_job.assert_not_called()
assert result == expected_job_name

@mock.patch.object(GlueJobHook, "get_iam_execution_role")
@mock.patch.object(GlueJobHook, "conn")
def test_create_or_update_glue_job_create_new_job(self, mock_conn, mock_get_iam_execution_role):
Expand Down
23 changes: 23 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_render_template(self, create_task_instance_of_operator):
script_args="{{ dag.dag_id }}",
create_job_kwargs="{{ dag.dag_id }}",
iam_role_name="{{ dag.dag_id }}",
iam_role_arn="{{ dag.dag_id }}",
s3_bucket="{{ dag.dag_id }}",
job_name="{{ dag.dag_id }}",
)
Expand All @@ -57,6 +58,7 @@ def test_render_template(self, create_task_instance_of_operator):
assert DAG_ID == rendered_template.script_args
assert DAG_ID == rendered_template.create_job_kwargs
assert DAG_ID == rendered_template.iam_role_name
assert DAG_ID == rendered_template.iam_role_arn
assert DAG_ID == rendered_template.s3_bucket
assert DAG_ID == rendered_template.job_name

Expand Down Expand Up @@ -99,6 +101,27 @@ def test_execute_without_failure(
mock_print_job_logs.assert_not_called()
assert glue.job_name == JOB_NAME

@mock.patch.object(GlueJobHook, "initialize_job")
@mock.patch.object(GlueJobHook, "get_conn")
def test_role_arn_execute_deferrable(self, _, mock_initialize_job):
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://folder/file",
aws_conn_id="aws_default",
region_name="us-west-2",
s3_bucket="some_bucket",
iam_role_arn="test_role",
deferrable=True,
)
mock_initialize_job.return_value = {"JobRunState": "RUNNING", "JobRunId": JOB_RUN_ID}

with pytest.raises(TaskDeferred) as defer:
glue.execute(mock.MagicMock())

assert defer.value.trigger.job_name == JOB_NAME
assert defer.value.trigger.run_id == JOB_RUN_ID

@mock.patch.object(GlueJobHook, "initialize_job")
@mock.patch.object(GlueJobHook, "get_conn")
def test_execute_deferrable(self, _, mock_initialize_job):
Expand Down

0 comments on commit 60df705

Please sign in to comment.