Skip to content

Commit

Permalink
Fix deferrable mode for DataflowTemplatedJobStartOperator and Dataflo…
Browse files Browse the repository at this point in the history
…wStartFlexTemplateOperator
  • Loading branch information
e-galan committed Apr 25, 2024
1 parent 667ee1b commit 3d59b64
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 93 deletions.
177 changes: 145 additions & 32 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@
MessagesV1Beta3AsyncClient,
MetricsV1Beta3AsyncClient,
)
from google.cloud.dataflow_v1beta3.types import GetJobMetricsRequest, JobMessageImportance, JobMetrics
from google.cloud.dataflow_v1beta3.types import (
GetJobMetricsRequest,
JobMessageImportance,
JobMetrics,
)
from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest
from googleapiclient.discovery import build
from googleapiclient.discovery import Resource, build

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
Expand Down Expand Up @@ -573,7 +577,7 @@ def __init__(
impersonation_chain=impersonation_chain,
)

def get_conn(self) -> build:
def get_conn(self) -> Resource:
"""Return a Google Cloud Dataflow service object."""
http_authorized = self._authorize()
return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)
Expand Down Expand Up @@ -653,9 +657,9 @@ def start_template_dataflow(
on_new_job_callback: Callable[[dict], None] | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
environment: dict | None = None,
) -> dict:
) -> dict[str, str]:
"""
Start Dataflow template job.
Launch a Dataflow job with a Classic Template and wait for its completion.
:param job_name: The name of the job.
:param variables: Map of job runtime environment options.
Expand Down Expand Up @@ -688,34 +692,22 @@ def start_template_dataflow(
environment=environment,
)

service = self.get_conn()

request = (
service.projects()
.locations()
.templates()
.launch(
projectId=project_id,
location=location,
gcsPath=dataflow_template,
body={
"jobName": name,
"parameters": parameters,
"environment": environment,
},
)
job: dict[str, str] = self.send_launch_template_request(
project_id=project_id,
location=location,
gcs_path=dataflow_template,
job_name=name,
parameters=parameters,
environment=environment,
)
response = request.execute(num_retries=self.num_retries)

job = response["job"]

if on_new_job_id_callback:
warnings.warn(
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
AirflowProviderDeprecationWarning,
stacklevel=3,
)
on_new_job_id_callback(job.get("id"))
on_new_job_id_callback(job["id"])

if on_new_job_callback:
on_new_job_callback(job)
Expand All @@ -734,7 +726,62 @@ def start_template_dataflow(
expected_terminal_state=self.expected_terminal_state,
)
jobs_controller.wait_for_done()
return response["job"]
return job

@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
def launch_job_with_template(
self,
*,
job_name: str,
variables: dict,
parameters: dict,
dataflow_template: str,
project_id: str,
append_job_name: bool = True,
location: str = DEFAULT_DATAFLOW_LOCATION,
environment: dict | None = None,
) -> dict[str, str]:
"""
Launch a Dataflow job with a Classic Template and exit without waiting for its completion.
:param job_name: The name of the job.
:param variables: Map of job runtime environment options.
It will update environment argument if passed.
.. seealso::
For more information on possible configurations, look at the API documentation
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
:param parameters: Parameters for the template
:param dataflow_template: GCS path to the template.
:param project_id: Optional, the Google Cloud project ID in which to start a job.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param append_job_name: True if unique suffix has to be appended to job name.
:param location: Job location.
.. seealso::
For more information on possible configurations, look at the API documentation
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
:return: the Dataflow job response
"""
name = self.build_dataflow_job_name(job_name, append_job_name)
environment = self._update_environment(
variables=variables,
environment=environment,
)
job: dict[str, str] = self.send_launch_template_request(
project_id=project_id,
location=location,
gcs_path=dataflow_template,
job_name=name,
parameters=parameters,
environment=environment,
)
return job

def _update_environment(self, variables: dict, environment: dict | None = None) -> dict:
environment = environment or {}
Expand Down Expand Up @@ -770,6 +817,35 @@ def _check_one(key, val):

return environment

def send_launch_template_request(
self,
*,
project_id: str,
location: str,
gcs_path: str,
job_name: str,
parameters: dict,
environment: dict,
) -> dict[str, str]:
service: Resource = self.get_conn()
request = (
service.projects()
.locations()
.templates()
.launch(
projectId=project_id,
location=location,
gcsPath=gcs_path,
body={
"jobName": job_name,
"parameters": parameters,
"environment": environment,
},
)
)
response: dict = request.execute(num_retries=self.num_retries)
return response["job"]

@GoogleBaseHook.fallback_to_default_project_id
def start_flex_template(
self,
Expand All @@ -778,9 +854,9 @@ def start_flex_template(
project_id: str,
on_new_job_id_callback: Callable[[str], None] | None = None,
on_new_job_callback: Callable[[dict], None] | None = None,
) -> dict:
) -> dict[str, str]:
"""
Start flex templates with the Dataflow pipeline.
Launch a Dataflow job with a Flex Template and wait for its completion.
:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
Expand All @@ -791,31 +867,32 @@ def start_flex_template(
:param on_new_job_callback: A callback that is called when a Job is detected.
:return: the Job
"""
service = self.get_conn()
service: Resource = self.get_conn()
request = (
service.projects()
.locations()
.flexTemplates()
.launch(projectId=project_id, body=body, location=location)
)
response = request.execute(num_retries=self.num_retries)
response: dict = request.execute(num_retries=self.num_retries)
job = response["job"]
job_id: str = job["id"]

if on_new_job_id_callback:
warnings.warn(
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
AirflowProviderDeprecationWarning,
stacklevel=3,
)
on_new_job_id_callback(job.get("id"))
on_new_job_id_callback(job_id)

if on_new_job_callback:
on_new_job_callback(job)

jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
job_id=job.get("id"),
job_id=job_id,
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
Expand All @@ -826,6 +903,42 @@ def start_flex_template(

return jobs_controller.get_jobs(refresh=True)[0]

@GoogleBaseHook.fallback_to_default_project_id
def launch_job_with_flex_template(
self,
body: dict,
location: str,
project_id: str,
) -> dict[str, str]:
"""
Launch a Dataflow Job with a Flex Template and exit without waiting for the job completion.
:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
:param location: The location of the Dataflow job (for example europe-west1)
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:return: a Dataflow job response
"""
service: Resource = self.get_conn()
request = (
service.projects()
.locations()
.flexTemplates()
.launch(projectId=project_id, body=body, location=location)
)
response: dict = request.execute(num_retries=self.num_retries)
return response["job"]

@staticmethod
def extract_job_id(job: dict) -> str:
try:
return job["id"]
except KeyError:
raise AirflowException(
"While reading job object after template execution error occurred. Job object has no id."
)

@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
Expand Down
Loading

0 comments on commit 3d59b64

Please sign in to comment.