diff --git a/google/api_core/page_iterator.py b/google/api_core/page_iterator.py index 11a92d38..fff3b556 100644 --- a/google/api_core/page_iterator.py +++ b/google/api_core/page_iterator.py @@ -179,7 +179,7 @@ def __init__( single item. """ self.max_results = max_results - """int: The maximum number of results to fetch.""" + """int: The maximum number of results to fetch""" # The attributes below will change over the life of the iterator. self.page_number = 0 @@ -298,7 +298,8 @@ class HTTPIterator(Iterator): can be found. page_token (str): A token identifying a page in a result set to start fetching results from. - max_results (int): The maximum number of results to fetch. + page_size (int): The maximum number of results to fetch per page + max_results (int): The maximum number of results to fetch extra_params (dict): Extra query string parameters for the API call. page_start (Callable[ @@ -329,6 +330,7 @@ def __init__( item_to_value, items_key=_DEFAULT_ITEMS_KEY, page_token=None, + page_size=None, max_results=None, extra_params=None, page_start=_do_nothing_page_start, @@ -341,6 +343,7 @@ def __init__( self.path = path self._items_key = items_key self.extra_params = extra_params + self._page_size = page_size self._page_start = page_start self._next_token = next_token # Verify inputs / provide defaults. @@ -399,8 +402,18 @@ def _get_query_params(self): result = {} if self.next_page_token is not None: result[self._PAGE_TOKEN] = self.next_page_token + + page_size = None if self.max_results is not None: - result[self._MAX_RESULTS] = self.max_results - self.num_results + page_size = self.max_results - self.num_results + if self._page_size is not None: + page_size = min(page_size, self._page_size) + elif self._page_size is not None: + page_size = self._page_size + + if page_size is not None: + result[self._MAX_RESULTS] = page_size + result.update(self.extra_params) return result diff --git a/tests/unit/test_page_iterator.py b/tests/unit/test_page_iterator.py index 2bf74249..83595376 100644 --- a/tests/unit/test_page_iterator.py +++ b/tests/unit/test_page_iterator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import types import mock @@ -235,6 +236,7 @@ def test_constructor(self): assert iterator.page_number == 0 assert iterator.next_page_token is None assert iterator.num_results == 0 + assert iterator._page_size is None def test_constructor_w_extra_param_collision(self): extra_params = {"pageToken": "val"} @@ -432,6 +434,66 @@ def test__get_next_page_bad_http_method(self): with pytest.raises(ValueError): iterator._get_next_page_response() + @pytest.mark.parametrize( + "page_size,max_results,pages", + [(3, None, False), (3, 8, False), (3, None, True), (3, 8, True)]) + def test_page_size_items(self, page_size, max_results, pages): + path = "/foo" + NITEMS = 10 + + n = [0] # blast you python 2! + + def api_request(*args, **kw): + assert not args + query_params = dict( + maxResults=( + page_size if max_results is None + else min(page_size, max_results - n[0])) + ) + if n[0]: + query_params.update(pageToken='test') + assert kw == {'method': 'GET', 'path': '/foo', + 'query_params': query_params} + n_items = min(kw['query_params']['maxResults'], NITEMS - n[0]) + items = [dict(name=str(i + n[0])) for i in range(n_items)] + n[0] += n_items + result = dict(items=items) + if n[0] < NITEMS: + result.update(nextPageToken='test') + return result + + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + api_request, + path=path, + item_to_value=page_iterator._item_to_value_identity, + page_size=page_size, + max_results=max_results, + ) + + assert iterator.num_results == 0 + + n_results = max_results if max_results is not None else NITEMS + if pages: + items_iter = iter(iterator.pages) + npages = int(math.ceil(float(n_results) / page_size)) + for ipage in range(npages): + assert ( + list(six.next(items_iter)) == [ + dict(name=str(i)) + for i in range(ipage * page_size, + min((ipage + 1) * page_size, n_results), + ) + ]) + else: + items_iter = iter(iterator) + for i in range(n_results): + assert six.next(items_iter) == dict(name=str(i)) + assert iterator.num_results == i + 1 + + with pytest.raises(StopIteration): + six.next(items_iter) + class TestGRPCIterator(object): def test_constructor(self):