Skip to content

Commit

Permalink
gcsio: reduce number of get requests in function calls (#30205)
Browse files Browse the repository at this point in the history
* Reduce the number of get requests in gcsio.

* Apply formatter.

* Replace get_bucket with bucket in _gcs_object
  • Loading branch information
shunping authored Feb 7, 2024
1 parent 7a46686 commit 4645a76
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 50 deletions.
51 changes: 22 additions & 29 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,14 @@ def open(
ValueError: Invalid open file mode.
"""
bucket_name, blob_name = parse_gcs_path(filename)
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)

if mode == 'r' or mode == 'rb':
blob = bucket.get_blob(blob_name)
blob = bucket.blob(blob_name)
return BeamBlobReader(blob, chunk_size=read_buffer_size)
elif mode == 'w' or mode == 'wb':
blob = bucket.get_blob(blob_name)
if not blob:
blob = storage.Blob(blob_name, bucket)
blob = bucket.blob(blob_name)
return BeamBlobWriter(blob, mime_type)

else:
raise ValueError('Invalid file open mode: %s.' % mode)

Expand All @@ -199,7 +196,7 @@ def delete(self, path):
"""
bucket_name, blob_name = parse_gcs_path(path)
try:
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)
except NotFound:
return
Expand Down Expand Up @@ -228,16 +225,15 @@ def delete_batch(self, paths):
with current_batch:
for path in current_paths:
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)

for i, path in enumerate(current_paths):
error_code = None
for j in range(2):
resp = current_batch._responses[2 * i + j]
if resp.status_code >= 400 and resp.status_code != 404:
error_code = resp.status_code
break
resp = current_batch._responses[i]
if resp.status_code >= 400 and resp.status_code != 404:
error_code = resp.status_code
break
final_results.append((path, error_code))

