From 4645a76ee2a4b04f93ddb91b5f0a94d6713021fc Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Wed, 7 Feb 2024 10:06:58 -0500 Subject: [PATCH] gcsio: reduce number of get requests in function calls (#30205) * Reduce the number of get requests in gcsio. * Apply formatter. * Replace get_bucket with bucket in _gcs_object --- sdks/python/apache_beam/io/gcp/gcsio.py | 51 +++++------- sdks/python/apache_beam/io/gcp/gcsio_test.py | 85 +++++++++++++++----- 2 files changed, 86 insertions(+), 50 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index a6ba82a6e07c..b5a291428767 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -175,17 +175,14 @@ def open( ValueError: Invalid open file mode. """ bucket_name, blob_name = parse_gcs_path(filename) - bucket = self.client.get_bucket(bucket_name) + bucket = self.client.bucket(bucket_name) if mode == 'r' or mode == 'rb': - blob = bucket.get_blob(blob_name) + blob = bucket.blob(blob_name) return BeamBlobReader(blob, chunk_size=read_buffer_size) elif mode == 'w' or mode == 'wb': - blob = bucket.get_blob(blob_name) - if not blob: - blob = storage.Blob(blob_name, bucket) + blob = bucket.blob(blob_name) return BeamBlobWriter(blob, mime_type) - else: raise ValueError('Invalid file open mode: %s.' % mode) @@ -199,7 +196,7 @@ def delete(self, path): """ bucket_name, blob_name = parse_gcs_path(path) try: - bucket = self.client.get_bucket(bucket_name) + bucket = self.client.bucket(bucket_name) bucket.delete_blob(blob_name) except NotFound: return @@ -228,16 +225,15 @@ def delete_batch(self, paths): with current_batch: for path in current_paths: bucket_name, blob_name = parse_gcs_path(path) - bucket = self.client.get_bucket(bucket_name) + bucket = self.client.bucket(bucket_name) bucket.delete_blob(blob_name) for i, path in enumerate(current_paths): error_code = None - for j in range(2): - resp = current_batch._responses[2 * i + j] - if resp.status_code >= 400 and resp.status_code != 404: - error_code = resp.status_code - break + resp = current_batch._responses[i] + if resp.status_code >= 400 and resp.status_code != 404: + error_code = resp.status_code + break final_results.append((path, error_code)) s += MAX_BATCH_OPERATION_SIZE @@ -258,11 +254,9 @@ def copy(self, src, dest): """ src_bucket_name, src_blob_name = parse_gcs_path(src) dest_bucket_name, dest_blob_name= parse_gcs_path(dest, object_optional=True) - src_bucket = self.get_bucket(src_bucket_name) - src_blob = src_bucket.get_blob(src_blob_name) - if not src_blob: - raise NotFound("Source %s not found", src) - dest_bucket = self.get_bucket(dest_bucket_name) + src_bucket = self.client.bucket(src_bucket_name) + src_blob = src_bucket.blob(src_blob_name) + dest_bucket = self.client.bucket(dest_bucket_name) if not dest_blob_name: dest_blob_name = None src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_blob_name) @@ -291,19 +285,18 @@ def copy_batch(self, src_dest_pairs): for pair in current_pairs: src_bucket_name, src_blob_name = parse_gcs_path(pair[0]) dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1]) - src_bucket = self.client.get_bucket(src_bucket_name) - src_blob = src_bucket.get_blob(src_blob_name) - dest_bucket = self.client.get_bucket(dest_bucket_name) + src_bucket = self.client.bucket(src_bucket_name) + src_blob = src_bucket.blob(src_blob_name) + dest_bucket = self.client.bucket(dest_bucket_name) src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name) for i, pair in enumerate(current_pairs): error_code = None - for j in range(4): - resp = current_batch._responses[4 * i + j] - if resp.status_code >= 400: - error_code = resp.status_code - break + resp = current_batch._responses[i] + if resp.status_code >= 400: + error_code = resp.status_code + break final_results.append((pair[0], pair[1], error_code)) s += MAX_BATCH_OPERATION_SIZE @@ -417,12 +410,12 @@ def _gcs_object(self, path): """Returns a gcs object for the given path This method does not perform glob expansion. Hence the given path must be - for a single GCS object. + for a single GCS object. The method will make HTTP requests. Returns: GCS object. """ bucket_name, blob_name = parse_gcs_path(path) - bucket = self.client.get_bucket(bucket_name) + bucket = self.client.bucket(bucket_name) blob = bucket.get_blob(blob_name) if blob: return blob @@ -470,7 +463,7 @@ def list_files(self, path, with_metadata=False): _LOGGER.debug("Starting the file information of the input") else: _LOGGER.debug("Starting the size estimation of the input") - bucket = self.client.get_bucket(bucket_name) + bucket = self.client.bucket(bucket_name) response = self.client.list_blobs(bucket, prefix=prefix) for item in response: file_name = 'gs://%s/%s' % (item.bucket.name, item.name) diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index f8b580c91c95..c9a7fb72f779 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -43,9 +43,15 @@ class FakeGcsClient(object): def __init__(self): self.buckets = {} + def _add_bucket(self, bucket): + self.buckets[bucket.name] = bucket + return self.buckets[bucket.name] + + def bucket(self, name): + return FakeBucket(self, name) + def create_bucket(self, name): - self.buckets[name] = FakeBucket(self, name) - return self.buckets[name] + return self._add_bucket(self.bucket(name)) def get_bucket(self, name): if name in self.buckets: @@ -92,40 +98,51 @@ def __init__(self, client, name): self.name = name self.blobs = {} self.default_kms_key_name = None - self.client.buckets[name] = self - def add_blob(self, blob): - self.blobs[blob.name] = blob + def _get_canonical_bucket(self): + return self.client.get_bucket(self.name) - def create_blob(self, name): + def _create_blob(self, name): return FakeBlob(name, self) + def add_blob(self, blob): + bucket = self._get_canonical_bucket() + bucket.blobs[blob.name] = blob + return bucket.blobs[blob.name] + + def blob(self, name): + return self._create_blob(name) + def copy_blob(self, blob, dest, new_name=None): + if self.get_blob(blob.name) is None: + raise NotFound("source blob not found") if not new_name: new_name = blob.name - dest.blobs[new_name] = blob - dest.blobs[new_name].name = new_name - dest.blobs[new_name].bucket = dest - return dest.blobs[new_name] + new_blob = FakeBlob(new_name, dest) + dest.add_blob(new_blob) + return new_blob def get_blob(self, blob_name): - if blob_name in self.blobs: - return self.blobs[blob_name] + bucket = self._get_canonical_bucket() + if blob_name in bucket.blobs: + return bucket.blobs[blob_name] else: return None def lookup_blob(self, name): - if name in self.blobs: - return self.blobs[name] + bucket = self._get_canonical_bucket() + if name in bucket.blobs: + return bucket.blobs[name] else: - return self.create_blob(name) + return bucket.create_blob(name) def set_default_kms_key_name(self, name): self.default_kms_key_name = name def delete_blob(self, name): - if name in self.blobs: - del self.blobs[name] + bucket = self._get_canonical_bucket() + if name in bucket.blobs: + del bucket.blobs[name] class FakeBlob(object): @@ -151,11 +168,18 @@ def __init__( self.updated = updated self._fail_when_getting_metadata = fail_when_getting_metadata self._fail_when_reading = fail_when_reading - self.bucket.add_blob(self) def delete(self): - if self.name in self.bucket.blobs: - del self.bucket.blobs[self.name] + self.bucket.delete_blob(self.name) + + def download_as_bytes(self, **kwargs): + blob = self.bucket.get_blob(self.name) + if blob is None: + raise NotFound("blob not found") + return blob.contents + + def __eq__(self, other): + return self.bucket.get_blob(self.name) is other.bucket.get_blob(other.name) @unittest.skipIf(NotFound is None, 'GCP dependencies are not installed') @@ -224,6 +248,7 @@ def _insert_random_file( updated=updated, fail_when_getting_metadata=fail_when_getting_metadata, fail_when_reading=fail_when_reading) + bucket.add_blob(blob) return blob def setUp(self): @@ -475,7 +500,25 @@ def test_list_prefix(self): def test_downloader_fail_non_existent_object(self): file_name = 'gs://gcsio-metrics-test/dummy_mode_file' with self.assertRaises(NotFound): - self.gcs.open(file_name, 'r') + with self.gcs.open(file_name, 'r') as f: + f.read(1) + + def test_blob_delete(self): + file_name = 'gs://gcsio-test/delete_me' + file_size = 1024 + bucket_name, blob_name = gcsio.parse_gcs_path(file_name) + # Test deletion of non-existent file. + bucket = self.client.get_bucket(bucket_name) + self.gcs.delete(file_name) + + self._insert_random_file(self.client, file_name, file_size) + self.assertTrue(blob_name in bucket.blobs) + + blob = bucket.get_blob(blob_name) + self.assertIsNotNone(blob) + + blob.delete() + self.assertFalse(blob_name in bucket.blobs) if __name__ == '__main__':