Skip to content

Commit

Permalink
Firestore: Add 'should_terminate' predicate for clean BiDi shutdown. (#…
Browse files Browse the repository at this point in the history
…8650)

Closes #7826.
  • Loading branch information
tseaver authored Jul 17, 2019
1 parent 562deea commit 94043b1
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 28 deletions.
20 changes: 19 additions & 1 deletion api_core/google/api_core/bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ def pending_requests(self):
return self._request_queue.qsize()


def _never_terminate(future_or_error):
"""By default, no errors cause BiDi termination."""
return False


class ResumableBidiRpc(BidiRpc):
"""A :class:`BidiRpc` that can automatically resume the stream on errors.
Expand Down Expand Up @@ -391,6 +396,9 @@ def should_recover(exc):
should_recover (Callable[[Exception], bool]): A function that returns
True if the stream should be recovered. This will be called
whenever an error is encountered on the stream.
should_terminate (Callable[[Exception], bool]): A function that returns
True if the stream should be terminated. This will be called
whenever an error is encountered on the stream.
metadata Sequence[Tuple(str, str)]: RPC metadata to include in
the request.
throttle_reopen (bool): If ``True``, throttling will be applied to
Expand All @@ -401,12 +409,14 @@ def __init__(
self,
start_rpc,
should_recover,
should_terminate=_never_terminate,
initial_request=None,
metadata=None,
throttle_reopen=False,
):
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
self._should_recover = should_recover
self._should_terminate = should_terminate
self._operational_lock = threading.RLock()
self._finalized = False
self._finalize_lock = threading.Lock()
Expand All @@ -433,7 +443,9 @@ def _on_call_done(self, future):
# error, not for errors that we can recover from. Note that grpc's
# "future" here is also a grpc.RpcError.
with self._operational_lock:
if not self._should_recover(future):
if self._should_terminate(future):
self._finalize(future)
elif not self._should_recover(future):
self._finalize(future)
else:
_LOGGER.debug("Re-opening stream from gRPC callback.")
Expand Down Expand Up @@ -496,6 +508,12 @@ def _recoverable(self, method, *args, **kwargs):
with self._operational_lock:
_LOGGER.debug("Call to retryable %r caused %s.", method, exc)

if self._should_terminate(exc):
self.close()
_LOGGER.debug("Terminating %r due to %s.", method, exc)
self._finalize(exc)
break

if not self._should_recover(exc):
self.close()
_LOGGER.debug("Not retrying %r due to %s.", method, exc)
Expand Down
110 changes: 104 additions & 6 deletions api_core/tests/unit/test_bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,33 +370,111 @@ def cancel(self):


class TestResumableBidiRpc(object):
def test_initial_state(self):
callback = mock.Mock()
callback.return_value = True
bidi_rpc = bidi.ResumableBidiRpc(None, callback)
def test_ctor_defaults(self):
start_rpc = mock.Mock()
should_recover = mock.Mock()
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)

assert bidi_rpc.is_active is False
assert bidi_rpc._finalized is False
assert bidi_rpc._start_rpc is start_rpc
assert bidi_rpc._should_recover is should_recover
assert bidi_rpc._should_terminate is bidi._never_terminate
assert bidi_rpc._initial_request is None
assert bidi_rpc._rpc_metadata is None
assert bidi_rpc._reopen_throttle is None

def test_ctor_explicit(self):
start_rpc = mock.Mock()
should_recover = mock.Mock()
should_terminate = mock.Mock()
initial_request = mock.Mock()
metadata = {"x-foo": "bar"}
bidi_rpc = bidi.ResumableBidiRpc(
start_rpc,
should_recover,
should_terminate=should_terminate,
initial_request=initial_request,
metadata=metadata,
throttle_reopen=True,
)

assert bidi_rpc.is_active is False
assert bidi_rpc._finalized is False
assert bidi_rpc._should_recover is should_recover
assert bidi_rpc._should_terminate is should_terminate
assert bidi_rpc._initial_request is initial_request
assert bidi_rpc._rpc_metadata == metadata
assert isinstance(bidi_rpc._reopen_throttle, bidi._Throttle)

def test_done_callbacks_terminate(self):
cancellation = mock.Mock()
start_rpc = mock.Mock()
should_recover = mock.Mock(spec=["__call__"], return_value=True)
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(
start_rpc, should_recover, should_terminate=should_terminate
)
callback = mock.Mock(spec=["__call__"])

bidi_rpc.add_done_callback(callback)
bidi_rpc._on_call_done(cancellation)

should_terminate.assert_called_once_with(cancellation)
should_recover.assert_not_called()
callback.assert_called_once_with(cancellation)
assert not bidi_rpc.is_active

def test_done_callbacks_recoverable(self):
start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
should_recover = mock.Mock(spec=["__call__"], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
callback = mock.Mock(spec=["__call__"])

bidi_rpc.add_done_callback(callback)
bidi_rpc._on_call_done(mock.sentinel.future)

callback.assert_not_called()
start_rpc.assert_called_once()
should_recover.assert_called_once_with(mock.sentinel.future)
assert bidi_rpc.is_active

def test_done_callbacks_non_recoverable(self):
bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
should_recover = mock.Mock(spec=["__call__"], return_value=False)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
callback = mock.Mock(spec=["__call__"])

bidi_rpc.add_done_callback(callback)
bidi_rpc._on_call_done(mock.sentinel.future)

callback.assert_called_once_with(mock.sentinel.future)
should_recover.assert_called_once_with(mock.sentinel.future)
assert not bidi_rpc.is_active

def test_send_terminate(self):
cancellation = ValueError()
call_1 = CallStub([cancellation], active=False)
call_2 = CallStub([])
start_rpc = mock.create_autospec(
grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
)
should_recover = mock.Mock(spec=["__call__"], return_value=False)
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)

bidi_rpc.open()

bidi_rpc.send(mock.sentinel.request)

assert bidi_rpc.pending_requests == 1
assert bidi_rpc._request_queue.get() is None

should_recover.assert_not_called()
should_terminate.assert_called_once_with(cancellation)
assert bidi_rpc.call == call_1
assert bidi_rpc.is_active is False
assert call_1.cancelled is True

def test_send_recover(self):
error = ValueError()
Expand Down Expand Up @@ -441,6 +519,26 @@ def test_send_failure(self):
assert bidi_rpc.pending_requests == 1
assert bidi_rpc._request_queue.get() is None

def test_recv_terminate(self):
cancellation = ValueError()
call = CallStub([cancellation])
start_rpc = mock.create_autospec(
grpc.StreamStreamMultiCallable, instance=True, return_value=call
)
should_recover = mock.Mock(spec=["__call__"], return_value=False)
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)

bidi_rpc.open()

bidi_rpc.recv()

should_recover.assert_not_called()
should_terminate.assert_called_once_with(cancellation)
assert bidi_rpc.call == call
assert bidi_rpc.is_active is False
assert call.cancelled is True

def test_recv_recover(self):
error = ValueError()
call_1 = CallStub([1, error])
Expand Down
28 changes: 14 additions & 14 deletions firestore/google/cloud/firestore_v1/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,8 @@
"DO_NOT_USE": -1,
}
_RPC_ERROR_THREAD_NAME = "Thread-OnRpcTerminated"
_RETRYABLE_STREAM_ERRORS = (
exceptions.DeadlineExceeded,
exceptions.ServiceUnavailable,
exceptions.InternalServerError,
exceptions.Unknown,
exceptions.GatewayTimeout,
)
_RECOVERABLE_STREAM_EXCEPTIONS = (exceptions.ServiceUnavailable,)
_TERMINATING_STREAM_EXCEPTIONS = (exceptions.Cancelled,)

DocTreeEntry = collections.namedtuple("DocTreeEntry", ["value", "index"])

Expand Down Expand Up @@ -153,6 +148,16 @@ def document_watch_comparator(doc1, doc2):
return 0


def _should_recover(exception):
wrapped = _maybe_wrap_exception(exception)
return isinstance(wrapped, _RECOVERABLE_STREAM_EXCEPTIONS)


def _should_terminate(exception):
wrapped = _maybe_wrap_exception(exception)
return isinstance(wrapped, _TERMINATING_STREAM_EXCEPTIONS)


class Watch(object):

BackgroundConsumer = BackgroundConsumer # FBO unit tests
Expand Down Expand Up @@ -199,12 +204,6 @@ def __init__(
self._closing = threading.Lock()
self._closed = False

def should_recover(exc): # pragma: NO COVER
return (
isinstance(exc, grpc.RpcError)
and exc.code() == grpc.StatusCode.UNAVAILABLE
)

initial_request = firestore_pb2.ListenRequest(
database=self._firestore._database_string, add_target=self._targets
)
Expand All @@ -214,8 +213,9 @@ def should_recover(exc): # pragma: NO COVER

self._rpc = ResumableBidiRpc(
self._api.transport.listen,
should_recover=_should_recover,
should_terminate=_should_terminate,
initial_request=initial_request,
should_recover=should_recover,
metadata=self._firestore._rpc_metadata,
)

Expand Down
10 changes: 9 additions & 1 deletion firestore/tests/unit/v1/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,18 @@ def convert_precondition(precond):


class DummyRpc(object): # pragma: NO COVER
def __init__(self, listen, initial_request, should_recover, metadata=None):
def __init__(
self,
listen,
should_recover,
should_terminate=None,
initial_request=None,
metadata=None,
):
self.listen = listen
self.initial_request = initial_request
self.should_recover = should_recover
self.should_terminate = should_terminate
self.closed = False
self.callbacks = []
self._metadata = metadata
Expand Down
67 changes: 61 additions & 6 deletions firestore/tests/unit/v1/test_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,44 @@ def test_diff_doc(self):
self.assertRaises(AssertionError, self._callFUT, 1, 2)


class Test_should_recover(unittest.TestCase):
def _callFUT(self, exception):
from google.cloud.firestore_v1.watch import _should_recover

return _should_recover(exception)

def test_w_unavailable(self):
from google.api_core.exceptions import ServiceUnavailable

exception = ServiceUnavailable("testing")

self.assertTrue(self._callFUT(exception))

def test_w_non_recoverable(self):
exception = ValueError("testing")

self.assertFalse(self._callFUT(exception))


class Test_should_terminate(unittest.TestCase):
def _callFUT(self, exception):
from google.cloud.firestore_v1.watch import _should_terminate

return _should_terminate(exception)

def test_w_unavailable(self):
from google.api_core.exceptions import Cancelled

exception = Cancelled("testing")

self.assertTrue(self._callFUT(exception))

def test_w_non_recoverable(self):
exception = ValueError("testing")

self.assertFalse(self._callFUT(exception))


class TestWatch(unittest.TestCase):
def _makeOne(
self,
Expand Down Expand Up @@ -161,17 +199,26 @@ def _snapshot_callback(self, docs, changes, read_time):
self.snapshotted = (docs, changes, read_time)

def test_ctor(self):
from google.cloud.firestore_v1.proto import firestore_pb2
from google.cloud.firestore_v1.watch import _should_recover
from google.cloud.firestore_v1.watch import _should_terminate

inst = self._makeOne()
self.assertTrue(inst._consumer.started)
self.assertTrue(inst._rpc.callbacks, [inst._on_rpc_done])
self.assertIs(inst._rpc.start_rpc, inst._api.transport.listen)
self.assertIs(inst._rpc.should_recover, _should_recover)
self.assertIs(inst._rpc.should_terminate, _should_terminate)
self.assertIsInstance(inst._rpc.initial_request, firestore_pb2.ListenRequest)
self.assertEqual(inst._rpc.metadata, DummyFirestore._rpc_metadata)

def test__on_rpc_done(self):
from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME

inst = self._makeOne()
threading = DummyThreading()
with mock.patch("google.cloud.firestore_v1.watch.threading", threading):
inst._on_rpc_done(True)
from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME

self.assertTrue(threading.threads[_RPC_ERROR_THREAD_NAME].started)

def test_close(self):
Expand Down Expand Up @@ -835,13 +882,21 @@ def Thread(self, name, target, kwargs):


class DummyRpc(object):
def __init__(self, listen, initial_request, should_recover, metadata=None):
self.listen = listen
self.initial_request = initial_request
def __init__(
self,
start_rpc,
should_recover,
should_terminate=None,
initial_request=None,
metadata=None,
):
self.start_rpc = start_rpc
self.should_recover = should_recover
self.should_terminate = should_terminate
self.initial_request = initial_request
self.metadata = metadata
self.closed = False
self.callbacks = []
self._metadata = metadata

def add_done_callback(self, callback):
self.callbacks.append(callback)
Expand Down

0 comments on commit 94043b1

Please sign in to comment.