From 31ebd5e3d4ae516ef0baf63ffae841afb36e7b52 Mon Sep 17 00:00:00 2001 From: Jordan Guymon Date: Thu, 11 Aug 2016 09:47:47 -0700 Subject: [PATCH] Load all pages for operations that paginate. --- awsshell/wizard.py | 19 ++++++++++++++----- tests/unit/test_wizard.py | 21 ++++++++++++++++++++- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/awsshell/wizard.py b/awsshell/wizard.py index a32769a..dac81bc 100644 --- a/awsshell/wizard.py +++ b/awsshell/wizard.py @@ -249,16 +249,25 @@ def _handle_request_retrieval(self): req = self.retrieval['Resource'] # get client from wizard's cache client = self._cached_creator.create_client(req['Service']) - # get the operation from the client - operation = getattr(client, xform_name(req['Operation'])) + operation_name = xform_name(req['Operation']) # get any parameters parameters = req.get('Parameters', {}) env_parameters = \ self._env.resolve_parameters(req.get('EnvParameters', {})) - # union of parameters and env_parameters, conflicts favor env_params + # union of parameters and env_parameters, conflicts favor env params parameters = dict(parameters, **env_parameters) - # execute operation passing all parameters - return operation(**parameters) + # if the operation supports pagination, load all results upfront + if client.can_paginate(operation_name): + # get paginator and create iterator + paginator = client.get_paginator(operation_name) + page_iterator = paginator.paginate(**parameters) + # scroll through all pages combining them + return page_iterator.build_full_result() + else: + # get the operation from the client + operation = getattr(client, operation_name) + # execute operation passing all parameters + return operation(**parameters) def _handle_retrieval(self): # In case of no retrieval, empty dict diff --git a/tests/unit/test_wizard.py b/tests/unit/test_wizard.py index 6e05374..98331d3 100644 --- a/tests/unit/test_wizard.py +++ b/tests/unit/test_wizard.py @@ -1,5 +1,6 @@ import mock import pytest +from botocore.session import Session from awsshell.utils import FileReadError from awsshell.wizard import stage_error_handler from awsshell.interaction import InteractionException @@ -138,7 +139,8 @@ def test_static_retrieval_with_query(wizard_spec, loader): def test_request_retrieval(wizard_spec_request): # Tests that retrieval requests are parsed and call the correct operation - mock_session = mock.Mock() + mock_session = mock.Mock(spec=Session) + mock_session.create_client.return_value.can_paginate.return_value = False mock_request = mock_session.create_client.return_value.create_rest_api mock_request.return_value = {'id': 'api id', 'name': 'random name'} @@ -148,6 +150,23 @@ def test_request_retrieval(wizard_spec_request): mock_request.assert_called_once_with(param='value', name='new api name') +def test_request_retrieval_paginate(wizard_spec_request): + # Tests that retrieval requests are parsed and call the correct operation + mock_session = mock.Mock(spec=Session) + mock_client = mock_session.create_client.return_value + mock_client.can_paginate.return_value = True + mock_paginator = mock_client.get_paginator.return_value + mock_iterator = mock_paginator.paginate.return_value + result = {'id': 'api id', 'name': 'random name'} + mock_iterator.build_full_result.return_value = result + paginate = mock_paginator.paginate + + loader = WizardLoader(mock_session) + wizard = loader.create_wizard(wizard_spec_request) + wizard.execute() + paginate.assert_called_once_with(param='value', name='new api name') + + def test_next_stage_resolution(wizard_spec, loader): # Test that the stage can resolve the next stage from env wizard_spec['Stages'][0]['Retrieval']['Path'] = '[0]'