From 94043b1990433f0cfc78cd00fc6b135ef9291143 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 17 Jul 2019 13:03:13 -0400 Subject: [PATCH] Firestore: Add 'should_terminate' predicate for clean BiDi shutdown. (#8650) Closes #7826. --- api_core/google/api_core/bidi.py | 20 +++- api_core/tests/unit/test_bidi.py | 110 +++++++++++++++++- firestore/google/cloud/firestore_v1/watch.py | 28 ++--- .../tests/unit/v1/test_cross_language.py | 10 +- firestore/tests/unit/v1/test_watch.py | 67 ++++++++++- 5 files changed, 207 insertions(+), 28 deletions(-) diff --git a/api_core/google/api_core/bidi.py b/api_core/google/api_core/bidi.py index 3b69e91be16c..7d3716dccfe3 100644 --- a/api_core/google/api_core/bidi.py +++ b/api_core/google/api_core/bidi.py @@ -349,6 +349,11 @@ def pending_requests(self): return self._request_queue.qsize() +def _never_terminate(future_or_error): + """By default, no errors cause BiDi termination.""" + return False + + class ResumableBidiRpc(BidiRpc): """A :class:`BidiRpc` that can automatically resume the stream on errors. @@ -391,6 +396,9 @@ def should_recover(exc): should_recover (Callable[[Exception], bool]): A function that returns True if the stream should be recovered. This will be called whenever an error is encountered on the stream. + should_terminate (Callable[[Exception], bool]): A function that returns + True if the stream should be terminated. This will be called + whenever an error is encountered on the stream. metadata Sequence[Tuple(str, str)]: RPC metadata to include in the request. throttle_reopen (bool): If ``True``, throttling will be applied to @@ -401,12 +409,14 @@ def __init__( self, start_rpc, should_recover, + should_terminate=_never_terminate, initial_request=None, metadata=None, throttle_reopen=False, ): super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata) self._should_recover = should_recover + self._should_terminate = should_terminate self._operational_lock = threading.RLock() self._finalized = False self._finalize_lock = threading.Lock() @@ -433,7 +443,9 @@ def _on_call_done(self, future): # error, not for errors that we can recover from. Note that grpc's # "future" here is also a grpc.RpcError. with self._operational_lock: - if not self._should_recover(future): + if self._should_terminate(future): + self._finalize(future) + elif not self._should_recover(future): self._finalize(future) else: _LOGGER.debug("Re-opening stream from gRPC callback.") @@ -496,6 +508,12 @@ def _recoverable(self, method, *args, **kwargs): with self._operational_lock: _LOGGER.debug("Call to retryable %r caused %s.", method, exc) + if self._should_terminate(exc): + self.close() + _LOGGER.debug("Terminating %r due to %s.", method, exc) + self._finalize(exc) + break + if not self._should_recover(exc): self.close() _LOGGER.debug("Not retrying %r due to %s.", method, exc) diff --git a/api_core/tests/unit/test_bidi.py b/api_core/tests/unit/test_bidi.py index 8e9f26202fde..4d185d3158e4 100644 --- a/api_core/tests/unit/test_bidi.py +++ b/api_core/tests/unit/test_bidi.py @@ -370,16 +370,65 @@ def cancel(self): class TestResumableBidiRpc(object): - def test_initial_state(self): - callback = mock.Mock() - callback.return_value = True - bidi_rpc = bidi.ResumableBidiRpc(None, callback) + def test_ctor_defaults(self): + start_rpc = mock.Mock() + should_recover = mock.Mock() + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + + assert bidi_rpc.is_active is False + assert bidi_rpc._finalized is False + assert bidi_rpc._start_rpc is start_rpc + assert bidi_rpc._should_recover is should_recover + assert bidi_rpc._should_terminate is bidi._never_terminate + assert bidi_rpc._initial_request is None + assert bidi_rpc._rpc_metadata is None + assert bidi_rpc._reopen_throttle is None + + def test_ctor_explicit(self): + start_rpc = mock.Mock() + should_recover = mock.Mock() + should_terminate = mock.Mock() + initial_request = mock.Mock() + metadata = {"x-foo": "bar"} + bidi_rpc = bidi.ResumableBidiRpc( + start_rpc, + should_recover, + should_terminate=should_terminate, + initial_request=initial_request, + metadata=metadata, + throttle_reopen=True, + ) assert bidi_rpc.is_active is False + assert bidi_rpc._finalized is False + assert bidi_rpc._should_recover is should_recover + assert bidi_rpc._should_terminate is should_terminate + assert bidi_rpc._initial_request is initial_request + assert bidi_rpc._rpc_metadata == metadata + assert isinstance(bidi_rpc._reopen_throttle, bidi._Throttle) + + def test_done_callbacks_terminate(self): + cancellation = mock.Mock() + start_rpc = mock.Mock() + should_recover = mock.Mock(spec=["__call__"], return_value=True) + should_terminate = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc( + start_rpc, should_recover, should_terminate=should_terminate + ) + callback = mock.Mock(spec=["__call__"]) + + bidi_rpc.add_done_callback(callback) + bidi_rpc._on_call_done(cancellation) + + should_terminate.assert_called_once_with(cancellation) + should_recover.assert_not_called() + callback.assert_called_once_with(cancellation) + assert not bidi_rpc.is_active def test_done_callbacks_recoverable(self): start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True) - bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True) + should_recover = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) callback = mock.Mock(spec=["__call__"]) bidi_rpc.add_done_callback(callback) @@ -387,16 +436,45 @@ def test_done_callbacks_recoverable(self): callback.assert_not_called() start_rpc.assert_called_once() + should_recover.assert_called_once_with(mock.sentinel.future) assert bidi_rpc.is_active def test_done_callbacks_non_recoverable(self): - bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False) + start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) callback = mock.Mock(spec=["__call__"]) bidi_rpc.add_done_callback(callback) bidi_rpc._on_call_done(mock.sentinel.future) callback.assert_called_once_with(mock.sentinel.future) + should_recover.assert_called_once_with(mock.sentinel.future) + assert not bidi_rpc.is_active + + def test_send_terminate(self): + cancellation = ValueError() + call_1 = CallStub([cancellation], active=False) + call_2 = CallStub([]) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2] + ) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + should_terminate = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate) + + bidi_rpc.open() + + bidi_rpc.send(mock.sentinel.request) + + assert bidi_rpc.pending_requests == 1 + assert bidi_rpc._request_queue.get() is None + + should_recover.assert_not_called() + should_terminate.assert_called_once_with(cancellation) + assert bidi_rpc.call == call_1 + assert bidi_rpc.is_active is False + assert call_1.cancelled is True def test_send_recover(self): error = ValueError() @@ -441,6 +519,26 @@ def test_send_failure(self): assert bidi_rpc.pending_requests == 1 assert bidi_rpc._request_queue.get() is None + def test_recv_terminate(self): + cancellation = ValueError() + call = CallStub([cancellation]) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, return_value=call + ) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + should_terminate = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate) + + bidi_rpc.open() + + bidi_rpc.recv() + + should_recover.assert_not_called() + should_terminate.assert_called_once_with(cancellation) + assert bidi_rpc.call == call + assert bidi_rpc.is_active is False + assert call.cancelled is True + def test_recv_recover(self): error = ValueError() call_1 = CallStub([1, error]) diff --git a/firestore/google/cloud/firestore_v1/watch.py b/firestore/google/cloud/firestore_v1/watch.py index ac20b98bfe33..10a4f6dfebf5 100644 --- a/firestore/google/cloud/firestore_v1/watch.py +++ b/firestore/google/cloud/firestore_v1/watch.py @@ -57,13 +57,8 @@ "DO_NOT_USE": -1, } _RPC_ERROR_THREAD_NAME = "Thread-OnRpcTerminated" -_RETRYABLE_STREAM_ERRORS = ( - exceptions.DeadlineExceeded, - exceptions.ServiceUnavailable, - exceptions.InternalServerError, - exceptions.Unknown, - exceptions.GatewayTimeout, -) +_RECOVERABLE_STREAM_EXCEPTIONS = (exceptions.ServiceUnavailable,) +_TERMINATING_STREAM_EXCEPTIONS = (exceptions.Cancelled,) DocTreeEntry = collections.namedtuple("DocTreeEntry", ["value", "index"]) @@ -153,6 +148,16 @@ def document_watch_comparator(doc1, doc2): return 0 +def _should_recover(exception): + wrapped = _maybe_wrap_exception(exception) + return isinstance(wrapped, _RECOVERABLE_STREAM_EXCEPTIONS) + + +def _should_terminate(exception): + wrapped = _maybe_wrap_exception(exception) + return isinstance(wrapped, _TERMINATING_STREAM_EXCEPTIONS) + + class Watch(object): BackgroundConsumer = BackgroundConsumer # FBO unit tests @@ -199,12 +204,6 @@ def __init__( self._closing = threading.Lock() self._closed = False - def should_recover(exc): # pragma: NO COVER - return ( - isinstance(exc, grpc.RpcError) - and exc.code() == grpc.StatusCode.UNAVAILABLE - ) - initial_request = firestore_pb2.ListenRequest( database=self._firestore._database_string, add_target=self._targets ) @@ -214,8 +213,9 @@ def should_recover(exc): # pragma: NO COVER self._rpc = ResumableBidiRpc( self._api.transport.listen, + should_recover=_should_recover, + should_terminate=_should_terminate, initial_request=initial_request, - should_recover=should_recover, metadata=self._firestore._rpc_metadata, ) diff --git a/firestore/tests/unit/v1/test_cross_language.py b/firestore/tests/unit/v1/test_cross_language.py index 6bc4b7cc4b4e..2cfb68d967d8 100644 --- a/firestore/tests/unit/v1/test_cross_language.py +++ b/firestore/tests/unit/v1/test_cross_language.py @@ -343,10 +343,18 @@ def convert_precondition(precond): class DummyRpc(object): # pragma: NO COVER - def __init__(self, listen, initial_request, should_recover, metadata=None): + def __init__( + self, + listen, + should_recover, + should_terminate=None, + initial_request=None, + metadata=None, + ): self.listen = listen self.initial_request = initial_request self.should_recover = should_recover + self.should_terminate = should_terminate self.closed = False self.callbacks = [] self._metadata = metadata diff --git a/firestore/tests/unit/v1/test_watch.py b/firestore/tests/unit/v1/test_watch.py index 2e31f9a77009..363d7d1284a4 100644 --- a/firestore/tests/unit/v1/test_watch.py +++ b/firestore/tests/unit/v1/test_watch.py @@ -110,6 +110,44 @@ def test_diff_doc(self): self.assertRaises(AssertionError, self._callFUT, 1, 2) +class Test_should_recover(unittest.TestCase): + def _callFUT(self, exception): + from google.cloud.firestore_v1.watch import _should_recover + + return _should_recover(exception) + + def test_w_unavailable(self): + from google.api_core.exceptions import ServiceUnavailable + + exception = ServiceUnavailable("testing") + + self.assertTrue(self._callFUT(exception)) + + def test_w_non_recoverable(self): + exception = ValueError("testing") + + self.assertFalse(self._callFUT(exception)) + + +class Test_should_terminate(unittest.TestCase): + def _callFUT(self, exception): + from google.cloud.firestore_v1.watch import _should_terminate + + return _should_terminate(exception) + + def test_w_unavailable(self): + from google.api_core.exceptions import Cancelled + + exception = Cancelled("testing") + + self.assertTrue(self._callFUT(exception)) + + def test_w_non_recoverable(self): + exception = ValueError("testing") + + self.assertFalse(self._callFUT(exception)) + + class TestWatch(unittest.TestCase): def _makeOne( self, @@ -161,17 +199,26 @@ def _snapshot_callback(self, docs, changes, read_time): self.snapshotted = (docs, changes, read_time) def test_ctor(self): + from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud.firestore_v1.watch import _should_recover + from google.cloud.firestore_v1.watch import _should_terminate + inst = self._makeOne() self.assertTrue(inst._consumer.started) self.assertTrue(inst._rpc.callbacks, [inst._on_rpc_done]) + self.assertIs(inst._rpc.start_rpc, inst._api.transport.listen) + self.assertIs(inst._rpc.should_recover, _should_recover) + self.assertIs(inst._rpc.should_terminate, _should_terminate) + self.assertIsInstance(inst._rpc.initial_request, firestore_pb2.ListenRequest) + self.assertEqual(inst._rpc.metadata, DummyFirestore._rpc_metadata) def test__on_rpc_done(self): + from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME + inst = self._makeOne() threading = DummyThreading() with mock.patch("google.cloud.firestore_v1.watch.threading", threading): inst._on_rpc_done(True) - from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME - self.assertTrue(threading.threads[_RPC_ERROR_THREAD_NAME].started) def test_close(self): @@ -835,13 +882,21 @@ def Thread(self, name, target, kwargs): class DummyRpc(object): - def __init__(self, listen, initial_request, should_recover, metadata=None): - self.listen = listen - self.initial_request = initial_request + def __init__( + self, + start_rpc, + should_recover, + should_terminate=None, + initial_request=None, + metadata=None, + ): + self.start_rpc = start_rpc self.should_recover = should_recover + self.should_terminate = should_terminate + self.initial_request = initial_request + self.metadata = metadata self.closed = False self.callbacks = [] - self._metadata = metadata def add_done_callback(self, callback): self.callbacks.append(callback)