Skip to content

Commit

Permalink
Deprecate delimiter param and source object's wildcards in GCS, int…
Browse files Browse the repository at this point in the history
…roduce `match_glob` param. (#31261)

* Deprecate `delimiter` param and source object's wildcards in GCS, introduce `match_glob` param.

---------

Co-authored-by: eladkal <45845474+eladkal@users.noreply.github.com>
  • Loading branch information
shahar1 and eladkal authored Jun 30, 2023
1 parent 7e06c80 commit d6e254d
Show file tree
Hide file tree
Showing 14 changed files with 448 additions and 138 deletions.
18 changes: 16 additions & 2 deletions airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from __future__ import annotations

import os
import warnings
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
Expand All @@ -40,7 +42,7 @@ class GCSToS3Operator(BaseOperator):
:param bucket: The Google Cloud Storage bucket to find the objects. (templated)
:param prefix: Prefix string which filters objects whose name begin with
this prefix. (templated)
:param delimiter: The delimiter by which you want to filter the objects. (templated)
:param delimiter: (Deprecated) The delimiter by which you want to filter the objects. (templated)
For e.g to lists the CSV files from in a directory in GCS you would use
delimiter='.csv'.
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
Expand Down Expand Up @@ -76,6 +78,8 @@ class GCSToS3Operator(BaseOperator):
object to be uploaded in S3
:param keep_directory_structure: (Optional) When set to False the path of the file
on the bucket is recreated within path passed in dest_s3_key.
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``)
"""

template_fields: Sequence[str] = (
Expand All @@ -102,12 +106,19 @@ def __init__(
dest_s3_extra_args: dict | None = None,
s3_acl_policy: str | None = None,
keep_directory_structure: bool = True,
match_glob: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.bucket = bucket
self.prefix = prefix
if delimiter:
warnings.warn(
"Usage of 'delimiter' is deprecated, please use 'match_glob' instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.delimiter = delimiter
self.gcp_conn_id = gcp_conn_id
self.dest_aws_conn_id = dest_aws_conn_id
Expand All @@ -118,6 +129,7 @@ def __init__(
self.dest_s3_extra_args = dest_s3_extra_args or {}
self.s3_acl_policy = s3_acl_policy
self.keep_directory_structure = keep_directory_structure
self.match_glob = match_glob

def execute(self, context: Context) -> list[str]:
# list all files in an Google Cloud Storage bucket
Expand All @@ -133,7 +145,9 @@ def execute(self, context: Context) -> list[str]:
self.prefix,
)

files = hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter)
files = hook.list(
bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter, match_glob=self.match_glob
)

s3_hook = S3Hook(
aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify, extra_args=self.dest_s3_extra_args
Expand Down
126 changes: 108 additions & 18 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import shutil
import time
import warnings
from contextlib import contextmanager
from datetime import datetime
from functools import partial
Expand All @@ -44,7 +45,7 @@
from google.cloud.storage.retry import DEFAULT_RETRY
from requests import Session

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
Expand Down Expand Up @@ -709,6 +710,7 @@ def list(
max_results: int | None = None,
prefix: str | List[str] | None = None,
delimiter: str | None = None,
match_glob: str | None = None,
):
"""
List all objects from the bucket with the given a single prefix or multiple prefixes.
Expand All @@ -717,9 +719,19 @@ def list(
:param versions: if true, list all versions of the objects
:param max_results: max count of items to return in a single page of responses
:param prefix: string or list of strings which filter objects whose name begin with it/them
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``).
:return: a stream of object names matching the filtering criteria
"""
if delimiter and delimiter != "/":
warnings.warn(
"Usage of 'delimiter' param is deprecated, please use 'match_glob' instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
if match_glob and delimiter and delimiter != "/":
raise AirflowException("'match_glob' param cannot be used with 'delimiter' that differs than '/'")
objects = []
if isinstance(prefix, list):
for prefix_item in prefix:
Expand All @@ -730,6 +742,7 @@ def list(
max_results=max_results,
prefix=prefix_item,
delimiter=delimiter,
match_glob=match_glob,
)
)
else:
Expand All @@ -740,6 +753,7 @@ def list(
max_results=max_results,
prefix=prefix,
delimiter=delimiter,
match_glob=match_glob,
)
)
return objects
Expand All @@ -751,6 +765,7 @@ def _list(
max_results: int | None = None,
prefix: str | None = None,
delimiter: str | None = None,
match_glob: str | None = None,
) -> List:
"""
List all objects from the bucket with the give string prefix in name.
Expand All @@ -759,7 +774,9 @@ def _list(
:param versions: if true, list all versions of the objects
:param max_results: max count of items to return in a single page of responses
:param prefix: string which filters objects whose name begin with it
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``).
:return: a stream of object names matching the filtering criteria
"""
client = self.get_conn()
Expand All @@ -768,13 +785,25 @@ def _list(
ids = []
page_token = None
while True:
blobs = bucket.list_blobs(
max_results=max_results,
page_token=page_token,
prefix=prefix,
delimiter=delimiter,
versions=versions,
)
if match_glob:
blobs = self._list_blobs_with_match_glob(
bucket=bucket,
client=client,
match_glob=match_glob,
max_results=max_results,
page_token=page_token,
path=bucket.path + "/o",
prefix=prefix,
versions=versions,
)
else:
blobs = bucket.list_blobs(
max_results=max_results,
page_token=page_token,
prefix=prefix,
delimiter=delimiter,
versions=versions,
)

blob_names = []
for blob in blobs:
Expand All @@ -792,6 +821,52 @@ def _list(
break
return ids

@staticmethod
def _list_blobs_with_match_glob(
bucket,
client,
path: str,
max_results: int | None = None,
page_token: str | None = None,
match_glob: str | None = None,
prefix: str | None = None,
versions: bool | None = None,
) -> Any:
"""
List blobs when match_glob param is given.
This method is a patched version of google.cloud.storage Client.list_blobs().
It is used as a temporary workaround to support "match_glob" param,
as it isn't officially supported by GCS Python client.
(follow `issue #1035<https://github.com/googleapis/python-storage/issues/1035>`__).
"""
from google.api_core import page_iterator
from google.cloud.storage.bucket import _blobs_page_start, _item_to_blob

extra_params: Any = {}
if prefix is not None:
extra_params["prefix"] = prefix
if match_glob is not None:
extra_params["matchGlob"] = match_glob
if versions is not None:
extra_params["versions"] = versions
api_request = functools.partial(
client._connection.api_request, timeout=DEFAULT_TIMEOUT, retry=DEFAULT_RETRY
)

blobs: Any = page_iterator.HTTPIterator(
client=client,
api_request=api_request,
path=path,
item_to_value=_item_to_blob,
page_token=page_token,
max_results=max_results,
extra_params=extra_params,
page_start=_blobs_page_start,
)
blobs.prefixes = set()
blobs.bucket = bucket
return blobs

def list_by_timespan(
self,
bucket_name: str,
Expand All @@ -801,6 +876,7 @@ def list_by_timespan(
max_results: int | None = None,
prefix: str | None = None,
delimiter: str | None = None,
match_glob: str | None = None,
) -> List[str]:
"""
List all objects from the bucket with the give string prefix in name that were
Expand All @@ -813,7 +889,9 @@ def list_by_timespan(
:param max_results: max count of items to return in a single page of responses
:param prefix: prefix string which filters objects whose name begin with
this prefix
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``).
:return: a stream of object names matching the filtering criteria
"""
client = self.get_conn()
Expand All @@ -823,13 +901,25 @@ def list_by_timespan(
page_token = None

while True:
blobs = bucket.list_blobs(
max_results=max_results,
page_token=page_token,
prefix=prefix,
delimiter=delimiter,
versions=versions,
)
if match_glob:
blobs = self._list_blobs_with_match_glob(
bucket=bucket,
client=client,
match_glob=match_glob,
max_results=max_results,
page_token=page_token,
path=bucket.path + "/o",
prefix=prefix,
versions=versions,
)
else:
blobs = bucket.list_blobs(
max_results=max_results,
page_token=page_token,
prefix=prefix,
delimiter=delimiter,
versions=versions,
)

blob_names = []
for blob in blobs:
Expand Down
Loading

0 comments on commit d6e254d

Please sign in to comment.