diff --git a/spanner/google/cloud/spanner/session.py b/spanner/google/cloud/spanner/session.py index d513889053a7..94fd0f092366 100644 --- a/spanner/google/cloud/spanner/session.py +++ b/spanner/google/cloud/spanner/session.py @@ -165,8 +165,7 @@ def snapshot(self, **kw): return Snapshot(self, **kw) - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -185,17 +184,12 @@ def read(self, table, columns, keyset, index='', limit=0, :type limit: int :param limit: (Optional) maxiumn number of rows to return - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted read - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self.snapshot().read( - table, columns, keyset, index, limit, resume_token) + return self.snapshot().read(table, columns, keyset, index, limit) - def execute_sql(self, sql, params=None, param_types=None, query_mode=None, - resume_token=b''): + def execute_sql(self, sql, params=None, param_types=None, query_mode=None): """Perform an ``ExecuteStreamingSql`` API request. :type sql: str @@ -216,14 +210,11 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted query - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ return self.snapshot().execute_sql( - sql, params, param_types, query_mode, resume_token) + sql, params, param_types, query_mode) def batch(self): """Factory to create a batch for this session. diff --git a/spanner/google/cloud/spanner/snapshot.py b/spanner/google/cloud/spanner/snapshot.py index 89bd840000dc..7c5ff449448c 100644 --- a/spanner/google/cloud/spanner/snapshot.py +++ b/spanner/google/cloud/spanner/snapshot.py @@ -14,10 +14,13 @@ """Model a set of read-only queries to a database as a snapshot.""" +import functools + from google.protobuf.struct_pb2 import Struct from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionSelector +from google.api.core.exceptions import ServiceUnavailable from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud._helpers import _timedelta_to_duration_pb from google.cloud.spanner._helpers import _make_value_pb @@ -26,6 +29,36 @@ from google.cloud.spanner.streamed import StreamedResultSet +def _restart_on_unavailable(restart): + """Restart iteration after :exc:`.ServiceUnavailable`. + + :type restart: callable + :param restart: curried function returning iterator + """ + resume_token = '' + item_buffer = [] + iterator = restart() + while True: + try: + for item in iterator: + item_buffer.append(item) + if item.resume_token: + resume_token = item.resume_token + break + except ServiceUnavailable: + del item_buffer[:] + iterator = restart(resume_token=resume_token) + continue + + if len(item_buffer) == 0: + break + + for item in item_buffer: + yield item + + del item_buffer[:] + + class _SnapshotBase(_SessionWrapper): """Base class for Snapshot. @@ -49,8 +82,7 @@ def _make_txn_selector(self): # pylint: disable=redundant-returns-doc """ raise NotImplementedError - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -69,9 +101,6 @@ def read(self, table, columns, keyset, index='', limit=0, :type limit: int :param limit: (Optional) maxiumn number of rows to return - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted read - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. :raises ValueError: @@ -89,10 +118,13 @@ def read(self, table, columns, keyset, index='', limit=0, options = _options_with_prefix(database.name) transaction = self._make_txn_selector() - iterator = api.streaming_read( + restart = functools.partial( + api.streaming_read, self._session.name, table, columns, keyset.to_pb(), transaction=transaction, index=index, limit=limit, - resume_token=resume_token, options=options) + options=options) + + iterator = _restart_on_unavailable(restart) self._read_request_count += 1 @@ -101,8 +133,7 @@ def read(self, table, columns, keyset, index='', limit=0, else: return StreamedResultSet(iterator) - def execute_sql(self, sql, params=None, param_types=None, query_mode=None, - resume_token=b''): + def execute_sql(self, sql, params=None, param_types=None, query_mode=None): """Perform an ``ExecuteStreamingSql`` API request for rows in a table. :type sql: str @@ -122,9 +153,6 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted query - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. :raises ValueError: @@ -150,10 +178,14 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, options = _options_with_prefix(database.name) transaction = self._make_txn_selector() api = database.spanner_api - iterator = api.execute_streaming_sql( + + restart = functools.partial( + api.execute_streaming_sql, self._session.name, sql, transaction=transaction, params=params_pb, param_types=param_types, - query_mode=query_mode, resume_token=resume_token, options=options) + query_mode=query_mode, options=options) + + iterator = _restart_on_unavailable(restart) self._read_request_count += 1 diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index c7d950d766d7..ef6ba2e3bcfd 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -43,7 +43,6 @@ def __init__(self, response_iterator, source=None): self._counter = 0 # Counter for processed responses self._metadata = None # Until set from first PRS self._stats = None # Until set from last PRS - self._resume_token = None # To resume from last received PRS self._current_row = [] # Accumulated values for incomplete row self._pending_chunk = None # Incomplete value self._source = source # Source snapshot @@ -85,15 +84,6 @@ def stats(self): """ return self._stats - @property - def resume_token(self): - """Token for resuming interrupted read / query. - - :rtype: bytes - :returns: token from last chunk of results. - """ - return self._resume_token - def _merge_chunk(self, value): """Merge pending chunk with next value. @@ -132,7 +122,6 @@ def consume_next(self): """ response = six.next(self._response_iterator) self._counter += 1 - self._resume_token = response.resume_token if self._metadata is None: # first response metadata = self._metadata = response.metadata diff --git a/spanner/tests/unit/test_session.py b/spanner/tests/unit/test_session.py index 826369079d29..a045e94d35de 100644 --- a/spanner/tests/unit/test_session.py +++ b/spanner/tests/unit/test_session.py @@ -265,7 +265,6 @@ def test_read(self): KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 - TOKEN = b'DEADBEEF' database = _Database(self.DATABASE_NAME) session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -279,28 +278,26 @@ def __init__(self, session, **kwargs): self._session = session self._kwargs = kwargs.copy() - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): _read_with.append( - (table, columns, keyset, index, limit, resume_token)) + (table, columns, keyset, index, limit)) return expected with _Monkey(MUT, Snapshot=_Snapshot): found = session.read( TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + index=INDEX, limit=LIMIT) self.assertIs(found, expected) self.assertEqual(len(_read_with), 1) - (table, columns, key_set, index, limit, resume_token) = _read_with[0] + (table, columns, key_set, index, limit) = _read_with[0] self.assertEqual(table, TABLE_NAME) self.assertEqual(columns, COLUMNS) self.assertEqual(key_set, KEYSET) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) - self.assertEqual(resume_token, TOKEN) def test_execute_sql_not_created(self): SQL = 'SELECT first_name, age FROM citizens' @@ -330,25 +327,23 @@ def __init__(self, session, **kwargs): self._kwargs = kwargs.copy() def execute_sql( - self, sql, params=None, param_types=None, query_mode=None, - resume_token=None): + self, sql, params=None, param_types=None, query_mode=None): _executed_sql_with.append( - (sql, params, param_types, query_mode, resume_token)) + (sql, params, param_types, query_mode)) return expected with _Monkey(MUT, Snapshot=_Snapshot): - found = session.execute_sql(SQL, resume_token=TOKEN) + found = session.execute_sql(SQL) self.assertIs(found, expected) self.assertEqual(len(_executed_sql_with), 1) - sql, params, param_types, query_mode, token = _executed_sql_with[0] + sql, params, param_types, query_mode = _executed_sql_with[0] self.assertEqual(sql, SQL) self.assertEqual(params, None) self.assertEqual(param_types, None) self.assertEqual(query_mode, None) - self.assertEqual(token, TOKEN) def test_batch_not_created(self): database = _Database(self.DATABASE_NAME) diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 4717a14c2f24..a9b03a397910 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -15,6 +15,8 @@ import unittest +import mock + from google.cloud._testing import _GAXBaseAPI @@ -31,6 +33,85 @@ PARAMS_WITH_BYTES = {'bytes': b'DEADBEEF'} +class Test_restart_on_unavailable(unittest.TestCase): + + def _call_fut(self, restart): + from google.cloud.spanner.snapshot import _restart_on_unavailable + + return _restart_on_unavailable(restart) + + def _make_item(self, value, resume_token=''): + return mock.Mock( + value=value, resume_token=resume_token, + spec=['value', 'resume_token']) + + def test_iteration_w_empty_raw(self): + ITEMS = () + raw = _MockIterator() + restart = mock.Mock(spec=[], return_value=raw) + resumable = self._call_fut(restart) + self.assertEqual(list(resumable), []) + + def test_iteration_w_non_empty_raw(self): + ITEMS = (self._make_item(0), self._make_item(1)) + raw = _MockIterator(*ITEMS) + restart = mock.Mock(spec=[], return_value=raw) + resumable = self._call_fut(restart) + self.assertEqual(list(resumable), list(ITEMS)) + restart.assert_called_once_with() + + def test_iteration_w_raw_w_resume_tken(self): + ITEMS = ( + self._make_item(0), + self._make_item(1, resume_token='DEADBEEF'), + self._make_item(2), + self._make_item(3), + ) + raw = _MockIterator(*ITEMS) + restart = mock.Mock(spec=[], return_value=raw) + resumable = self._call_fut(restart) + self.assertEqual(list(resumable), list(ITEMS)) + restart.assert_called_once_with() + + def test_iteration_w_raw_raising_unavailable(self): + FIRST = ( + self._make_item(0), + self._make_item(1, resume_token='DEADBEEF'), + ) + SECOND = ( # discarded after 503 + self._make_item(2), + ) + LAST = ( + self._make_item(3), + ) + before = _MockIterator(*(FIRST + SECOND), fail_after=True) + after = _MockIterator(*LAST) + restart = mock.Mock(spec=[], side_effect=[before, after]) + resumable = self._call_fut(restart) + self.assertEqual(list(resumable), list(FIRST + LAST)) + self.assertEqual( + restart.mock_calls, + [mock.call(), mock.call(resume_token='DEADBEEF')]) + + def test_iteration_w_raw_raising_unavailable_after_token(self): + FIRST = ( + self._make_item(0), + self._make_item(1, resume_token='DEADBEEF'), + ) + SECOND = ( + self._make_item(2), + self._make_item(3), + ) + before = _MockIterator(*FIRST, fail_after=True) + after = _MockIterator(*SECOND) + restart = mock.Mock(spec=[], side_effect=[before, after]) + resumable = self._call_fut(restart) + self.assertEqual(list(resumable), list(FIRST + SECOND)) + self.assertEqual( + restart.mock_calls, + [mock.call(), mock.call(resume_token='DEADBEEF')]) + + class Test_SnapshotBase(unittest.TestCase): PROJECT_ID = 'project-id' @@ -95,7 +176,7 @@ def test_read_grpc_error(self): derived = self._makeDerived(session) with self.assertRaises(GaxError): - derived.read(TABLE_NAME, COLUMNS, KEYSET) + list(derived.read(TABLE_NAME, COLUMNS, KEYSET)) (r_session, table, columns, key_set, transaction, index, limit, resume_token, options) = api._streaming_read_with @@ -152,7 +233,7 @@ def _read_helper(self, multi_use, first=True, count=0): TOKEN = b'DEADBEEF' database = _Database() api = database.spanner_api = _FauxSpannerAPI( - _streaming_read_response=_MockCancellableIterator(*result_sets)) + _streaming_read_response=_MockIterator(*result_sets)) session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use @@ -162,7 +243,7 @@ def _read_helper(self, multi_use, first=True, count=0): result_set = derived.read( TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + index=INDEX, limit=LIMIT) self.assertEqual(derived._read_request_count, count + 1) @@ -172,6 +253,7 @@ def _read_helper(self, multi_use, first=True, count=0): self.assertIsNone(result_set._source) result_set.consume_all() + self.assertEqual(list(result_set.rows), VALUES) self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) @@ -193,7 +275,7 @@ def _read_helper(self, multi_use, first=True, count=0): self.assertTrue(transaction.single_use.read_only.strong) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) - self.assertEqual(resume_token, TOKEN) + self.assertEqual(resume_token, b'') self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)]) @@ -229,7 +311,7 @@ def test_execute_sql_grpc_error(self): derived = self._makeDerived(session) with self.assertRaises(GaxError): - derived.execute_sql(SQL_QUERY) + list(derived.execute_sql(SQL_QUERY)) (r_session, sql, transaction, params, param_types, resume_token, query_mode, options) = api._executed_streaming_sql_with @@ -288,7 +370,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): PartialResultSet(values=VALUE_PBS[0], metadata=metadata_pb), PartialResultSet(values=VALUE_PBS[1], stats=stats_pb), ] - iterator = _MockCancellableIterator(*result_sets) + iterator = _MockIterator(*result_sets) database = _Database() api = database.spanner_api = _FauxSpannerAPI( _execute_streaming_sql_response=iterator) @@ -301,7 +383,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): result_set = derived.execute_sql( SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, - query_mode=MODE, resume_token=TOKEN) + query_mode=MODE) self.assertEqual(derived._read_request_count, count + 1) @@ -311,6 +393,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): self.assertIsNone(result_set._source) result_set.consume_all() + self.assertEqual(list(result_set.rows), VALUES) self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) @@ -333,7 +416,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): self.assertEqual(params, expected_params) self.assertEqual(param_types, PARAM_TYPES) self.assertEqual(query_mode, MODE) - self.assertEqual(resume_token, TOKEN) + self.assertEqual(resume_token, b'') self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)]) @@ -358,20 +441,6 @@ def test_execute_sql_w_multi_use_w_first_w_count_gt_0(self): self._execute_sql_helper(multi_use=True, first=True, count=1) -class _MockCancellableIterator(object): - - cancel_calls = 0 - - def __init__(self, *values): - self.iter_values = iter(values) - - def next(self): - return next(self.iter_values) - - def __next__(self): # pragma: NO COVER Py3k - return self.next() - - class TestSnapshot(unittest.TestCase): PROJECT_ID = 'project-id' @@ -725,7 +794,7 @@ def begin_transaction(self, session, options_, options=None): # pylint: disable=too-many-arguments def streaming_read(self, session, table, columns, key_set, transaction=None, index='', limit=0, - resume_token='', options=None): + resume_token=b'', options=None): from google.gax.errors import GaxError self._streaming_read_with = ( @@ -738,7 +807,7 @@ def streaming_read(self, session, table, columns, key_set, def execute_streaming_sql(self, session, sql, transaction=None, params=None, param_types=None, - resume_token='', query_mode=None, options=None): + resume_token=b'', query_mode=None, options=None): from google.gax.errors import GaxError self._executed_streaming_sql_with = ( @@ -747,3 +816,25 @@ def execute_streaming_sql(self, session, sql, transaction=None, if self._random_gax_error: raise GaxError('error') return self._execute_streaming_sql_response + + +class _MockIterator(object): + + def __init__(self, *values, **kw): + self._iter_values = iter(values) + self._fail_after = kw.pop('fail_after', False) + + def __iter__(self): + return self + + def __next__(self): + from google.api.core.exceptions import ServiceUnavailable + + try: + return next(self._iter_values) + except StopIteration: + if self._fail_after: + raise ServiceUnavailable('testing') + raise + + next = __next__ diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 0e0bcb7aff6b..c02c80466db7 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -36,7 +36,6 @@ def test_ctor_defaults(self): self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) self.assertIsNone(streamed.stats) - self.assertIsNone(streamed.resume_token) def test_ctor_w_source(self): iterator = _MockCancellableIterator() @@ -47,7 +46,6 @@ def test_ctor_w_source(self): self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) self.assertIsNone(streamed.stats) - self.assertIsNone(streamed.resume_token) def test_fields_unset(self): iterator = _MockCancellableIterator() @@ -669,7 +667,6 @@ def test_consume_next_first_set_partial(self): self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) - self.assertEqual(streamed.resume_token, result_set.resume_token) self.assertEqual(source._transaction_id, TXN_ID) def test_consume_next_first_set_partial_existing_txn_id(self): @@ -691,7 +688,6 @@ def test_consume_next_first_set_partial_existing_txn_id(self): self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) - self.assertEqual(streamed.resume_token, result_set.resume_token) self.assertEqual(source._transaction_id, TXN_ID) def test_consume_next_w_partial_result(self): @@ -711,7 +707,6 @@ def test_consume_next_w_partial_result(self): self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, []) self.assertEqual(streamed._pending_chunk, VALUES[0]) - self.assertEqual(streamed.resume_token, result_set.resume_token) def test_consume_next_w_pending_chunk(self): FIELDS = [ @@ -737,7 +732,6 @@ def test_consume_next_w_pending_chunk(self): ]) self.assertEqual(streamed._current_row, [BARE[6]]) self.assertIsNone(streamed._pending_chunk) - self.assertEqual(streamed.resume_token, result_set.resume_token) def test_consume_next_last_set(self): FIELDS = [ @@ -761,7 +755,6 @@ def test_consume_next_last_set(self): self.assertEqual(streamed.rows, [BARE]) self.assertEqual(streamed._current_row, []) self.assertEqual(streamed._stats, stats) - self.assertEqual(streamed.resume_token, result_set.resume_token) def test_consume_all_empty(self): iterator = _MockCancellableIterator()