Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry streaming exceptions (this time for sure, Rocky!) #4016

Merged
merged 6 commits into from
Sep 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions spanner/google/cloud/spanner/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def snapshot(self, **kw):

return Snapshot(self, **kw)

def read(self, table, columns, keyset, index='', limit=0,
resume_token=b''):
def read(self, table, columns, keyset, index='', limit=0):
"""Perform a ``StreamingRead`` API request for rows in a table.

:type table: str
Expand All @@ -185,17 +184,12 @@ def read(self, table, columns, keyset, index='', limit=0,
:type limit: int
:param limit: (Optional) maxiumn number of rows to return

:type resume_token: bytes
:param resume_token: token for resuming previously-interrupted read

:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self.snapshot().read(
table, columns, keyset, index, limit, resume_token)
return self.snapshot().read(table, columns, keyset, index, limit)

def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
resume_token=b''):
def execute_sql(self, sql, params=None, param_types=None, query_mode=None):
"""Perform an ``ExecuteStreamingSql`` API request.

:type sql: str
Expand All @@ -216,14 +210,11 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
:param query_mode: Mode governing return of results / query plan. See
https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1

:type resume_token: bytes
:param resume_token: token for resuming previously-interrupted query

:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self.snapshot().execute_sql(
sql, params, param_types, query_mode, resume_token)
sql, params, param_types, query_mode)

def batch(self):
"""Factory to create a batch for this session.
Expand Down
60 changes: 46 additions & 14 deletions spanner/google/cloud/spanner/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

"""Model a set of read-only queries to a database as a snapshot."""

import functools

from google.protobuf.struct_pb2 import Struct
from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions
from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionSelector

from google.api.core.exceptions import ServiceUnavailable
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.cloud._helpers import _timedelta_to_duration_pb
from google.cloud.spanner._helpers import _make_value_pb
Expand All @@ -26,6 +29,36 @@
from google.cloud.spanner.streamed import StreamedResultSet


def _restart_on_unavailable(restart):
"""Restart iteration after :exc:`.ServiceUnavailable`.

:type restart: callable
:param restart: curried function returning iterator
"""
resume_token = ''
item_buffer = []
iterator = restart()
while True:
try:
for item in iterator:
item_buffer.append(item)
if item.resume_token:
resume_token = item.resume_token
break
except ServiceUnavailable:
del item_buffer[:]
iterator = restart(resume_token=resume_token)
continue

if len(item_buffer) == 0:
break

for item in item_buffer:
yield item

del item_buffer[:]


class _SnapshotBase(_SessionWrapper):
"""Base class for Snapshot.

Expand All @@ -49,8 +82,7 @@ def _make_txn_selector(self): # pylint: disable=redundant-returns-doc
"""
raise NotImplementedError

def read(self, table, columns, keyset, index='', limit=0,
resume_token=b''):
def read(self, table, columns, keyset, index='', limit=0):
"""Perform a ``StreamingRead`` API request for rows in a table.

