diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index 33df98e37b002..825a357d3979e 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -21,6 +21,7 @@ import logging import os import re +from collections import namedtuple from datetime import datetime, timedelta from io import BytesIO from unittest import mock @@ -799,14 +800,26 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file): ) @pytest.mark.parametrize( - "prefix, result", + "prefix, blob_names, returned_prefixes, call_args, result", ( ( "prefix", + ["prefix"], + None, + [mock.call(delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None)], + ["prefix"], + ), + ( + "prefix", + ["prefix"], + {"prefix,"}, [mock.call(delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None)], + ["prefix,"], ), ( ["prefix", "prefix_2"], + ["prefix", "prefix2"], + None, [ mock.call( delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None @@ -815,19 +828,38 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file): delimiter=",", prefix="prefix_2", versions=None, max_results=None, page_token=None ), ], + ["prefix", "prefix2"], ), ), ) @mock.patch(GCS_STRING.format("GCSHook.get_conn")) - def test_list__delimiter(self, mock_service, prefix, result): - mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token = None + def test_list__delimiter(self, mock_service, prefix, blob_names, returned_prefixes, call_args, result): + Blob = namedtuple("Blob", ["name"]) + + class BlobsIterator: + def __init__(self): + self._item_iter = (Blob(name=name) for name in blob_names) + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self._item_iter) + except StopIteration: + self.prefixes = returned_prefixes + self.next_page_token = None + raise + + mock_service.return_value.bucket.return_value.list_blobs.return_value = BlobsIterator() with pytest.deprecated_call(): - self.gcs_hook.list( + blobs = self.gcs_hook.list( bucket_name="test_bucket", prefix=prefix, delimiter=",", ) - assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == result + assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == call_args + assert blobs == result @mock.patch(GCS_STRING.format("GCSHook.get_conn")) @mock.patch("airflow.providers.google.cloud.hooks.gcs.functools")