diff --git a/api_core/google/api_core/future/polling.py b/api_core/google/api_core/future/polling.py index b5aecde73b6b..01adc21a09eb 100644 --- a/api_core/google/api_core/future/polling.py +++ b/api_core/google/api_core/future/polling.py @@ -28,7 +28,12 @@ class _OperationNotComplete(Exception): pass -RETRY_PREDICATE = retry.if_exception_type(_OperationNotComplete) +RETRY_PREDICATE = retry.if_exception_type( + _OperationNotComplete, + exceptions.TooManyRequests, + exceptions.InternalServerError, + exceptions.BadGateway, +) DEFAULT_RETRY = retry.Retry(predicate=RETRY_PREDICATE) diff --git a/api_core/tests/unit/future/test_polling.py b/api_core/tests/unit/future/test_polling.py index f84e7ad3d576..f56f0c5304a0 100644 --- a/api_core/tests/unit/future/test_polling.py +++ b/api_core/tests/unit/future/test_polling.py @@ -19,6 +19,7 @@ import mock import pytest +from google.api_core import exceptions from google.api_core.future import polling @@ -118,6 +119,34 @@ def test_result_timeout(): future.result(timeout=1) +class PollingFutureImplTransient(PollingFutureImplWithPoll): + def __init__(self, errors): + super(PollingFutureImplTransient, self).__init__() + self._errors = errors + + def done(self): + if self._errors: + error, self._errors = self._errors[0], self._errors[1:] + raise error('testing') + self.poll_count += 1 + self.set_result(42) + return True + + +def test_result_transient_error(): + future = PollingFutureImplTransient(( + exceptions.TooManyRequests, + exceptions.InternalServerError, + exceptions.BadGateway, + )) + result = future.result() + assert result == 42 + assert future.poll_count == 1 + # Repeated calls should not cause additional polling + assert future.result() == result + assert future.poll_count == 1 + + def test_callback_background_thread(): future = PollingFutureImplWithPoll() callback = mock.Mock()