diff --git a/spanner/google/cloud/spanner_v1/database.py b/spanner/google/cloud/spanner_v1/database.py index 6d86e98ab42e..9299639f68f5 100644 --- a/spanner/google/cloud/spanner_v1/database.py +++ b/spanner/google/cloud/spanner_v1/database.py @@ -27,6 +27,7 @@ from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient +from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.pool import SessionCheckout from google.cloud.spanner_v1.session import Session @@ -308,6 +309,14 @@ def batch(self): """ return BatchCheckout(self) + def batch_transaction(self): + """Return an object which wraps a batch read / query. + + :rtype: :class:`~google.cloud.spanner_v1.database.BatchTransaction` + :returns: new wrapper + """ + return BatchTransaction(self) + def run_in_transaction(self, func, *args, **kw): """Perform a unit of work in a transaction, retrying on abort. @@ -406,6 +415,263 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._database._pool.put(self._session) +class BatchTransaction(object): + """Wrapper for generating and processing read / query batches. + + :type database: :class:`~google.cloud.spannder.database.Database` + :param database: database to use + + :type read_timestamp: :class:`datetime.datetime` + :param read_timestamp: Execute all reads at the given timestamp. + + :type min_read_timestamp: :class:`datetime.datetime` + :param min_read_timestamp: Execute all reads at a + timestamp >= ``min_read_timestamp``. + + :type max_staleness: :class:`datetime.timedelta` + :param max_staleness: Read data at a + timestamp >= NOW - ``max_staleness`` seconds. + + :type exact_staleness: :class:`datetime.timedelta` + :param exact_staleness: Execute all reads at a timestamp that is + ``exact_staleness`` old. + """ + def __init__( + self, database, + read_timestamp=None, + min_read_timestamp=None, + max_staleness=None, + exact_staleness=None): + + self._database = database + self._session = None + self._snapshot = None + self._read_timestamp = read_timestamp + self._min_read_timestamp = min_read_timestamp + self._max_staleness = max_staleness + self._exact_staleness = exact_staleness + + @classmethod + def from_dict(cls, database, mapping): + """Reconstruct an instance from a mapping. + + :type database: :class:`~google.cloud.spannder.database.Database` + :param database: database to use + + :type mapping: mapping + :param mapping: serialized state of the instance + + :rtype: :class:`BatchTransaction` + """ + instance = cls(database) + session = instance._session = database.session() + session._session_id = mapping['session_id'] + txn = session.transaction() + txn._transaction_id = mapping['transaction_id'] + return instance + + def to_dict(self): + """Return state as a dictionary. + + Result can be used to serialize the instance and reconstitute + it later using :meth:`from_dict`. + + :rtype: dict + """ + session = self._get_session() + return { + 'session_id': session._session_id, + 'transaction_id': session._transaction._transaction_id, + } + + def _get_session(self): + """Create session as needed. + + .. note:: + + Caller is responsible for cleaning up the session after + all partitions have been processed. + """ + if self._session is None: + session = self._session = self._database.session() + session.create() + txn = session.transaction() + txn.begin() + return self._session + + def _get_snapshot(self): + """Create snapshot if needed.""" + if self._snapshot is None: + self._snapshot = self._get_session().snapshot( + read_timestamp=self._read_timestamp, + min_read_timestamp=self._min_read_timestamp, + max_staleness=self._max_staleness, + exact_staleness=self._exact_staleness, + multi_use=True) + return self._snapshot + + def generate_read_batches( + self, table, columns, keyset, + index='', partition_size_bytes=None, max_partitions=None): + """Start a partitioned batch read operation. + + Uses the ``PartitionRead`` API request to initiate the partitioned + read. Returns a list of batch information needed to perform the + actual reads. + + :type table: str + :param table: name of the table from which to fetch data + + :type columns: list of str + :param columns: names of columns to be retrieved + + :type keyset: :class:`~google.cloud.spanner_v1.keyset.KeySet` + :param keyset: keys / ranges identifying rows to be retrieved + + :type index: str + :param index: (Optional) name of index to use, rather than the + table's primary key + + :type partition_size_bytes: int + :param partition_size_bytes: + (Optional) desired size for each partition generated. The service + uses this as a hint, the actual partition size may differ. + + :type max_partitions: int + :param max_partitions: + (Optional) desired maximum number of partitions generated. The + service uses this as a hint, the actual number of partitions may + differ. + + :rtype: iterable of dict + :returns: + mappings of information used peform actual partitioned reads via + :meth:`process_read_batch`. + """ + partitions = self._get_snapshot().partition_read( + table=table, columns=columns, keyset=keyset, index=index, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions) + + read_info = { + 'table': table, + 'columns': columns, + 'keyset': keyset._to_dict(), + 'index': index, + } + for partition in partitions: + yield {'partition': partition, 'read': read_info.copy()} + + def process_read_batch(self, batch): + """Process a single, partitioned read. + + :type batch: mapping + :param batch: + one of the mappings returned from an earlier call to + :meth:`generate_read_batches`. + + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` + :returns: a result set instance which can be used to consume rows. + """ + kwargs = batch['read'] + keyset_dict = kwargs.pop('keyset') + kwargs['keyset'] = KeySet._from_dict(keyset_dict) + return self._get_snapshot().read( + partition=batch['partition'], **kwargs) + + def generate_query_batches( + self, sql, params=None, param_types=None, + partition_size_bytes=None, max_partitions=None): + """Start a partitioned query operation. + + Uses the ``PartitionQuery`` API request to start a partitioned + query operation. Returns a list of batch information needed to + peform the actual queries. + + :type sql: str + :param sql: SQL query statement + + :type params: dict, {str -> column value} + :param params: values for parameter replacement. Keys must match + the names used in ``sql``. + + :type param_types: dict[str -> Union[dict, .types.Type]] + :param param_types: + (Optional) maps explicit types for one or more param values; + required if parameters are passed. + + :type partition_size_bytes: int + :param partition_size_bytes: + (Optional) desired size for each partition generated. The service + uses this as a hint, the actual partition size may differ. + + :type partition_size_bytes: int + :param partition_size_bytes: + (Optional) desired size for each partition generated. The service + uses this as a hint, the actual partition size may differ. + + :type max_partitions: int + :param max_partitions: + (Optional) desired maximum number of partitions generated. The + service uses this as a hint, the actual number of partitions may + differ. + + :rtype: iterable of dict + :returns: + mappings of information used peform actual partitioned reads via + :meth:`process_read_batch`. + """ + partitions = self._get_snapshot().partition_query( + sql=sql, params=params, param_types=param_types, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions) + + query_info = {'sql': sql} + if params: + query_info['params'] = params + query_info['param_types'] = param_types + + for partition in partitions: + yield {'partition': partition, 'query': query_info} + + def process_query_batch(self, batch): + """Process a single, partitioned query. + + :type batch: mapping + :param batch: + one of the mappings returned from an earlier call to + :meth:`generate_query_batches`. + + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` + :returns: a result set instance which can be used to consume rows. + """ + return self._get_snapshot().execute_sql( + partition=batch['partition'], **batch['query']) + + def process(self, batch): + """Process a single, partitioned query or read. + + :type batch: mapping + :param batch: + one of the mappings returned from an earlier call to + :meth:`generate_query_batches`. + + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` + :returns: a result set instance which can be used to consume rows. + :raises ValueError: if batch does not contain either 'read' or 'query' + """ + if 'query' in batch: + return self.process_query_batch(batch) + if 'read' in batch: + return self.process_read_batch(batch) + raise ValueError("Invalid batch") + + def close(self): + """Clean up underlying session.""" + if self._session is not None: + self._session.delete() + + def _check_ddl_statements(value): """Validate DDL Statements used to define database schema. diff --git a/spanner/google/cloud/spanner_v1/keyset.py b/spanner/google/cloud/spanner_v1/keyset.py index 141388ba83a8..7527770a2c76 100644 --- a/spanner/google/cloud/spanner_v1/keyset.py +++ b/spanner/google/cloud/spanner_v1/keyset.py @@ -85,6 +85,35 @@ def _to_pb(self): return KeyRangePB(**kwargs) + def _to_dict(self): + """Return keyrange's state as a dict. + + :rtype: dict + :returns: state of this instance. + """ + mapping = {} + + if self.start_open: + mapping['start_open'] = self.start_open + + if self.start_closed: + mapping['start_closed'] = self.start_closed + + if self.end_open: + mapping['end_open'] = self.end_open + + if self.end_closed: + mapping['end_closed'] = self.end_closed + + return mapping + + def __eq__(self, other): + """Compare by serialized state.""" + if not isinstance(other, self.__class__): + return NotImplemented + return self._to_dict() == other._to_dict() + + class KeySet(object): """Identify table rows via keys / ranges. @@ -122,3 +151,41 @@ def _to_pb(self): kwargs['ranges'] = [krange._to_pb() for krange in self.ranges] return KeySetPB(**kwargs) + + def _to_dict(self): + """Return keyset's state as a dict. + + The result can be used to serialize the instance and reconstitute + it later using :meth:`_from_dict`. + + :rtype: dict + :returns: state of this instance. + """ + if self.all_: + return {'all': True} + + return { + 'keys': self.keys, + 'ranges': [keyrange._to_dict() for keyrange in self.ranges], + } + + def __eq__(self, other): + """Compare by serialized state.""" + if not isinstance(other, self.__class__): + return NotImplemented + return self._to_dict() == other._to_dict() + + @classmethod + def _from_dict(cls, mapping): + """Create an instance from the corresponding state mapping. + + :type mapping: dict + :param mapping: the instance state. + """ + if mapping.get('all'): + return cls(all_=True) + + r_mappings = mapping.get('ranges', ()) + ranges = [KeyRange(**r_mapping) for r_mapping in r_mappings] + + return cls(keys=mapping.get('keys', ()), ranges=ranges) diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index c5d0e68ba542..9cb3d1b9c16c 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -1276,20 +1276,14 @@ def test_partition_read_w_index(self): keyset = [[expected[row][0], expected[row][1]]] union = [] - with self._db.snapshot(multi_use=True) as snapshot: - partitions = snapshot.partition_read( - self.TABLE, columns, KeySet(all_=True), index='name') - for partition in partitions: - p_results_iter = snapshot.read( - self.TABLE, - columns, - KeySet(all_=True), - index='name', - partition=partition, - ) - union.extend(list(p_results_iter)) + batch_txn = self._db.batch_transaction() + for batch in batch_txn.generate_read_batches( + self.TABLE, columns, KeySet(all_=True), index='name'): + p_results_iter = batch_txn.process(batch) + union.extend(list(p_results_iter)) self.assertEqual(union, expected) + batch_txn.close() def test_execute_sql_w_manual_consume(self): ROW_COUNT = 3000 @@ -1542,13 +1536,13 @@ def test_partition_query(self): all_data_rows = list(self._row_data(row_count)) union = [] - - with self._db.snapshot(multi_use=True) as snapshot: - for partition in snapshot.partition_query(sql): - p_results_iter = snapshot.execute_sql(sql, partition=partition) - union.extend(list(p_results_iter)) + batch_txn = self._db.batch_transaction() + for batch in batch_txn.generate_query_batches(sql): + p_results_iter = batch_txn.process(batch) + union.extend(list(p_results_iter)) self.assertEqual(union, all_data_rows) + batch_txn.close() class TestStreamingChunking(unittest.TestCase, _TestData): diff --git a/spanner/tests/unit/test_database.py b/spanner/tests/unit/test_database.py index 920d0a01b6a7..a5317f32a00e 100644 --- a/spanner/tests/unit/test_database.py +++ b/spanner/tests/unit/test_database.py @@ -39,14 +39,15 @@ class _BaseTest(unittest.TestCase): DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID SESSION_ID = 'session_id' SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID + TRANSACTION_ID = 'transaction_id' def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + return self._get_target_class()(*args, **kwargs) class TestDatabase(_BaseTest): - def _getTargetClass(self): + def _get_target_class(self): from google.cloud.spanner_v1.database import Database return Database @@ -107,7 +108,7 @@ def test_from_pb_bad_database_name(self): database_name = 'INCORRECT_FORMAT' database_pb = admin_v1_pb2.Database(name=database_name) - klass = self._getTargetClass() + klass = self._get_target_class() with self.assertRaises(ValueError): klass.from_pb(database_pb, None) @@ -120,7 +121,7 @@ def test_from_pb_project_mistmatch(self): client = _Client(project=ALT_PROJECT) instance = _Instance(self.INSTANCE_NAME, client) database_pb = admin_v1_pb2.Database(name=self.DATABASE_NAME) - klass = self._getTargetClass() + klass = self._get_target_class() with self.assertRaises(ValueError): klass.from_pb(database_pb, instance) @@ -134,7 +135,7 @@ def test_from_pb_instance_mistmatch(self): client = _Client() instance = _Instance(ALT_INSTANCE, client) database_pb = admin_v1_pb2.Database(name=self.DATABASE_NAME) - klass = self._getTargetClass() + klass = self._get_target_class() with self.assertRaises(ValueError): klass.from_pb(database_pb, instance) @@ -146,7 +147,7 @@ def test_from_pb_success_w_explicit_pool(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client) database_pb = admin_v1_pb2.Database(name=self.DATABASE_NAME) - klass = self._getTargetClass() + klass = self._get_target_class() pool = _Pool() database = klass.from_pb(database_pb, instance, pool=pool) @@ -167,7 +168,7 @@ def test_from_pb_success_w_hyphen_w_default_pool(self): client = _Client() instance = _Instance(self.INSTANCE_NAME, client) database_pb = admin_v1_pb2.Database(name=DATABASE_NAME_HYPHEN) - klass = self._getTargetClass() + klass = self._get_target_class() database = klass.from_pb(database_pb, instance) @@ -645,6 +646,16 @@ def test_batch(self): self.assertIsInstance(checkout, BatchCheckout) self.assertIs(checkout._database, database) + def test_batch_transaction(self): + from google.cloud.spanner_v1.database import BatchTransaction + + database = self._make_one( + self.DATABASE_ID, instance=object(), pool=_Pool()) + + batch_txn = database.batch_transaction() + self.assertIsInstance(batch_txn, BatchTransaction) + self.assertIs(batch_txn._database, database) + def test_run_in_transaction_wo_args(self): import datetime @@ -713,7 +724,7 @@ def nested_unit_of_work(): class TestBatchCheckout(_BaseTest): - def _getTargetClass(self): + def _get_target_class(self): from google.cloud.spanner_v1.database import BatchCheckout return BatchCheckout @@ -784,7 +795,7 @@ class Testing(Exception): class TestSnapshotCheckout(_BaseTest): - def _getTargetClass(self): + def _get_target_class(self): from google.cloud.spanner_v1.database import SnapshotCheckout return SnapshotCheckout @@ -857,6 +868,513 @@ class Testing(Exception): self.assertIs(pool._session, session) +class TestBatchTransaction(_BaseTest): + TABLE = 'table_name' + COLUMNS = ['column_one', 'column_two'] + TOKENS = [b'TOKEN1', b'TOKEN2'] + INDEX = 'index' + + def _get_target_class(self): + from google.cloud.spanner_v1.database import BatchTransaction + + return BatchTransaction + + @staticmethod + def _make_database(**kwargs): + from google.cloud.spanner_v1.database import Database + + return mock.create_autospec(Database, instance=True, **kwargs) + + @staticmethod + def _make_session(**kwargs): + from google.cloud.spanner_v1.session import Session + + return mock.create_autospec(Session, instance=True, **kwargs) + + @staticmethod + def _make_transaction(**kwargs): + from google.cloud.spanner_v1.transaction import Transaction + + return mock.create_autospec(Transaction, instance=True, **kwargs) + + @staticmethod + def _make_snapshot(**kwargs): + from google.cloud.spanner_v1.snapshot import Snapshot + + return mock.create_autospec(Snapshot, instance=True, **kwargs) + + @staticmethod + def _make_keyset(): + from google.cloud.spanner_v1.keyset import KeySet + + return KeySet(all_=True) + + @staticmethod + def _make_timestamp(): + import datetime + from google.cloud._helpers import UTC + + return datetime.datetime.utcnow().replace(tzinfo=UTC) + + @staticmethod + def _make_duration(seconds=1, microseconds=0): + import datetime + + return datetime.timedelta(seconds=seconds, microseconds=microseconds) + + def test_ctor_no_staleness(self): + database = self._make_database() + + batch_txn = self._make_one(database) + + self.assertIs(batch_txn._database, database) + self.assertIsNone(batch_txn._session) + self.assertIsNone(batch_txn._snapshot) + self.assertIsNone(batch_txn._read_timestamp) + self.assertIsNone(batch_txn._min_read_timestamp) + self.assertIsNone(batch_txn._max_staleness) + self.assertIsNone(batch_txn._exact_staleness) + + def test_ctor_w_read_timestamp(self): + database = self._make_database() + timestamp = self._make_timestamp() + + batch_txn = self._make_one(database, read_timestamp=timestamp) + + self.assertIs(batch_txn._database, database) + self.assertIsNone(batch_txn._session) + self.assertIsNone(batch_txn._snapshot) + self.assertEqual(batch_txn._read_timestamp, timestamp) + self.assertIsNone(batch_txn._min_read_timestamp) + self.assertIsNone(batch_txn._max_staleness) + self.assertIsNone(batch_txn._exact_staleness) + + def test_ctor_w_min_read_timestamp(self): + database = self._make_database() + timestamp = self._make_timestamp() + + batch_txn = self._make_one(database, min_read_timestamp=timestamp) + + self.assertIs(batch_txn._database, database) + self.assertIsNone(batch_txn._session) + self.assertIsNone(batch_txn._snapshot) + self.assertIsNone(batch_txn._read_timestamp) + self.assertEqual(batch_txn._min_read_timestamp, timestamp) + self.assertIsNone(batch_txn._max_staleness) + self.assertIsNone(batch_txn._exact_staleness) + + def test_ctor_w_max_staleness(self): + database = self._make_database() + duration = self._make_duration() + + batch_txn = self._make_one(database, max_staleness=duration) + + self.assertIs(batch_txn._database, database) + self.assertIsNone(batch_txn._session) + self.assertIsNone(batch_txn._snapshot) + self.assertIsNone(batch_txn._read_timestamp) + self.assertIsNone(batch_txn._min_read_timestamp) + self.assertEqual(batch_txn._max_staleness, duration) + self.assertIsNone(batch_txn._exact_staleness) + + def test_ctor_w_exact_staleness(self): + database = self._make_database() + duration = self._make_duration() + + batch_txn = self._make_one(database, exact_staleness=duration) + + self.assertIs(batch_txn._database, database) + self.assertIsNone(batch_txn._session) + self.assertIsNone(batch_txn._snapshot) + self.assertIsNone(batch_txn._read_timestamp) + self.assertIsNone(batch_txn._min_read_timestamp) + self.assertIsNone(batch_txn._max_staleness) + self.assertEqual(batch_txn._exact_staleness, duration) + + def test_from_dict(self): + klass = self._get_target_class() + database = self._make_database() + session = database.session.return_value = self._make_session() + txn = session.transaction.return_value = self._make_transaction() + api_repr = { + 'session_id': self.SESSION_ID, + 'transaction_id': self.TRANSACTION_ID, + } + + batch_txn = klass.from_dict(database, api_repr) + self.assertIs(batch_txn._database, database) + self.assertIs(batch_txn._session, session) + self.assertEqual(session._session_id, self.SESSION_ID) + self.assertEqual(txn._transaction_id, self.TRANSACTION_ID) + txn.begin.assert_not_called() + self.assertIsNone(batch_txn._snapshot) + + def test_to_dict(self): + database = self._make_database() + batch_txn = self._make_one(database) + txn = self._make_transaction(_transaction_id=self.TRANSACTION_ID) + session = batch_txn._session = self._make_session( + _session_id=self.SESSION_ID, _transaction=txn) + + expected = { + 'session_id': self.SESSION_ID, + 'transaction_id': self.TRANSACTION_ID, + } + self.assertEqual(batch_txn.to_dict(), expected) + + def test__get_session_already(self): + database = self._make_database() + batch_txn = self._make_one(database) + already = batch_txn._session = object() + self.assertIs(batch_txn._get_session(), already) + + def test__get_session_new(self): + database = self._make_database() + session = database.session.return_value = self._make_session() + txn = session.transaction.return_value = self._make_transaction() + batch_txn = self._make_one(database) + self.assertIs(batch_txn._get_session(), session) + session.create.assert_called_once_with() + txn.begin.assert_called_once_with() + + def test__get_snapshot_already(self): + database = self._make_database() + batch_txn = self._make_one(database) + already = batch_txn._snapshot = object() + self.assertIs(batch_txn._get_snapshot(), already) + + def test__get_snapshot_new_wo_staleness(self): + database = self._make_database() + batch_txn = self._make_one(database) + session = batch_txn._session = self._make_session() + snapshot = session.snapshot.return_value = self._make_snapshot() + self.assertIs(batch_txn._get_snapshot(), snapshot) + session.snapshot.assert_called_once_with( + read_timestamp=None, + min_read_timestamp=None, + max_staleness=None, + exact_staleness=None, + multi_use=True) + + def test__get_snapshot_w_read_timestamp(self): + database = self._make_database() + timestamp = self._make_timestamp() + batch_txn = self._make_one(database, read_timestamp=timestamp) + session = batch_txn._session = self._make_session() + snapshot = session.snapshot.return_value = self._make_snapshot() + self.assertIs(batch_txn._get_snapshot(), snapshot) + session.snapshot.assert_called_once_with( + read_timestamp=timestamp, + min_read_timestamp=None, + max_staleness=None, + exact_staleness=None, + multi_use=True) + + def test__get_snapshot_w_min_read_timestamp(self): + database = self._make_database() + timestamp = self._make_timestamp() + batch_txn = self._make_one(database, min_read_timestamp=timestamp) + session = batch_txn._session = self._make_session() + snapshot = session.snapshot.return_value = self._make_snapshot() + self.assertIs(batch_txn._get_snapshot(), snapshot) + session.snapshot.assert_called_once_with( + read_timestamp=None, + min_read_timestamp=timestamp, + max_staleness=None, + exact_staleness=None, + multi_use=True) + + def test__get_snapshot_w_max_staleness(self): + database = self._make_database() + duration = self._make_duration() + batch_txn = self._make_one(database, max_staleness=duration) + session = batch_txn._session = self._make_session() + snapshot = session.snapshot.return_value = self._make_snapshot() + self.assertIs(batch_txn._get_snapshot(), snapshot) + session.snapshot.assert_called_once_with( + read_timestamp=None, + min_read_timestamp=None, + max_staleness=duration, + exact_staleness=None, + multi_use=True) + + def test__get_snapshot_w_exact_staleness(self): + database = self._make_database() + duration = self._make_duration() + batch_txn = self._make_one(database, exact_staleness=duration) + session = batch_txn._session = self._make_session() + snapshot = session.snapshot.return_value = self._make_snapshot() + self.assertIs(batch_txn._get_snapshot(), snapshot) + session.snapshot.assert_called_once_with( + read_timestamp=None, + min_read_timestamp=None, + max_staleness=None, + exact_staleness=duration, + multi_use=True) + + def test_generate_read_batches_w_max_partitions(self): + max_partitions = len(self.TOKENS) + keyset = self._make_keyset() + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + snapshot.partition_read.return_value = self.TOKENS + + batches = list( + batch_txn.generate_read_batches( + self.TABLE, self.COLUMNS, keyset, + max_partitions=max_partitions)) + + expected_read = { + 'table': self.TABLE, + 'columns': self.COLUMNS, + 'keyset': {'all': True}, + 'index': '', + } + self.assertEqual(len(batches), len(self.TOKENS)) + for batch, token in zip(batches, self.TOKENS): + self.assertEqual(batch['partition'], token) + self.assertEqual(batch['read'], expected_read) + + snapshot.partition_read.assert_called_once_with( + table=self.TABLE, columns=self.COLUMNS, keyset=keyset, + index='', partition_size_bytes=None, max_partitions=max_partitions) + + def test_generate_read_batches_w_index_w_partition_size_bytes(self): + size = 1 << 20 + keyset = self._make_keyset() + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + snapshot.partition_read.return_value = self.TOKENS + + batches = list( + batch_txn.generate_read_batches( + self.TABLE, self.COLUMNS, keyset, index=self.INDEX, + partition_size_bytes=size)) + + expected_read = { + 'table': self.TABLE, + 'columns': self.COLUMNS, + 'keyset': {'all': True}, + 'index': self.INDEX, + } + self.assertEqual(len(batches), len(self.TOKENS)) + for batch, token in zip(batches, self.TOKENS): + self.assertEqual(batch['partition'], token) + self.assertEqual(batch['read'], expected_read) + + snapshot.partition_read.assert_called_once_with( + table=self.TABLE, columns=self.COLUMNS, keyset=keyset, + index=self.INDEX, partition_size_bytes=size, max_partitions=None) + + def test_process_read_batch(self): + keyset = self._make_keyset() + token = b'TOKEN' + batch = { + 'partition': token, + 'read': { + 'table': self.TABLE, + 'columns': self.COLUMNS, + 'keyset': {'all': True}, + 'index': self.INDEX, + }, + } + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + expected = snapshot.read.return_value = object() + + found = batch_txn.process_read_batch(batch) + + self.assertIs(found, expected) + + snapshot.read.assert_called_once_with( + table=self.TABLE, + columns=self.COLUMNS, + keyset=keyset, + index=self.INDEX, + partition=token, + ) + + def test_generate_query_batches_w_max_partitions(self): + sql = 'SELECT COUNT(*) FROM table_name' + max_partitions = len(self.TOKENS) + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + snapshot.partition_query.return_value = self.TOKENS + + batches = list( + batch_txn.generate_query_batches( + sql, max_partitions=max_partitions)) + + expected_query = { + 'sql': sql, + } + self.assertEqual(len(batches), len(self.TOKENS)) + for batch, token in zip(batches, self.TOKENS): + self.assertEqual(batch['partition'], token) + self.assertEqual(batch['query'], expected_query) + + snapshot.partition_query.assert_called_once_with( + sql=sql, params=None, param_types=None, + partition_size_bytes=None, max_partitions=max_partitions) + + def test_generate_query_batches_w_params_w_partition_size_bytes(self): + sql = ( + "SELECT first_name, last_name, email FROM citizens " + "WHERE age <= @max_age" + ) + params = {'max_age': 30} + param_types = {'max_age': 'INT64'} + size = 1 << 20 + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + snapshot.partition_query.return_value = self.TOKENS + + batches = list( + batch_txn.generate_query_batches( + sql, params=params, param_types=param_types, + partition_size_bytes=size)) + + expected_query = { + 'sql': sql, + 'params': params, + 'param_types': param_types, + } + self.assertEqual(len(batches), len(self.TOKENS)) + for batch, token in zip(batches, self.TOKENS): + self.assertEqual(batch['partition'], token) + self.assertEqual(batch['query'], expected_query) + + snapshot.partition_query.assert_called_once_with( + sql=sql, params=params, param_types=param_types, + partition_size_bytes=size, max_partitions=None) + + def test_process_query_batch(self): + sql = ( + "SELECT first_name, last_name, email FROM citizens " + "WHERE age <= @max_age" + ) + params = {'max_age': 30} + param_types = {'max_age': 'INT64'} + token = b'TOKEN' + batch = { + 'partition': token, + 'query': { + 'sql': sql, + 'params': params, + 'param_types': param_types, + }, + } + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + expected = snapshot.execute_sql.return_value = object() + + found = batch_txn.process_query_batch(batch) + + self.assertIs(found, expected) + + snapshot.execute_sql.assert_called_once_with( + sql=sql, + params=params, + param_types=param_types, + partition=token, + ) + + def test_close_wo_session(self): + database = self._make_database() + batch_txn = self._make_one(database) + + batch_txn.close() # no raise + + def test_close_w_session(self): + database = self._make_database() + batch_txn = self._make_one(database) + session = batch_txn._session = self._make_session() + + batch_txn.close() + + session.delete.assert_called_once_with() + + def test_process_w_invalid_batch(self): + keyset = self._make_keyset() + token = b'TOKEN' + batch = { + 'partition': token, + 'bogus': b'BOGUS', + } + database = self._make_database() + batch_txn = self._make_one(database) + + with self.assertRaises(ValueError): + batch_txn.process(batch) + + def test_process_w_read_batch(self): + keyset = self._make_keyset() + token = b'TOKEN' + batch = { + 'partition': token, + 'read': { + 'table': self.TABLE, + 'columns': self.COLUMNS, + 'keyset': {'all': True}, + 'index': self.INDEX, + }, + } + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + expected = snapshot.read.return_value = object() + + found = batch_txn.process(batch) + + self.assertIs(found, expected) + + snapshot.read.assert_called_once_with( + table=self.TABLE, + columns=self.COLUMNS, + keyset=keyset, + index=self.INDEX, + partition=token, + ) + + def test_process_w_query_batch(self): + sql = ( + "SELECT first_name, last_name, email FROM citizens " + "WHERE age <= @max_age" + ) + params = {'max_age': 30} + param_types = {'max_age': 'INT64'} + token = b'TOKEN' + batch = { + 'partition': token, + 'query': { + 'sql': sql, + 'params': params, + 'param_types': param_types, + }, + } + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + expected = snapshot.execute_sql.return_value = object() + + found = batch_txn.process(batch) + + self.assertIs(found, expected) + + snapshot.execute_sql.assert_called_once_with( + sql=sql, + params=params, + param_types=param_types, + partition=token, + ) + + class _Client(object): def __init__(self, project=TestDatabase.PROJECT_ID): @@ -897,6 +1415,9 @@ def put(self, session): class _Session(object): _rows = () + _created = False + _transaction = None + _snapshot = None def __init__(self, database=None, name=_BaseTest.SESSION_NAME, run_transaction_function=False): diff --git a/spanner/tests/unit/test_keyset.py b/spanner/tests/unit/test_keyset.py index a96bb1dad13f..49e98b784c5b 100644 --- a/spanner/tests/unit/test_keyset.py +++ b/spanner/tests/unit/test_keyset.py @@ -18,13 +18,13 @@ class TestKeyRange(unittest.TestCase): - def _getTargetClass(self): + def _get_target_class(self): from google.cloud.spanner_v1.keyset import KeyRange return KeyRange def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + return self._get_target_class()(*args, **kwargs) def test_ctor_no_start_no_end(self): with self.assertRaises(ValueError): @@ -92,6 +92,29 @@ def test_ctor_w_start_closed_and_end_open(self): self.assertEqual(krange.end_open, KEY_2) self.assertEqual(krange.end_closed, None) + def test___eq___self(self): + key_1 = [u'key_1'] + krange = self._make_one(end_open=key_1) + self.assertEqual(krange, krange) + + def test___eq___other_type(self): + key_1 = [u'key_1'] + krange = self._make_one(end_open=key_1) + self.assertNotEqual(krange, object()) + + def test___eq___other_hit(self): + key_1 = [u'key_1'] + krange = self._make_one(end_open=key_1) + other = self._make_one(end_open=key_1) + self.assertEqual(krange, other) + + def test___eq___other(self): + key_1 = [u'key_1'] + key_2 = [u'key_2'] + krange = self._make_one(end_open=key_1) + other = self._make_one(start_closed=key_2, end_open=key_1) + self.assertNotEqual(krange, other) + def test_to_pb_w_start_closed_and_end_open(self): from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value @@ -146,16 +169,36 @@ def test_to_pb_w_empty_list(self): ) self.assertEqual(key_range_pb, expected) + def test_to_dict_w_start_closed_and_end_open(self): + key1 = u'key_1' + key2 = u'key_2' + key_range = self._make_one(start_closed=[key1], end_open=[key2]) + expected = {'start_closed': [key1], 'end_open': [key2]} + self.assertEqual(key_range._to_dict(), expected) + + def test_to_dict_w_start_open_and_end_closed(self): + key1 = u'key_1' + key2 = u'key_2' + key_range = self._make_one(start_open=[key1], end_closed=[key2]) + expected = {'start_open': [key1], 'end_closed': [key2]} + self.assertEqual(key_range._to_dict(), expected) + + def test_to_dict_w_end_closed(self): + key = u'key' + key_range = self._make_one(end_closed=[key]) + expected = {'end_closed': [key]} + self.assertEqual(key_range._to_dict(), expected) + class TestKeySet(unittest.TestCase): - def _getTargetClass(self): + def _get_target_class(self): from google.cloud.spanner_v1.keyset import KeySet return KeySet def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + return self._get_target_class()(*args, **kwargs) def test_ctor_w_all(self): keyset = self._make_one(all_=True) @@ -199,6 +242,63 @@ def test_ctor_w_all_and_ranges(self): with self.assertRaises(ValueError): self._make_one(all_=True, ranges=[range_1, range_2]) + def test___eq___w_self(self): + keyset = self._make_one(all_=True) + self.assertEqual(keyset, keyset) + + def test___eq___w_other_type(self): + keyset = self._make_one(all_=True) + self.assertNotEqual(keyset, object()) + + def test___eq___w_all_hit(self): + keyset = self._make_one(all_=True) + other = self._make_one(all_=True) + self.assertEqual(keyset, other) + + def test___eq___w_all_miss(self): + keys = [[u'key1'], [u'key2']] + keyset = self._make_one(all_=True) + other = self._make_one(keys=keys) + self.assertNotEqual(keyset, other) + + def test___eq___w_keys_hit(self): + keys = [[u'key1'], [u'key2']] + + keyset = self._make_one(keys=keys) + other = self._make_one(keys=keys) + + self.assertEqual(keyset, other) + + def test___eq___w_keys_miss(self): + keys = [[u'key1'], [u'key2']] + + keyset = self._make_one(keys=keys[:1]) + other = self._make_one(keys=keys[1:]) + + self.assertNotEqual(keyset, other) + + def test___eq___w_ranges_hit(self): + from google.cloud.spanner_v1.keyset import KeyRange + + range_1 = KeyRange(start_closed=[u'key1'], end_open=[u'key3']) + range_2 = KeyRange(start_open=[u'key5'], end_closed=[u'key6']) + + keyset = self._make_one(ranges=[range_1, range_2]) + other = self._make_one(ranges=[range_1, range_2]) + + self.assertEqual(keyset, other) + + def test___eq___w_ranges_miss(self): + from google.cloud.spanner_v1.keyset import KeyRange + + range_1 = KeyRange(start_closed=[u'key1'], end_open=[u'key3']) + range_2 = KeyRange(start_open=[u'key5'], end_closed=[u'key6']) + + keyset = self._make_one(ranges=[range_1]) + other = self._make_one(ranges=[range_2]) + + self.assertNotEqual(keyset, other) + def test_to_pb_w_all(self): from google.cloud.spanner_v1.proto.keys_pb2 import KeySet @@ -252,3 +352,89 @@ def test_to_pb_w_only_ranges(self): for found, expected in zip(result.ranges, RANGES): self.assertEqual(found, expected._to_pb()) + + def test_to_dict_w_all(self): + keyset = self._make_one(all_=True) + expected = {'all': True} + self.assertEqual(keyset._to_dict(), expected) + + def test_to_dict_w_only_keys(self): + KEYS = [[u'key1'], [u'key2']] + keyset = self._make_one(keys=KEYS) + + expected = { + 'keys': KEYS, + 'ranges': [], + } + self.assertEqual(keyset._to_dict(), expected) + + def test_to_dict_w_only_ranges(self): + from google.cloud.spanner_v1.keyset import KeyRange + + key_1 = u'KEY_1' + key_2 = u'KEY_2' + key_3 = u'KEY_3' + key_4 = u'KEY_4' + ranges = [ + KeyRange(start_open=[key_1], end_closed=[key_2]), + KeyRange(start_closed=[key_3], end_open=[key_4]), + ] + keyset = self._make_one(ranges=ranges) + + expected = { + 'keys': [], + 'ranges': [ + {'start_open': [key_1], 'end_closed': [key_2]}, + {'start_closed': [key_3], 'end_open': [key_4]}, + ] + } + self.assertEqual(keyset._to_dict(), expected) + + def test_from_dict_w_all(self): + klass = self._get_target_class() + mapping = { + 'all': True, + } + + keyset = klass._from_dict(mapping) + + self.assertTrue(keyset.all_) + self.assertEqual(keyset.keys, []) + self.assertEqual(keyset.ranges, []) + + def test_from_dict_w_keys(self): + klass = self._get_target_class() + keys = [[u'key1'], [u'key2']] + mapping = { + 'keys': keys, + } + + keyset = klass._from_dict(mapping) + + self.assertFalse(keyset.all_) + self.assertEqual(keyset.keys, keys) + self.assertEqual(keyset.ranges, []) + + def test_from_dict_w_ranges(self): + from google.cloud.spanner_v1.keyset import KeyRange + + klass = self._get_target_class() + key_1 = u'KEY_1' + key_2 = u'KEY_2' + key_3 = u'KEY_3' + key_4 = u'KEY_4' + mapping = { + 'ranges': [ + {'start_open': [key_1], 'end_closed': [key_2]}, + {'start_closed': [key_3], 'end_open': [key_4]}, + ], + } + + keyset = klass._from_dict(mapping) + + range_1 = KeyRange(start_open=[key_1], end_closed=[key_2]) + range_2 = KeyRange(start_closed=[key_3], end_open=[key_4]) + + self.assertFalse(keyset.all_) + self.assertEqual(keyset.keys, []) + self.assertEqual(keyset.ranges, [range_1, range_2])