s += MAX_BATCH_OPERATION_SIZE
Expand All @@ -258,11 +254,9 @@ def copy(self, src, dest):
"""
src_bucket_name, src_blob_name = parse_gcs_path(src)
dest_bucket_name, dest_blob_name= parse_gcs_path(dest, object_optional=True)
src_bucket = self.get_bucket(src_bucket_name)
src_blob = src_bucket.get_blob(src_blob_name)
if not src_blob:
raise NotFound("Source %s not found", src)
dest_bucket = self.get_bucket(dest_bucket_name)
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)
if not dest_blob_name:
dest_blob_name = None
src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_blob_name)
Expand Down Expand Up @@ -291,19 +285,18 @@ def copy_batch(self, src_dest_pairs):
for pair in current_pairs:
src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
src_bucket = self.client.get_bucket(src_bucket_name)
src_blob = src_bucket.get_blob(src_blob_name)
dest_bucket = self.client.get_bucket(dest_bucket_name)
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)

src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)

for i, pair in enumerate(current_pairs):
error_code = None
for j in range(4):
resp = current_batch._responses[4 * i + j]
if resp.status_code >= 400:
error_code = resp.status_code
break
resp = current_batch._responses[i]
if resp.status_code >= 400:
error_code = resp.status_code
break
final_results.append((pair[0], pair[1], error_code))

s += MAX_BATCH_OPERATION_SIZE
Expand Down Expand Up @@ -417,12 +410,12 @@ def _gcs_object(self, path):
"""Returns a gcs object for the given path
This method does not perform glob expansion. Hence the given path must be
for a single GCS object.
for a single GCS object. The method will make HTTP requests.
Returns: GCS object.
"""
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)
blob = bucket.get_blob(blob_name)
if blob:
return blob
Expand Down Expand Up @@ -470,7 +463,7 @@ def list_files(self, path, with_metadata=False):
_LOGGER.debug("Starting the file information of the input")
else:
_LOGGER.debug("Starting the size estimation of the input")
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)
response = self.client.list_blobs(bucket, prefix=prefix)
for item in response:
file_name = 'gs://%s/%s' % (item.bucket.name, item.name)
Expand Down
85 changes: 64 additions & 21 deletions sdks/python/apache_beam/io/gcp/gcsio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@ class FakeGcsClient(object):
def __init__(self):
self.buckets = {}

def _add_bucket(self, bucket):
self.buckets[bucket.name] = bucket
return self.buckets[bucket.name]

def bucket(self, name):
return FakeBucket(self, name)

def create_bucket(self, name):
self.buckets[name] = FakeBucket(self, name)
return self.buckets[name]
return self._add_bucket(self.bucket(name))

def get_bucket(self, name):
if name in self.buckets:
Expand Down Expand Up @@ -92,40 +98,51 @@ def __init__(self, client, name):
self.name = name
self.blobs = {}
self.default_kms_key_name = None
self.client.buckets[name] = self

def add_blob(self, blob):
self.blobs[blob.name] = blob
def _get_canonical_bucket(self):
return self.client.get_bucket(self.name)

def create_blob(self, name):
def _create_blob(self, name):
return FakeBlob(name, self)

def add_blob(self, blob):
bucket = self._get_canonical_bucket()
bucket.blobs[blob.name] = blob
return bucket.blobs[blob.name]

def blob(self, name):
return self._create_blob(name)

def copy_blob(self, blob, dest, new_name=None):
if self.get_blob(blob.name) is None:
raise NotFound("source blob not found")
if not new_name:
new_name = blob.name
dest.blobs[new_name] = blob
dest.blobs[new_name].name = new_name
dest.blobs[new_name].bucket = dest
return dest.blobs[new_name]
new_blob = FakeBlob(new_name, dest)
dest.add_blob(new_blob)
return new_blob

def get_blob(self, blob_name):
if blob_name in self.blobs:
return self.blobs[blob_name]
bucket = self._get_canonical_bucket()
if blob_name in bucket.blobs:
return bucket.blobs[blob_name]
else:
return None

def lookup_blob(self, name):
if name in self.blobs:
return self.blobs[name]
bucket = self._get_canonical_bucket()
if name in bucket.blobs:
return bucket.blobs[name]
else:
return self.create_blob(name)
return bucket.create_blob(name)

def set_default_kms_key_name(self, name):
self.default_kms_key_name = name

def delete_blob(self, name):
if name in self.blobs:
del self.blobs[name]
bucket = self._get_canonical_bucket()
if name in bucket.blobs:
del bucket.blobs[name]


class FakeBlob(object):
Expand All @@ -151,11 +168,18 @@ def __init__(
self.updated = updated
self._fail_when_getting_metadata = fail_when_getting_metadata
self._fail_when_reading = fail_when_reading
self.bucket.add_blob(self)

def delete(self):
if self.name in self.bucket.blobs:
del self.bucket.blobs[self.name]
self.bucket.delete_blob(self.name)

def download_as_bytes(self, **kwargs):
blob = self.bucket.get_blob(self.name)
if blob is None:
raise NotFound("blob not found")
return blob.contents

def __eq__(self, other):
return self.bucket.get_blob(self.name) is other.bucket.get_blob(other.name)


@unittest.skipIf(NotFound is None, 'GCP dependencies are not installed')
Expand Down Expand Up @@ -224,6 +248,7 @@ def _insert_random_file(
updated=updated,
fail_when_getting_metadata=fail_when_getting_metadata,
fail_when_reading=fail_when_reading)
bucket.add_blob(blob)
return blob

def setUp(self):
Expand Down Expand Up @@ -475,7 +500,25 @@ def test_list_prefix(self):
def test_downloader_fail_non_existent_object(self):
file_name = 'gs://gcsio-metrics-test/dummy_mode_file'
with self.assertRaises(NotFound):
self.gcs.open(file_name, 'r')
with self.gcs.open(file_name, 'r') as f:
f.read(1)

def test_blob_delete(self):
file_name = 'gs://gcsio-test/delete_me'
file_size = 1024
bucket_name, blob_name = gcsio.parse_gcs_path(file_name)
# Test deletion of non-existent file.
bucket = self.client.get_bucket(bucket_name)
self.gcs.delete(file_name)

self._insert_random_file(self.client, file_name, file_size)
self.assertTrue(blob_name in bucket.blobs)

blob = bucket.get_blob(blob_name)
self.assertIsNotNone(blob)

blob.delete()
self.assertFalse(blob_name in bucket.blobs)


if __name__ == '__main__':
Expand Down

0 comments on commit 4645a76

Please sign in to comment.