Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gcsio: reduce number of get requests in function calls #30205

Merged
merged 3 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 22 additions & 29 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,14 @@ def open(
ValueError: Invalid open file mode.
"""
bucket_name, blob_name = parse_gcs_path(filename)
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)

if mode == 'r' or mode == 'rb':
blob = bucket.get_blob(blob_name)
blob = bucket.blob(blob_name)
return BeamBlobReader(blob, chunk_size=read_buffer_size)
elif mode == 'w' or mode == 'wb':
blob = bucket.get_blob(blob_name)
if not blob:
blob = storage.Blob(blob_name, bucket)
blob = bucket.blob(blob_name)
return BeamBlobWriter(blob, mime_type)

else:
raise ValueError('Invalid file open mode: %s.' % mode)

Expand All @@ -199,7 +196,7 @@ def delete(self, path):
"""
bucket_name, blob_name = parse_gcs_path(path)
try:
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)
except NotFound:
return
Expand Down Expand Up @@ -228,16 +225,15 @@ def delete_batch(self, paths):
with current_batch:
for path in current_paths:
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)

for i, path in enumerate(current_paths):
error_code = None
for j in range(2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. That was why it needed to be range(2) or range(4) in delete_batch and copy_batch

resp = current_batch._responses[2 * i + j]
if resp.status_code >= 400 and resp.status_code != 404:
error_code = resp.status_code
break
resp = current_batch._responses[i]
if resp.status_code >= 400 and resp.status_code != 404:
error_code = resp.status_code
break
final_results.append((path, error_code))

s += MAX_BATCH_OPERATION_SIZE
Expand All @@ -258,11 +254,9 @@ def copy(self, src, dest):
"""
src_bucket_name, src_blob_name = parse_gcs_path(src)
dest_bucket_name, dest_blob_name= parse_gcs_path(dest, object_optional=True)
src_bucket = self.get_bucket(src_bucket_name)
src_blob = src_bucket.get_blob(src_blob_name)
if not src_blob:
raise NotFound("Source %s not found", src)
dest_bucket = self.get_bucket(dest_bucket_name)
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)
if not dest_blob_name:
dest_blob_name = None
src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_blob_name)
Expand Down Expand Up @@ -291,19 +285,18 @@ def copy_batch(self, src_dest_pairs):
for pair in current_pairs:
src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
src_bucket = self.client.get_bucket(src_bucket_name)
src_blob = src_bucket.get_blob(src_blob_name)
dest_bucket = self.client.get_bucket(dest_bucket_name)
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)

src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)

for i, pair in enumerate(current_pairs):
error_code = None
for j in range(4):
resp = current_batch._responses[4 * i + j]
if resp.status_code >= 400:
error_code = resp.status_code
break
resp = current_batch._responses[i]
if resp.status_code >= 400:
error_code = resp.status_code
break
final_results.append((pair[0], pair[1], error_code))

s += MAX_BATCH_OPERATION_SIZE
Expand Down Expand Up @@ -417,12 +410,12 @@ def _gcs_object(self, path):
"""Returns a gcs object for the given path

This method does not perform glob expansion. Hence the given path must be
for a single GCS object.
for a single GCS object. The method will make HTTP requests.

Returns: GCS object.
"""
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)
blob = bucket.get_blob(blob_name)
if blob:
return blob
Expand Down Expand Up @@ -470,7 +463,7 @@ def list_files(self, path, with_metadata=False):
_LOGGER.debug("Starting the file information of the input")
else:
_LOGGER.debug("Starting the size estimation of the input")
bucket = self.client.get_bucket(bucket_name)
bucket = self.client.bucket(bucket_name)
response = self.client.list_blobs(bucket, prefix=prefix)
for item in response:
file_name = 'gs://%s/%s' % (item.bucket.name, item.name)
Expand Down
85 changes: 64 additions & 21 deletions sdks/python/apache_beam/io/gcp/gcsio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@ class FakeGcsClient(object):
def __init__(self):
self.buckets = {}

def _add_bucket(self, bucket):
self.buckets[bucket.name] = bucket
return self.buckets[bucket.name]

def bucket(self, name):
return FakeBucket(self, name)

def create_bucket(self, name):
self.buckets[name] = FakeBucket(self, name)
return self.buckets[name]
return self._add_bucket(self.bucket(name))

def get_bucket(self, name):
if name in self.buckets:
Expand Down Expand Up @@ -92,40 +98,51 @@ def __init__(self, client, name):
self.name = name
self.blobs = {}
self.default_kms_key_name = None
self.client.buckets[name] = self

def add_blob(self, blob):
self.blobs[blob.name] = blob
def _get_canonical_bucket(self):
return self.client.get_bucket(self.name)

def create_blob(self, name):
def _create_blob(self, name):
return FakeBlob(name, self)

def add_blob(self, blob):
bucket = self._get_canonical_bucket()
bucket.blobs[blob.name] = blob
return bucket.blobs[blob.name]

def blob(self, name):
return self._create_blob(name)

def copy_blob(self, blob, dest, new_name=None):
if self.get_blob(blob.name) is None:
raise NotFound("source blob not found")
if not new_name:
new_name = blob.name
dest.blobs[new_name] = blob
dest.blobs[new_name].name = new_name
dest.blobs[new_name].bucket = dest
return dest.blobs[new_name]
new_blob = FakeBlob(new_name, dest)
dest.add_blob(new_blob)
return new_blob

def get_blob(self, blob_name):
if blob_name in self.blobs:
return self.blobs[blob_name]
bucket = self._get_canonical_bucket()
if blob_name in bucket.blobs:
return bucket.blobs[blob_name]
else:
return None

def lookup_blob(self, name):
if name in self.blobs:
return self.blobs[name]
bucket = self._get_canonical_bucket()
if name in bucket.blobs:
return bucket.blobs[name]
else:
return self.create_blob(name)
return bucket.create_blob(name)

def set_default_kms_key_name(self, name):
self.default_kms_key_name = name

def delete_blob(self, name):
if name in self.blobs:
del self.blobs[name]
bucket = self._get_canonical_bucket()
if name in bucket.blobs:
del bucket.blobs[name]


class FakeBlob(object):
Expand All @@ -151,11 +168,18 @@ def __init__(
self.updated = updated
self._fail_when_getting_metadata = fail_when_getting_metadata
self._fail_when_reading = fail_when_reading
self.bucket.add_blob(self)

def delete(self):
if self.name in self.bucket.blobs:
del self.bucket.blobs[self.name]
self.bucket.delete_blob(self.name)

def download_as_bytes(self, **kwargs):
blob = self.bucket.get_blob(self.name)
if blob is None:
raise NotFound("blob not found")
return blob.contents

def __eq__(self, other):
return self.bucket.get_blob(self.name) is other.bucket.get_blob(other.name)


@unittest.skipIf(NotFound is None, 'GCP dependencies are not installed')
Expand Down Expand Up @@ -224,6 +248,7 @@ def _insert_random_file(
updated=updated,
fail_when_getting_metadata=fail_when_getting_metadata,
fail_when_reading=fail_when_reading)
bucket.add_blob(blob)
return blob

def setUp(self):
Expand Down Expand Up @@ -475,7 +500,25 @@ def test_list_prefix(self):
def test_downloader_fail_non_existent_object(self):
file_name = 'gs://gcsio-metrics-test/dummy_mode_file'
with self.assertRaises(NotFound):
self.gcs.open(file_name, 'r')
with self.gcs.open(file_name, 'r') as f:
f.read(1)

def test_blob_delete(self):
file_name = 'gs://gcsio-test/delete_me'
file_size = 1024
bucket_name, blob_name = gcsio.parse_gcs_path(file_name)
# Test deletion of non-existent file.
bucket = self.client.get_bucket(bucket_name)
self.gcs.delete(file_name)

self._insert_random_file(self.client, file_name, file_size)
self.assertTrue(blob_name in bucket.blobs)

blob = bucket.get_blob(blob_name)
self.assertIsNotNone(blob)

blob.delete()
self.assertFalse(blob_name in bucket.blobs)


if __name__ == '__main__':
Expand Down
Loading