Skip to content

Commit

Permalink
Support impersonation service account parameter for Dataflow runner (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Łukasz Wyszomirski authored May 31, 2022
1 parent 7c7dbfe commit 41e94b4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
15 changes: 15 additions & 0 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class BeamDataflowMixin(metaclass=ABCMeta):
dataflow_config: DataflowConfiguration
gcp_conn_id: str
delegate_to: Optional[str]
dataflow_support_impersonation: bool = True

def _set_dataflow(
self,
Expand Down Expand Up @@ -91,6 +92,13 @@ def __get_dataflow_pipeline_options(
pipeline_options[job_name_key] = job_name
if self.dataflow_config.service_account:
pipeline_options["serviceAccount"] = self.dataflow_config.service_account
if self.dataflow_support_impersonation and self.dataflow_config.impersonation_chain:
if isinstance(self.dataflow_config.impersonation_chain, list):
pipeline_options["impersonateServiceAccount"] = ",".join(
self.dataflow_config.impersonation_chain
)
else:
pipeline_options["impersonateServiceAccount"] = self.dataflow_config.impersonation_chain
pipeline_options["project"] = self.dataflow_config.project_id
pipeline_options["region"] = self.dataflow_config.location
pipeline_options.setdefault("labels", {}).update(
Expand Down Expand Up @@ -550,6 +558,13 @@ def __init__(
**kwargs,
)

if self.dataflow_config.impersonation_chain:
self.log.info(
"Impersonation chain parameter is not supported for Apache Beam GO SDK and will be skipped "
"in the execution"
)
self.dataflow_support_impersonation = False

self.go_file = go_file
self.should_init_go_module = False
self.pipeline_options.setdefault("labels", {}).update(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'mypy-boto3-redshift-data>=1.21.0',
]
apache_beam = [
'apache-beam>=2.33.0',
'apache-beam>=2.39.0',
]
arangodb = ['python-arango>=7.3.2']
asana = ['asana>=0.10']
Expand Down
9 changes: 6 additions & 3 deletions tests/providers/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'output': 'gs://test/output',
'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
}
TEST_IMPERSONATION_ACCOUNT = "test@impersonation.com"


class TestBeamRunPythonPipelineOperator(unittest.TestCase):
Expand Down Expand Up @@ -104,7 +105,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
"""Test DataflowHook is created and the right args are passed to
start_python_dataflow.
"""
dataflow_config = DataflowConfiguration()
dataflow_config = DataflowConfiguration(impersonation_chain=TEST_IMPERSONATION_ACCOUNT)
self.operator.runner = "DataflowRunner"
self.operator.dataflow_config = dataflow_config
gcs_provide_file = gcs_hook.return_value.provide_file
Expand All @@ -126,6 +127,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
'output': 'gs://test/output',
'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
'region': 'us-central1',
'impersonate_service_account': TEST_IMPERSONATION_ACCOUNT,
}
gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
persist_link_mock.assert_called_once_with(
Expand Down Expand Up @@ -223,7 +225,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
"""Test DataflowHook is created and the right args are passed to
start_java_dataflow.
"""
dataflow_config = DataflowConfiguration()
dataflow_config = DataflowConfiguration(impersonation_chain="test@impersonation.com")
self.operator.runner = "DataflowRunner"
self.operator.dataflow_config = dataflow_config
gcs_provide_file = gcs_hook.return_value.provide_file
Expand All @@ -248,6 +250,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
'region': 'us-central1',
'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
'output': 'gs://test/output',
'impersonateServiceAccount': TEST_IMPERSONATION_ACCOUNT,
}
persist_link_mock.assert_called_once_with(
self.operator,
Expand Down Expand Up @@ -374,7 +377,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
"""Test DataflowHook is created and the right args are passed to
start_go_dataflow.
"""
dataflow_config = DataflowConfiguration()
dataflow_config = DataflowConfiguration(impersonation_chain="test@impersonation.com")
self.operator.runner = "DataflowRunner"
self.operator.dataflow_config = dataflow_config
gcs_provide_file = gcs_hook.return_value.provide_file
Expand Down

0 comments on commit 41e94b4

Please sign in to comment.