:type table: str
Expand All @@ -69,9 +101,6 @@ def read(self, table, columns, keyset, index='', limit=0,
:type limit: int
:param limit: (Optional) maxiumn number of rows to return

:type resume_token: bytes
:param resume_token: token for resuming previously-interrupted read

:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
:raises ValueError:
Expand All @@ -89,10 +118,13 @@ def read(self, table, columns, keyset, index='', limit=0,
options = _options_with_prefix(database.name)
transaction = self._make_txn_selector()

iterator = api.streaming_read(
restart = functools.partial(
api.streaming_read,
self._session.name, table, columns, keyset.to_pb(),
transaction=transaction, index=index, limit=limit,
resume_token=resume_token, options=options)
options=options)

iterator = _restart_on_unavailable(restart)

self._read_request_count += 1

Expand All @@ -101,8 +133,7 @@ def read(self, table, columns, keyset, index='', limit=0,
else:
return StreamedResultSet(iterator)

def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
resume_token=b''):
def execute_sql(self, sql, params=None, param_types=None, query_mode=None):
"""Perform an ``ExecuteStreamingSql`` API request for rows in a table.

:type sql: str
Expand All @@ -122,9 +153,6 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
:param query_mode: Mode governing return of results / query plan. See
https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1

:type resume_token: bytes
:param resume_token: token for resuming previously-interrupted query

:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
:raises ValueError:
Expand All @@ -150,10 +178,14 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
options = _options_with_prefix(database.name)
transaction = self._make_txn_selector()
api = database.spanner_api
iterator = api.execute_streaming_sql(

restart = functools.partial(
api.execute_streaming_sql,
self._session.name, sql,
transaction=transaction, params=params_pb, param_types=param_types,
query_mode=query_mode, resume_token=resume_token, options=options)
query_mode=query_mode, options=options)

iterator = _restart_on_unavailable(restart)

self._read_request_count += 1

Expand Down
11 changes: 0 additions & 11 deletions spanner/google/cloud/spanner/streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(self, response_iterator, source=None):
self._counter = 0 # Counter for processed responses
self._metadata = None # Until set from first PRS
self._stats = None # Until set from last PRS
self._resume_token = None # To resume from last received PRS
self._current_row = [] # Accumulated values for incomplete row
self._pending_chunk = None # Incomplete value
self._source = source # Source snapshot
Expand Down Expand Up @@ -85,15 +84,6 @@ def stats(self):
"""
return self._stats

@property
def resume_token(self):
"""Token for resuming interrupted read / query.

:rtype: bytes
:returns: token from last chunk of results.
"""
return self._resume_token

def _merge_chunk(self, value):
"""Merge pending chunk with next value.

Expand Down Expand Up @@ -132,7 +122,6 @@ def consume_next(self):
"""
response = six.next(self._response_iterator)
self._counter += 1
self._resume_token = response.resume_token

if self._metadata is None: # first response
metadata = self._metadata = response.metadata
Expand Down
21 changes: 8 additions & 13 deletions spanner/tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def test_read(self):
KEYSET = KeySet(keys=KEYS)
INDEX = 'email-address-index'
LIMIT = 20
TOKEN = b'DEADBEEF'
database = _Database(self.DATABASE_NAME)
session = self._make_one(database)
session._session_id = 'DEADBEEF'
Expand All @@ -279,28 +278,26 @@ def __init__(self, session, **kwargs):
self._session = session
self._kwargs = kwargs.copy()

def read(self, table, columns, keyset, index='', limit=0,
resume_token=b''):
def read(self, table, columns, keyset, index='', limit=0):
_read_with.append(
(table, columns, keyset, index, limit, resume_token))
(table, columns, keyset, index, limit))
return expected

with _Monkey(MUT, Snapshot=_Snapshot):
found = session.read(
TABLE_NAME, COLUMNS, KEYSET,
index=INDEX, limit=LIMIT, resume_token=TOKEN)
index=INDEX, limit=LIMIT)

self.assertIs(found, expected)

self.assertEqual(len(_read_with), 1)
(table, columns, key_set, index, limit, resume_token) = _read_with[0]
(table, columns, key_set, index, limit) = _read_with[0]

self.assertEqual(table, TABLE_NAME)
self.assertEqual(columns, COLUMNS)
self.assertEqual(key_set, KEYSET)
self.assertEqual(index, INDEX)
self.assertEqual(limit, LIMIT)
self.assertEqual(resume_token, TOKEN)

def test_execute_sql_not_created(self):
SQL = 'SELECT first_name, age FROM citizens'
Expand Down Expand Up @@ -330,25 +327,23 @@ def __init__(self, session, **kwargs):
self._kwargs = kwargs.copy()

def execute_sql(
self, sql, params=None, param_types=None, query_mode=None,
resume_token=None):
self, sql, params=None, param_types=None, query_mode=None):
_executed_sql_with.append(
(sql, params, param_types, query_mode, resume_token))
(sql, params, param_types, query_mode))
return expected

with _Monkey(MUT, Snapshot=_Snapshot):
found = session.execute_sql(SQL, resume_token=TOKEN)
found = session.execute_sql(SQL)

self.assertIs(found, expected)

self.assertEqual(len(_executed_sql_with), 1)
sql, params, param_types, query_mode, token = _executed_sql_with[0]
sql, params, param_types, query_mode = _executed_sql_with[0]

self.assertEqual(sql, SQL)
self.assertEqual(params, None)
self.assertEqual(param_types, None)
self.assertEqual(query_mode, None)
self.assertEqual(token, TOKEN)

def test_batch_not_created(self):
database = _Database(self.DATABASE_NAME)
Expand Down
Loading