Skip to content

Commit

Permalink
feat: add Hook Level Lineage support for GCSHook
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <mudakacper@gmail.com>
  • Loading branch information
kacpermuda committed Sep 26, 2024
1 parent 3de9e14 commit 4e70d6a
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 2 deletions.
52 changes: 50 additions & 2 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from requests import Session

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import (
Expand Down Expand Up @@ -213,6 +214,16 @@ def copy(
destination_object = source_bucket.copy_blob( # type: ignore[attr-defined]
blob=source_object, destination_bucket=destination_bucket, new_name=destination_object
)
get_hook_lineage_collector().add_input_dataset(
context=self,
scheme="gs",
dataset_kwargs={"bucket": source_bucket.name, "key": source_object.name},
)
get_hook_lineage_collector().add_output_dataset(
context=self,
scheme="gs",
dataset_kwargs={"bucket": destination_bucket.name, "key": destination_object.name},
)

self.log.info(
"Object %s in bucket %s copied to object %s in bucket %s",
Expand Down Expand Up @@ -266,6 +277,16 @@ def rewrite(
).rewrite(source=source_object, token=token)

self.log.info("Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten)
get_hook_lineage_collector().add_input_dataset(
context=self,
scheme="gs",
dataset_kwargs={"bucket": source_bucket.name, "key": source_object.name},
)
get_hook_lineage_collector().add_output_dataset(
context=self,
scheme="gs",
dataset_kwargs={"bucket": destination_bucket.name, "key": destination_object},
)
self.log.info(
"Object %s in bucket %s rewritten to object %s in bucket %s",
source_object.name, # type: ignore[attr-defined]
Expand Down Expand Up @@ -344,9 +365,18 @@ def download(

if filename:
blob.download_to_filename(filename, timeout=timeout)
get_hook_lineage_collector().add_input_dataset(
context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": blob.name}
)
get_hook_lineage_collector().add_output_dataset(
context=self, scheme="file", dataset_kwargs={"path": filename}
)
self.log.info("File downloaded to %s", filename)
return filename
else:
get_hook_lineage_collector().add_input_dataset(
context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": blob.name}
)
return blob.download_as_bytes()

except GoogleCloudError:
Expand Down Expand Up @@ -554,6 +584,9 @@ def _call_with_retry(f: Callable[[], None]) -> None:
_call_with_retry(
partial(blob.upload_from_filename, filename=filename, content_type=mime_type, timeout=timeout)
)
get_hook_lineage_collector().add_input_dataset(
context=self, scheme="file", dataset_kwargs={"path": filename}
)

if gzip:
os.remove(filename)
Expand All @@ -575,6 +608,10 @@ def _call_with_retry(f: Callable[[], None]) -> None:
else:
raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.")

get_hook_lineage_collector().add_output_dataset(
context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": blob.name}
)

def exists(self, bucket_name: str, object_name: str, retry: Retry = DEFAULT_RETRY) -> bool:
"""
Check for the existence of a file in Google Cloud Storage.
Expand Down Expand Up @@ -694,6 +731,9 @@ def delete(self, bucket_name: str, object_name: str) -> None:
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)
blob.delete()
get_hook_lineage_collector().add_input_dataset(
context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": blob.name}
)

self.log.info("Blob %s deleted.", object_name)

Expand Down Expand Up @@ -1193,9 +1233,17 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec
client = self.get_conn()
bucket = client.bucket(bucket_name)
destination_blob = bucket.blob(destination_object)
destination_blob.compose(
sources=[bucket.blob(blob_name=source_object) for source_object in source_objects]
source_blobs = [bucket.blob(blob_name=source_object) for source_object in source_objects]
destination_blob.compose(sources=source_blobs)
get_hook_lineage_collector().add_output_dataset(
context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": destination_blob.name}
)
for single_source_blob in source_blobs:
get_hook_lineage_collector().add_input_dataset(
context=self,
scheme="gs",
dataset_kwargs={"bucket": bucket.name, "key": single_source_blob.name},
)

self.log.info("Completed successfully.")

Expand Down
45 changes: 45 additions & 0 deletions airflow/providers/google/datasets/gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.datasets import Dataset
from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url

if TYPE_CHECKING:
from urllib.parse import SplitResult

from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset


def create_dataset(*, bucket: str, key: str, extra: dict | None = None) -> Dataset:
return Dataset(uri=f"gs://{bucket}/{key}", extra=extra)


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format gs:// must contain a bucket name")
return uri


def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset:
"""Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the hook."""
from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset

bucket, key = _parse_gcs_url(dataset.uri)
return OpenLineageDataset(namespace=f"gs://{bucket}", name=key if key else "/")
4 changes: 4 additions & 0 deletions airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,10 @@ dataset-uris:
handler: null
- schemes: [bigquery]
handler: airflow.providers.google.datasets.bigquery.sanitize_uri
- schemes: [gs]
handler: airflow.providers.google.datasets.gcs.sanitize_uri
factory: airflow.providers.google.datasets.gcs.create_dataset
to_openlineage_converter: airflow.providers.google.datasets.gcs.convert_dataset_to_openlineage

hooks:
- integration-name: Google Ads
Expand Down

0 comments on commit 4e70d6a

Please sign in to comment.