Skip to content

Commit

Permalink
fix: harden 'query.stream' against retriable exceptions (#456)
Browse files Browse the repository at this point in the history
Closes #223.
  • Loading branch information
tseaver authored Sep 23, 2021
1 parent 335e2c4 commit 0dca32f
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 7 deletions.
48 changes: 43 additions & 5 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""
from google.cloud import firestore_v1
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.api_core import exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore

Expand Down Expand Up @@ -208,6 +209,29 @@ def _chunkify(
):
return

def _get_stream_iterator(self, transaction, retry, timeout):
"""Helper method for :meth:`stream`."""
request, expected_prefix, kwargs = self._prep_stream(
transaction, retry, timeout,
)

response_iterator = self._client._firestore_api.run_query(
request=request, metadata=self._client._rpc_metadata, **kwargs,
)

return response_iterator, expected_prefix

def _retry_query_after_exception(self, exc, retry, transaction):
"""Helper method for :meth:`stream`."""
if transaction is None: # no snapshot-based retry inside transaction
if retry is gapic_v1.method.DEFAULT:
transport = self._client._firestore_api._transport
gapic_callable = transport.run_query
retry = gapic_callable._retry
return retry._predicate(exc)

return False

def stream(
self,
transaction=None,
Expand Down Expand Up @@ -244,15 +268,28 @@ def stream(
:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`:
The next document that fulfills the query.
"""
request, expected_prefix, kwargs = self._prep_stream(
response_iterator, expected_prefix = self._get_stream_iterator(
transaction, retry, timeout,
)

response_iterator = self._client._firestore_api.run_query(
request=request, metadata=self._client._rpc_metadata, **kwargs,
)
last_snapshot = None

while True:
try:
response = next(response_iterator, None)
except exceptions.GoogleAPICallError as exc:
if self._retry_query_after_exception(exc, retry, transaction):
new_query = self.start_after(last_snapshot)
response_iterator, _ = new_query._get_stream_iterator(
transaction, retry, timeout,
)
continue
else:
raise

if response is None: # EOI
break

for response in response_iterator:
if self._all_descendants:
snapshot = _collection_group_query_response_to_snapshot(
response, self._parent
Expand All @@ -262,6 +299,7 @@ def stream(
response, self._parent, expected_prefix
)
if snapshot is not None:
last_snapshot = snapshot
yield snapshot

def on_snapshot(self, callback: Callable) -> Watch:
Expand Down
123 changes: 121 additions & 2 deletions tests/unit/v1/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.firestore_v1.types.document import Document
from google.cloud.firestore_v1.types.firestore import RunQueryResponse
import types
import unittest

import mock
import pytest

from google.api_core import gapic_v1
from google.cloud.firestore_v1.types.document import Document
from google.cloud.firestore_v1.types.firestore import RunQueryResponse
from tests.unit.v1.test_base_query import _make_credentials
from tests.unit.v1.test_base_query import _make_cursor_pb
from tests.unit.v1.test_base_query import _make_query_response
Expand Down Expand Up @@ -456,6 +457,124 @@ def test_stream_w_collection_group(self):
metadata=client._rpc_metadata,
)

def _stream_w_retriable_exc_helper(
self,
retry=gapic_v1.method.DEFAULT,
timeout=None,
transaction=None,
expect_retry=True,
):
from google.api_core import exceptions
from google.cloud.firestore_v1 import _helpers

if transaction is not None:
expect_retry = False

# Create a minimal fake GAPIC.
firestore_api = mock.Mock(spec=["run_query", "_transport"])
transport = firestore_api._transport = mock.Mock(spec=["run_query"])
stub = transport.run_query = mock.create_autospec(
gapic_v1.method._GapicCallable
)
stub._retry = mock.Mock(spec=["_predicate"])
stub._predicate = lambda exc: True # pragma: NO COVER

# Attach the fake GAPIC to a real client.
client = _make_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dee")

# Add a dummy response to the minimal fake GAPIC.
_, expected_prefix = parent._parent_info()
name = "{}/sleep".format(expected_prefix)
data = {"snooze": 10}
response_pb = _make_query_response(name=name, data=data)
retriable_exc = exceptions.ServiceUnavailable("testing")

def _stream_w_exception(*_args, **_kw):
yield response_pb
raise retriable_exc

firestore_api.run_query.side_effect = [_stream_w_exception(), iter([])]
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

# Execute the query and check the response.
query = self._make_one(parent)

get_response = query.stream(transaction=transaction, **kwargs)

self.assertIsInstance(get_response, types.GeneratorType)
if expect_retry:
returned = list(get_response)
else:
returned = [next(get_response)]
with self.assertRaises(exceptions.ServiceUnavailable):
next(get_response)

self.assertEqual(len(returned), 1)
snapshot = returned[0]
self.assertEqual(snapshot.reference._path, ("dee", "sleep"))
self.assertEqual(snapshot.to_dict(), data)

# Verify the mock call.
parent_path, _ = parent._parent_info()
calls = firestore_api.run_query.call_args_list

if expect_retry:
self.assertEqual(len(calls), 2)
else:
self.assertEqual(len(calls), 1)

if transaction is not None:
expected_transaction_id = transaction.id
else:
expected_transaction_id = None

self.assertEqual(
calls[0],
mock.call(
request={
"parent": parent_path,
"structured_query": query._to_protobuf(),
"transaction": expected_transaction_id,
},
metadata=client._rpc_metadata,
**kwargs,
),
)

if expect_retry:
new_query = query.start_after(snapshot)
self.assertEqual(
calls[1],
mock.call(
request={
"parent": parent_path,
"structured_query": new_query._to_protobuf(),
"transaction": None,
},
metadata=client._rpc_metadata,
**kwargs,
),
)

def test_stream_w_retriable_exc_w_defaults(self):
self._stream_w_retriable_exc_helper()

def test_stream_w_retriable_exc_w_retry(self):
retry = mock.Mock(spec=["_predicate"])
retry._predicate = lambda exc: False
self._stream_w_retriable_exc_helper(retry=retry, expect_retry=False)

def test_stream_w_retriable_exc_w_transaction(self):
from google.cloud.firestore_v1 import transaction

txn = transaction.Transaction(client=mock.Mock(spec=[]))
txn._id = b"DEADBEEF"
self._stream_w_retriable_exc_helper(transaction=txn)

@mock.patch("google.cloud.firestore_v1.query.Watch", autospec=True)
def test_on_snapshot(self, watch):
query = self._make_one(mock.sentinel.parent)
Expand Down

0 comments on commit 0dca32f

Please sign in to comment.