diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index 4ef43e9d74..2d8eeed4a5 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -15,10 +15,27 @@ if TYPE_CHECKING: from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi import ProgrammingError + from google.cloud.spanner_dbapi.parsed_statement import ( ParsedStatement, ClientSideStatementType, ) +from google.cloud.spanner_v1 import ( + Type, + StructType, + TypeCode, + ResultSetMetadata, + PartialResultSet, +) + +from google.cloud.spanner_v1._helpers import _make_value_pb +from google.cloud.spanner_v1.streamed import StreamedResultSet + +CONNECTION_CLOSED_ERROR = "This connection is closed" +TRANSACTION_NOT_STARTED_WARNING = ( + "This method is non-operational as a transaction has not been started." +) def execute(connection: "Connection", parsed_statement: ParsedStatement): @@ -32,9 +49,46 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement): :type parsed_statement: ParsedStatement :param parsed_statement: parsed_statement based on the sql query """ - if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT: - return connection.commit() - if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN: - return connection.begin() - if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK: - return connection.rollback() + if connection.is_closed: + raise ProgrammingError(CONNECTION_CLOSED_ERROR) + statement_type = parsed_statement.client_side_statement_type + if statement_type == ClientSideStatementType.COMMIT: + connection.commit() + return None + if statement_type == ClientSideStatementType.BEGIN: + connection.begin() + return None + if statement_type == ClientSideStatementType.ROLLBACK: + connection.rollback() + return None + if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP: + if connection._transaction is None: + committed_timestamp = None + else: + committed_timestamp = connection._transaction.committed + return _get_streamed_result_set( + ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name, + TypeCode.TIMESTAMP, + committed_timestamp, + ) + if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP: + if connection._snapshot is None: + read_timestamp = None + else: + read_timestamp = connection._snapshot._transaction_read_timestamp + return _get_streamed_result_set( + ClientSideStatementType.SHOW_READ_TIMESTAMP.name, + TypeCode.TIMESTAMP, + read_timestamp, + ) + + +def _get_streamed_result_set(column_name, type_code, column_value): + struct_type_pb = StructType( + fields=[StructType.Field(name=column_name, type_=Type(code=type_code))] + ) + + result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb)) + if column_value is not None: + result_set.values.extend([_make_value_pb(column_value)]) + return StreamedResultSet(iter([result_set])) diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py index ce1474e809..35d0e4e609 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_parser.py +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -23,6 +23,12 @@ RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE) RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE) RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE) +RE_SHOW_COMMIT_TIMESTAMP = re.compile( + r"^\s*(SHOW)\s+(VARIABLE)\s+(COMMIT_TIMESTAMP)", re.IGNORECASE +) +RE_SHOW_READ_TIMESTAMP = re.compile( + r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE +) def parse_stmt(query): @@ -37,16 +43,19 @@ def parse_stmt(query): :rtype: ParsedStatement :returns: ParsedStatement object. """ + client_side_statement_type = None if RE_COMMIT.match(query): - return ParsedStatement( - StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT - ) + client_side_statement_type = ClientSideStatementType.COMMIT if RE_BEGIN.match(query): - return ParsedStatement( - StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN - ) + client_side_statement_type = ClientSideStatementType.BEGIN if RE_ROLLBACK.match(query): + client_side_statement_type = ClientSideStatementType.ROLLBACK + if RE_SHOW_COMMIT_TIMESTAMP.match(query): + client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP + if RE_SHOW_READ_TIMESTAMP.match(query): + client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP + if client_side_statement_type is not None: return ParsedStatement( - StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK + StatementType.CLIENT_SIDE, query, client_side_statement_type ) return None diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index a3306b316c..f60913fd14 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -23,6 +23,7 @@ from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_v1.snapshot import Snapshot +from deprecated import deprecated from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -35,7 +36,7 @@ CLIENT_TRANSACTION_NOT_STARTED_WARNING = ( - "This method is non-operational as transaction has not started" + "This method is non-operational as a transaction has not been started." ) MAX_INTERNAL_RETRIES = 50 @@ -107,6 +108,9 @@ def __init__(self, instance, database=None, read_only=False): self._staleness = None self.request_priority = None self._transaction_begin_marked = False + # whether transaction started at Spanner. This means that we had + # made atleast one call to Spanner. + self._spanner_transaction_started = False @property def autocommit(self): @@ -140,26 +144,15 @@ def database(self): return self._database @property - def _spanner_transaction_started(self): - """Flag: whether transaction started at Spanner. This means that we had - made atleast one call to Spanner. Property client_transaction_started - would always be true if this is true as transaction has to start first - at clientside than at Spanner - - Returns: - bool: True if Spanner transaction started, False otherwise. - """ + @deprecated( + reason="This method is deprecated. Use _spanner_transaction_started field" + ) + def inside_transaction(self): return ( self._transaction and not self._transaction.committed and not self._transaction.rolled_back - ) or (self._snapshot is not None) - - @property - def inside_transaction(self): - """Deprecated property which won't be supported in future versions. - Please use spanner_transaction_started property instead.""" - return self._spanner_transaction_started + ) @property def _client_transaction_started(self): @@ -277,7 +270,8 @@ def _release_session(self): """ if self.database is None: raise ValueError("Database needs to be passed for this operation") - self.database._pool.put(self._session) + if self._session is not None: + self.database._pool.put(self._session) self._session = None def retry_transaction(self): @@ -293,7 +287,7 @@ def retry_transaction(self): """ attempt = 0 while True: - self._transaction = None + self._spanner_transaction_started = False attempt += 1 if attempt > MAX_INTERNAL_RETRIES: raise @@ -319,7 +313,6 @@ def _rerun_previous_statements(self): status, res = transaction.batch_update(statements) if status.code == ABORTED: - self.connection._transaction = None raise Aborted(status.details) retried_checksum = ResultsChecksum() @@ -363,6 +356,8 @@ def transaction_checkout(self): if not self.read_only and self._client_transaction_started: if not self._spanner_transaction_started: self._transaction = self._session_checkout().transaction() + self._snapshot = None + self._spanner_transaction_started = True self._transaction.begin() return self._transaction @@ -377,11 +372,13 @@ def snapshot_checkout(self): :returns: A Cloud Spanner snapshot object, ready to use. """ if self.read_only and self._client_transaction_started: - if not self._snapshot: + if not self._spanner_transaction_started: self._snapshot = Snapshot( self._session_checkout(), multi_use=True, **self.staleness ) + self._transaction = None self._snapshot.begin() + self._spanner_transaction_started = True return self._snapshot @@ -391,7 +388,7 @@ def close(self): The connection will be unusable from this point forward. If the connection has an active transaction, it will be rolled back. """ - if self._spanner_transaction_started and not self.read_only: + if self._spanner_transaction_started and not self._read_only: self._transaction.rollback() if self._own_pool and self.database: @@ -405,13 +402,15 @@ def begin(self): Marks the transaction as started. :raises: :class:`InterfaceError`: if this connection is closed. - :raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running + :raises: :class:`OperationalError`: if there is an existing transaction + that has been started """ if self._transaction_begin_marked: raise OperationalError("A transaction has already started") if self._spanner_transaction_started: raise OperationalError( - "Beginning a new transaction is not allowed when a transaction is already running" + "Beginning a new transaction is not allowed when a transaction " + "is already running" ) self._transaction_begin_marked = True @@ -430,41 +429,37 @@ def commit(self): return self.run_prior_DDL_statements() - if self._spanner_transaction_started: - try: - if self.read_only: - self._snapshot = None - else: - self._transaction.commit() - - self._release_session() - self._statements = [] - self._transaction_begin_marked = False - except Aborted: - self.retry_transaction() - self.commit() + try: + if self._spanner_transaction_started and not self._read_only: + self._transaction.commit() + except Aborted: + self.retry_transaction() + self.commit() + finally: + self._release_session() + self._statements = [] + self._transaction_begin_marked = False + self._spanner_transaction_started = False def rollback(self): """Rolls back any pending transaction. This is a no-op if there is no active client transaction. """ - if not self._client_transaction_started: warnings.warn( CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2 ) return - if self._spanner_transaction_started: - if self.read_only: - self._snapshot = None - else: + try: + if self._spanner_transaction_started and not self._read_only: self._transaction.rollback() - + finally: self._release_session() self._statements = [] self._transaction_begin_marked = False + self._spanner_transaction_started = False @check_not_closed def cursor(self): diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 023149eeb0..726dd26cb4 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -178,7 +178,10 @@ def close(self): """Closes this cursor.""" self._is_closed = True - def _do_execute_update(self, transaction, sql, params): + def _do_execute_update_in_autocommit(self, transaction, sql, params): + """This function should only be used in autocommit mode.""" + self.connection._transaction = transaction + self.connection._snapshot = None self._result_set = transaction.execute_sql( sql, params=params, param_types=get_param_types(params) ) @@ -239,65 +242,72 @@ def execute(self, sql, args=None): self._row_count = _UNSET_COUNT try: - if self.connection.read_only: - self._handle_DQL(sql, args or None) - return - parsed_statement = parse_utils.classify_statement(sql) + if parsed_statement.statement_type == StatementType.CLIENT_SIDE: - return client_side_statement_executor.execute( + self._result_set = client_side_statement_executor.execute( self.connection, parsed_statement ) - if parsed_statement.statement_type == StatementType.DDL: + if self._result_set is not None: + self._itr = PeekIterator(self._result_set) + elif self.connection.read_only or ( + not self.connection._client_transaction_started + and parsed_statement.statement_type == StatementType.QUERY + ): + self._handle_DQL(sql, args or None) + elif parsed_statement.statement_type == StatementType.DDL: self._batch_DDLs(sql) if not self.connection._client_transaction_started: self.connection.run_prior_DDL_statements() - return - - # For every other operation, we've got to ensure that - # any prior DDL statements were run. - # self._run_prior_DDL_statements() - self.connection.run_prior_DDL_statements() - - if parsed_statement.statement_type == StatementType.UPDATE: - sql = parse_utils.ensure_where_clause(sql) - - sql, args = sql_pyformat_args_to_spanner(sql, args or None) - - if self.connection._client_transaction_started: - statement = Statement( - sql, - args, - get_param_types(args or None), - ResultsChecksum(), - ) - - ( - self._result_set, - self._checksum, - ) = self.connection.run_statement(statement) - while True: - try: - self._itr = PeekIterator(self._result_set) - break - except Aborted: - self.connection.retry_transaction() - return - - if parsed_statement.statement_type == StatementType.QUERY: - self._handle_DQL(sql, args or None) else: - self.connection.database.run_in_transaction( - self._do_execute_update, - sql, - args or None, - ) + self._execute_in_rw_transaction(parsed_statement, sql, args) + except (AlreadyExists, FailedPrecondition, OutOfRange) as e: raise IntegrityError(getattr(e, "details", e)) from e except InvalidArgument as e: raise ProgrammingError(getattr(e, "details", e)) from e except InternalServerError as e: raise OperationalError(getattr(e, "details", e)) from e + finally: + if self.connection._client_transaction_started is False: + self.connection._spanner_transaction_started = False + + def _execute_in_rw_transaction(self, parsed_statement, sql, args): + # For every other operation, we've got to ensure that + # any prior DDL statements were run. + self.connection.run_prior_DDL_statements() + if parsed_statement.statement_type == StatementType.UPDATE: + sql = parse_utils.ensure_where_clause(sql) + sql, args = sql_pyformat_args_to_spanner(sql, args or None) + + if self.connection._client_transaction_started: + statement = Statement( + sql, + args, + get_param_types(args or None), + ResultsChecksum(), + ) + + ( + self._result_set, + self._checksum, + ) = self.connection.run_statement(statement) + + while True: + try: + self._itr = PeekIterator(self._result_set) + break + except Aborted: + self.connection.retry_transaction() + except Exception as ex: + self.connection._statements.remove(statement) + raise ex + else: + self.connection.database.run_in_transaction( + self._do_execute_update_in_autocommit, + sql, + args or None, + ) @check_not_closed def executemany(self, operation, seq_of_params): @@ -477,6 +487,10 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params): # Unfortunately, Spanner doesn't seem to send back # information about the number of rows available. self._row_count = _UNSET_COUNT + if self._result_set.metadata.transaction.read_timestamp is not None: + snapshot._transaction_read_timestamp = ( + self._result_set.metadata.transaction.read_timestamp + ) def _handle_DQL(self, sql, params): if self.connection.database is None: @@ -492,6 +506,8 @@ def _handle_DQL(self, sql, params): with self.connection.database.snapshot( **self.connection.staleness ) as snapshot: + self.connection._snapshot = snapshot + self.connection._transaction = None self._handle_DQL_with_snapshot(snapshot, sql, params) def __enter__(self): diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index 28705b69ed..30f4c1630f 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -28,6 +28,8 @@ class ClientSideStatementType(Enum): COMMIT = 1 BEGIN = 2 ROLLBACK = 3 + SHOW_COMMIT_TIMESTAMP = 4 + SHOW_READ_TIMESTAMP = 5 @dataclass diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 573042aa11..1e515bd8e6 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -447,31 +447,19 @@ def execute_sql( if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: - iterator = _restart_on_unavailable( - restart, - request, - "CloudSpanner.ReadWriteTransaction", - self._session, - trace_attributes, - transaction=self, - ) - self._read_request_count += 1 - self._execute_sql_count += 1 - - if self._multi_use: - return StreamedResultSet(iterator, source=self) - else: - return StreamedResultSet(iterator) + return self._get_streamed_result_set(restart, request, trace_attributes) else: - iterator = _restart_on_unavailable( - restart, - request, - "CloudSpanner.ReadWriteTransaction", - self._session, - trace_attributes, - transaction=self, - ) + return self._get_streamed_result_set(restart, request, trace_attributes) + def _get_streamed_result_set(self, restart, request, trace_attributes): + iterator = _restart_on_unavailable( + restart, + request, + "CloudSpanner.ReadWriteTransaction", + self._session, + trace_attributes, + transaction=self, + ) self._read_request_count += 1 self._execute_sql_count += 1 @@ -739,6 +727,7 @@ def __init__( "'min_read_timestamp' / 'max_staleness'" ) + self._transaction_read_timestamp = None self._strong = len(flagged) == 0 self._read_timestamp = read_timestamp self._min_read_timestamp = min_read_timestamp @@ -768,7 +757,9 @@ def _make_txn_selector(self): value = True options = TransactionOptions( - read_only=TransactionOptions.ReadOnly(**{key: value}) + read_only=TransactionOptions.ReadOnly( + **{key: value, "return_read_timestamp": True} + ) ) if self._multi_use: @@ -814,4 +805,5 @@ def begin(self): allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) self._transaction_id = response.id + self._transaction_read_timestamp = response.read_timestamp return self._transaction_id diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 26af9e5e0f..6a6cc385f6 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -25,6 +25,7 @@ from google.cloud.spanner_dbapi.exceptions import ProgrammingError, OperationalError from google.cloud.spanner_v1 import JsonObject from google.cloud.spanner_v1 import gapic_version as package_version +from google.api_core.datetime_helpers import DatetimeWithNanoseconds from . import _helpers DATABASE_NAME = "dbapi-txn" @@ -109,7 +110,7 @@ def _execute_common_statements(self, cursor): "test.email_updated@domen.ru", ) - @pytest.mark.parametrize("client_side", [False, True]) + @pytest.mark.parametrize("client_side", [True, False]) def test_commit(self, client_side): """Test committing a transaction with several statements.""" updated_row = self._execute_common_statements(self._cursor) @@ -125,6 +126,109 @@ def test_commit(self, client_side): assert got_rows == [updated_row] + @pytest.mark.skip(reason="b/315807641") + def test_commit_exception(self): + """Test that if exception during commit method is caught, then + subsequent operations on same Cursor and Connection object works + properly.""" + self._execute_common_statements(self._cursor) + # deleting the session to fail the commit + self._conn._session.delete() + try: + self._conn.commit() + except Exception: + pass + + # Testing that the connection and Cursor are in proper state post commit + # and a new transaction is started + updated_row = self._execute_common_statements(self._cursor) + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + self._conn.commit() + + assert got_rows == [updated_row] + + @pytest.mark.skip(reason="b/315807641") + def test_rollback_exception(self): + """Test that if exception during rollback method is caught, then + subsequent operations on same Cursor and Connection object works + properly.""" + self._execute_common_statements(self._cursor) + # deleting the session to fail the rollback + self._conn._session.delete() + try: + self._conn.rollback() + except Exception: + pass + + # Testing that the connection and Cursor are in proper state post + # exception in rollback and a new transaction is started + updated_row = self._execute_common_statements(self._cursor) + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + self._conn.commit() + + assert got_rows == [updated_row] + + @pytest.mark.skip(reason="b/315807641") + def test_cursor_execute_exception(self): + """Test that if exception in Cursor's execute method is caught when + Connection is not in autocommit mode, then subsequent operations on + same Cursor and Connection object works properly.""" + updated_row = self._execute_common_statements(self._cursor) + try: + self._cursor.execute("SELECT * FROM unknown_table") + except Exception: + pass + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + self._conn.commit() + assert got_rows == [updated_row] + + # Testing that the connection and Cursor are in proper state post commit + # and a new transaction is started + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + self._conn.commit() + assert got_rows == [updated_row] + + def test_cursor_execute_exception_autocommit(self): + """Test that if exception in Cursor's execute method is caught when + Connection is in autocommit mode, then subsequent operations on + same Cursor and Connection object works properly.""" + self._conn.autocommit = True + updated_row = self._execute_common_statements(self._cursor) + try: + self._cursor.execute("SELECT * FROM unknown_table") + except Exception: + pass + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + assert got_rows == [updated_row] + + def test_cursor_execute_exception_begin_client_side(self): + """Test that if exception in Cursor's execute method is caught when + beginning a transaction using client side statement, then subsequent + operations on same Cursor and Connection object works properly.""" + self._conn.autocommit = True + self._cursor.execute("begin transaction") + updated_row = self._execute_common_statements(self._cursor) + try: + self._cursor.execute("SELECT * FROM unknown_table") + except Exception: + pass + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + self._conn.commit() + assert got_rows == [updated_row] + + # Testing that the connection and Cursor are in proper state post commit + self._conn.autocommit = False + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + self._conn.commit() + assert got_rows == [updated_row] + @pytest.mark.noautofixt def test_begin_client_side(self, shared_instance, dbapi_database): """Test beginning a transaction using client side statement, @@ -152,6 +256,175 @@ def test_begin_client_side(self, shared_instance, dbapi_database): conn3.close() assert got_rows == [updated_row] + def test_begin_and_commit(self): + """Test beginning and then committing a transaction is a Noop""" + self._cursor.execute("begin transaction") + self._cursor.execute("commit transaction") + self._cursor.execute("SELECT * FROM contacts") + self._conn.commit() + assert self._cursor.fetchall() == [] + + def test_begin_and_rollback(self): + """Test beginning and then rolling back a transaction is a Noop""" + self._cursor.execute("begin transaction") + self._cursor.execute("rollback transaction") + self._cursor.execute("SELECT * FROM contacts") + self._conn.commit() + assert self._cursor.fetchall() == [] + + def test_read_and_commit_timestamps(self): + """Test COMMIT_TIMESTAMP is not available after read statement and + READ_TIMESTAMP is not available after write statement in autocommit + mode.""" + self._conn.autocommit = True + self._cursor.execute("SELECT * FROM contacts") + self._cursor.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + + self._cursor.execute("SHOW VARIABLE COMMIT_TIMESTAMP") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 1 + + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 0 + + self._cursor.execute("SELECT * FROM contacts") + + self._cursor.execute("SHOW VARIABLE COMMIT_TIMESTAMP") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 0 + + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 1 + + def test_commit_timestamp_client_side_transaction(self): + """Test executing SHOW_COMMIT_TIMESTAMP client side statement in a + transaction.""" + + self._cursor.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + self._cursor.execute("SHOW VARIABLE COMMIT_TIMESTAMP") + got_rows = self._cursor.fetchall() + # As the connection is not committed we will get 0 rows + assert len(got_rows) == 0 + assert len(self._cursor.description) == 1 + + self._cursor.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + self._conn.commit() + self._cursor.execute("SHOW VARIABLE COMMIT_TIMESTAMP") + + got_rows = self._cursor.fetchall() + assert len(got_rows) == 1 + assert len(got_rows[0]) == 1 + assert len(self._cursor.description) == 1 + assert self._cursor.description[0].name == "SHOW_COMMIT_TIMESTAMP" + assert isinstance(got_rows[0][0], DatetimeWithNanoseconds) + + def test_commit_timestamp_client_side_autocommit(self): + """Test executing SHOW_COMMIT_TIMESTAMP client side statement in a + transaction when connection is in autocommit mode.""" + + self._conn.autocommit = True + self._cursor.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + self._cursor.execute("SHOW VARIABLE COMMIT_TIMESTAMP") + + got_rows = self._cursor.fetchall() + assert len(got_rows) == 1 + assert len(got_rows[0]) == 1 + assert len(self._cursor.description) == 1 + assert self._cursor.description[0].name == "SHOW_COMMIT_TIMESTAMP" + assert isinstance(got_rows[0][0], DatetimeWithNanoseconds) + + def test_read_timestamp_client_side(self): + """Test executing SHOW_READ_TIMESTAMP client side statement in a + transaction.""" + + self._conn.read_only = True + self._cursor.execute("SELECT * FROM contacts") + assert self._cursor.fetchall() == [] + + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + read_timestamp_query_result_1 = self._cursor.fetchall() + + self._cursor.execute("SELECT * FROM contacts") + assert self._cursor.fetchall() == [] + + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + read_timestamp_query_result_2 = self._cursor.fetchall() + + self._conn.commit() + + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + read_timestamp_query_result_3 = self._cursor.fetchall() + assert len(self._cursor.description) == 1 + assert self._cursor.description[0].name == "SHOW_READ_TIMESTAMP" + + assert ( + read_timestamp_query_result_1 + == read_timestamp_query_result_2 + == read_timestamp_query_result_3 + ) + assert len(read_timestamp_query_result_1) == 1 + assert len(read_timestamp_query_result_1[0]) == 1 + assert isinstance(read_timestamp_query_result_1[0][0], DatetimeWithNanoseconds) + + self._cursor.execute("SELECT * FROM contacts") + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + read_timestamp_query_result_4 = self._cursor.fetchall() + self._conn.commit() + assert read_timestamp_query_result_1 != read_timestamp_query_result_4 + + def test_read_timestamp_client_side_autocommit(self): + """Test executing SHOW_READ_TIMESTAMP client side statement in a + transaction when connection is in autocommit mode.""" + + self._conn.autocommit = True + + self._cursor.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + self._conn.read_only = True + self._cursor.execute("SELECT * FROM contacts") + assert self._cursor.fetchall() == [ + (2, "first-name", "last-name", "test.email@domen.ru") + ] + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + read_timestamp_query_result_1 = self._cursor.fetchall() + + assert len(read_timestamp_query_result_1) == 1 + assert len(read_timestamp_query_result_1[0]) == 1 + assert len(self._cursor.description) == 1 + assert self._cursor.description[0].name == "SHOW_READ_TIMESTAMP" + assert isinstance(read_timestamp_query_result_1[0][0], DatetimeWithNanoseconds) + + self._cursor.execute("SELECT * FROM contacts") + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + read_timestamp_query_result_2 = self._cursor.fetchall() + assert read_timestamp_query_result_1 != read_timestamp_query_result_2 + def test_begin_success_post_commit(self): """Test beginning a new transaction post commiting an existing transaction is possible on a connection, when connection is in autocommit mode.""" @@ -643,6 +916,17 @@ def test_read_only(self): ReadOnly transactions. """ + self._conn.read_only = True + self._cursor.execute("SELECT * FROM contacts") + assert self._cursor.fetchall() == [] + self._conn.commit() + + def test_read_only_dml(self): + """ + Check that connection set to `read_only=True` leads to exception when + executing dml statements. + """ + self._conn.read_only = True with pytest.raises(ProgrammingError): self._cursor.execute( @@ -653,9 +937,6 @@ def test_read_only(self): """ ) - self._cursor.execute("SELECT * FROM contacts") - self._conn.commit() - def test_staleness(self): """Check the DB API `staleness` option.""" diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 91b2e3d5e8..853b78a936 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -20,6 +20,8 @@ import warnings import pytest from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError +from google.cloud.spanner_dbapi import Connection +from google.cloud.spanner_dbapi.connection import CLIENT_TRANSACTION_NOT_STARTED_WARNING PROJECT = "test-project" INSTANCE = "test-instance" @@ -46,7 +48,6 @@ def _get_client_info(self): return ClientInfo(user_agent=USER_AGENT) def _make_connection(self, **kwargs): - from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_v1.instance import Instance from google.cloud.spanner_v1.client import Client @@ -71,33 +72,13 @@ def test_autocommit_setter_transaction_not_started(self, mock_commit): @mock.patch("google.cloud.spanner_dbapi.connection.Connection.commit") def test_autocommit_setter_transaction_started(self, mock_commit): connection = self._make_connection() - connection._transaction = mock.Mock(committed=False, rolled_back=False) + connection._spanner_transaction_started = True connection.autocommit = True mock_commit.assert_called_once() self.assertTrue(connection._autocommit) - @mock.patch("google.cloud.spanner_dbapi.connection.Connection.commit") - def test_autocommit_setter_transaction_started_commited_rolled_back( - self, mock_commit - ): - connection = self._make_connection() - - connection._transaction = mock.Mock(committed=True, rolled_back=False) - - connection.autocommit = True - mock_commit.assert_not_called() - self.assertTrue(connection._autocommit) - - connection.autocommit = False - - connection._transaction = mock.Mock(committed=False, rolled_back=True) - - connection.autocommit = True - mock_commit.assert_not_called() - self.assertTrue(connection._autocommit) - def test_property_database(self): from google.cloud.spanner_v1.database import Database @@ -116,7 +97,7 @@ def test_read_only_connection(self): connection = self._make_connection(read_only=True) self.assertTrue(connection.read_only) - connection._transaction = mock.Mock(committed=False, rolled_back=False) + connection._spanner_transaction_started = True with self.assertRaisesRegex( ValueError, "Connection read/write mode can't be changed while a transaction is in progress. " @@ -124,7 +105,7 @@ def test_read_only_connection(self): ): connection.read_only = False - connection._transaction = None + connection._spanner_transaction_started = False connection.read_only = False self.assertFalse(connection.read_only) @@ -160,8 +141,6 @@ def _make_pool(): @mock.patch("google.cloud.spanner_v1.database.Database") def test__session_checkout(self, mock_database): - from google.cloud.spanner_dbapi import Connection - pool = self._make_pool() mock_database._pool = pool connection = Connection(INSTANCE, mock_database) @@ -175,8 +154,6 @@ def test__session_checkout(self, mock_database): self.assertEqual(connection._session, "db_session") def test_session_checkout_database_error(self): - from google.cloud.spanner_dbapi import Connection - connection = Connection(INSTANCE) with pytest.raises(ValueError): @@ -184,8 +161,6 @@ def test_session_checkout_database_error(self): @mock.patch("google.cloud.spanner_v1.database.Database") def test__release_session(self, mock_database): - from google.cloud.spanner_dbapi import Connection - pool = self._make_pool() mock_database._pool = pool connection = Connection(INSTANCE, mock_database) @@ -196,15 +171,11 @@ def test__release_session(self, mock_database): self.assertIsNone(connection._session) def test_release_session_database_error(self): - from google.cloud.spanner_dbapi import Connection - connection = Connection(INSTANCE) with pytest.raises(ValueError): connection._release_session() def test_transaction_checkout(self): - from google.cloud.spanner_dbapi import Connection - connection = Connection(INSTANCE, DATABASE) mock_checkout = mock.MagicMock(autospec=True) connection._session_checkout = mock_checkout @@ -214,8 +185,8 @@ def test_transaction_checkout(self): mock_checkout.assert_called_once_with() mock_transaction = mock.MagicMock() - mock_transaction.committed = mock_transaction.rolled_back = False connection._transaction = mock_transaction + connection._spanner_transaction_started = True self.assertEqual(connection.transaction_checkout(), mock_transaction) @@ -223,8 +194,6 @@ def test_transaction_checkout(self): self.assertIsNone(connection.transaction_checkout()) def test_snapshot_checkout(self): - from google.cloud.spanner_dbapi import Connection - connection = Connection(INSTANCE, DATABASE, read_only=True) connection.autocommit = False @@ -239,20 +208,20 @@ def test_snapshot_checkout(self): self.assertEqual(snapshot, connection.snapshot_checkout()) connection.commit() - self.assertIsNone(connection._snapshot) + self.assertIsNotNone(connection._snapshot) release_session.assert_called_once() connection.snapshot_checkout() self.assertIsNotNone(connection._snapshot) connection.rollback() - self.assertIsNone(connection._snapshot) + self.assertIsNotNone(connection._snapshot) + self.assertEqual(release_session.call_count, 2) connection.autocommit = True self.assertIsNone(connection.snapshot_checkout()) - @mock.patch("google.cloud.spanner_v1.Client") - def test_close(self, mock_client): + def test_close(self): from google.cloud.spanner_dbapi import connect from google.cloud.spanner_dbapi import InterfaceError @@ -268,8 +237,8 @@ def test_close(self, mock_client): connection.cursor() mock_transaction = mock.MagicMock() - mock_transaction.committed = mock_transaction.rolled_back = False connection._transaction = mock_transaction + connection._spanner_transaction_started = True mock_rollback = mock.MagicMock() mock_transaction.rollback = mock_rollback @@ -285,36 +254,35 @@ def test_close(self, mock_client): self.assertTrue(connection.is_closed) @mock.patch.object(warnings, "warn") - def test_commit(self, mock_warn): - from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import ( - CLIENT_TRANSACTION_NOT_STARTED_WARNING, - ) - - connection = Connection(INSTANCE, DATABASE) + def test_commit_with_spanner_transaction_not_started(self, mock_warn): + self._under_test._spanner_transaction_started = False with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: - connection.commit() + self._under_test.commit() - mock_release.assert_not_called() + mock_release.assert_called() - connection._transaction = mock_transaction = mock.MagicMock( - rolled_back=False, committed=False - ) + def test_commit(self): + self._under_test._transaction = mock_transaction = mock.MagicMock() + self._under_test._spanner_transaction_started = True mock_transaction.commit = mock_commit = mock.MagicMock() with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: - connection.commit() + self._under_test.commit() mock_commit.assert_called_once_with() mock_release.assert_called_once_with() - connection._autocommit = True - connection.commit() + @mock.patch.object(warnings, "warn") + def test_commit_in_autocommit_mode(self, mock_warn): + self._under_test._autocommit = True + + self._under_test.commit() + mock_warn.assert_called_once_with( CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2 ) @@ -328,37 +296,38 @@ def test_commit_database_error(self): connection.commit() @mock.patch.object(warnings, "warn") - def test_rollback(self, mock_warn): - from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import ( - CLIENT_TRANSACTION_NOT_STARTED_WARNING, - ) - - connection = Connection(INSTANCE, DATABASE) + def test_rollback_spanner_transaction_not_started(self, mock_warn): + self._under_test._spanner_transaction_started = False with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: - connection.rollback() + self._under_test.rollback() - mock_release.assert_not_called() + mock_release.assert_called() + @mock.patch.object(warnings, "warn") + def test_rollback(self, mock_warn): mock_transaction = mock.MagicMock() - mock_transaction.committed = mock_transaction.rolled_back = False - connection._transaction = mock_transaction + self._under_test._spanner_transaction_started = True + self._under_test._transaction = mock_transaction mock_rollback = mock.MagicMock() mock_transaction.rollback = mock_rollback with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: - connection.rollback() + self._under_test.rollback() mock_rollback.assert_called_once_with() mock_release.assert_called_once_with() - connection._autocommit = True - connection.rollback() + @mock.patch.object(warnings, "warn") + def test_rollback_in_autocommit_mode(self, mock_warn): + self._under_test._autocommit = True + + self._under_test.rollback() + mock_warn.assert_called_once_with( CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2 ) @@ -412,9 +381,7 @@ def test_begin_transaction_begin_marked(self): self._under_test.begin() def test_begin_transaction_started(self): - mock_transaction = mock.MagicMock() - mock_transaction.committed = mock_transaction.rolled_back = False - self._under_test._transaction = mock_transaction + self._under_test._spanner_transaction_started = True with self.assertRaises(OperationalError): self._under_test.begin() @@ -510,7 +477,8 @@ def test_commit_clears_statements(self, mock_transaction): cleared, when the transaction is commited. """ connection = self._make_connection() - connection._transaction = mock.Mock(rolled_back=False, committed=False) + connection._spanner_transaction_started = True + connection._transaction = mock.Mock() connection._statements = [{}, {}] self.assertEqual(len(connection._statements), 2) @@ -526,7 +494,7 @@ def test_rollback_clears_statements(self, mock_transaction): cleared, when the transaction is roll backed. """ connection = self._make_connection() - mock_transaction.committed = mock_transaction.rolled_back = False + connection._spanner_transaction_started = True connection._transaction = mock_transaction connection._statements = [{}, {}] @@ -604,7 +572,8 @@ def test_commit_retry_aborted_statements(self, mock_client): statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) - mock_transaction = mock.Mock(rolled_back=False, committed=False) + mock_transaction = mock.Mock() + connection._spanner_transaction_started = True connection._transaction = mock_transaction mock_transaction.commit.side_effect = [Aborted("Aborted"), None] run_mock = connection.run_statement = mock.Mock() @@ -614,20 +583,6 @@ def test_commit_retry_aborted_statements(self, mock_client): run_mock.assert_called_with(statement, retried=True) - def test_retry_transaction_drop_transaction(self): - """ - Check that before retrying an aborted transaction - connection drops the original aborted transaction. - """ - connection = self._make_connection() - transaction_mock = mock.Mock() - connection._transaction = transaction_mock - - # as we didn't set any statements, the method - # will only drop the transaction object - connection.retry_transaction() - self.assertIsNone(connection._transaction) - @mock.patch("google.cloud.spanner_v1.Client") def test_retry_aborted_retry(self, mock_client): """ @@ -874,7 +829,8 @@ def test_staleness_inside_transaction(self): option if a transaction is in progress. """ connection = self._make_connection() - connection._transaction = mock.Mock(committed=False, rolled_back=False) + connection._spanner_transaction_started = True + connection._transaction = mock.Mock() with self.assertRaises(ValueError): connection.staleness = {"read_timestamp": datetime.datetime(2021, 9, 21)} @@ -902,7 +858,8 @@ def test_staleness_multi_use(self): "session", multi_use=True, read_timestamp=timestamp ) - def test_staleness_single_use_autocommit(self): + @mock.patch("google.cloud.spanner_dbapi.cursor.PeekIterator") + def test_staleness_single_use_autocommit(self, MockedPeekIterator): """ Check that `staleness` option is correctly sent to the snapshot context manager. @@ -919,7 +876,8 @@ def test_staleness_single_use_autocommit(self): # mock snapshot context manager snapshot_obj = mock.Mock() - snapshot_obj.execute_sql = mock.Mock(return_value=[1]) + _result_set = mock.Mock() + snapshot_obj.execute_sql.return_value = _result_set snapshot_ctx = mock.Mock() snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj) @@ -933,7 +891,8 @@ def test_staleness_single_use_autocommit(self): connection.database.snapshot.assert_called_with(read_timestamp=timestamp) - def test_staleness_single_use_readonly_autocommit(self): + @mock.patch("google.cloud.spanner_dbapi.cursor.PeekIterator") + def test_staleness_single_use_readonly_autocommit(self, MockedPeekIterator): """ Check that `staleness` option is correctly sent to the snapshot context manager while in `autocommit` mode. @@ -951,7 +910,8 @@ def test_staleness_single_use_readonly_autocommit(self): # mock snapshot context manager snapshot_obj = mock.Mock() - snapshot_obj.execute_sql = mock.Mock(return_value=[1]) + _result_set = mock.Mock() + snapshot_obj.execute_sql.return_value = _result_set snapshot_ctx = mock.Mock() snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj) @@ -976,7 +936,8 @@ def test_request_priority(self): priority = 2 connection = self._make_connection() - connection._transaction = mock.Mock(committed=False, rolled_back=False) + connection._spanner_transaction_started = True + connection._transaction = mock.Mock() connection._transaction.execute_sql = mock.Mock() connection.request_priority = priority diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 972816f47a..dfa0a0ac17 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -13,7 +13,6 @@ # limitations under the License. """Cursor() class unit tests.""" - from unittest import mock import sys import unittest @@ -107,7 +106,7 @@ def test_do_execute_update(self): result_set.stats = ResultSetStats(row_count_exact=1234) transaction.execute_sql.return_value = result_set - cursor._do_execute_update( + cursor._do_execute_update_in_autocommit( transaction=transaction, sql="SELECT * WHERE true", params={}, @@ -255,7 +254,7 @@ def test_execute_statement(self): mock_db.run_in_transaction = mock_run_in = mock.MagicMock() cursor.execute(sql="sql") mock_run_in.assert_called_once_with( - cursor._do_execute_update, "sql WHERE 1=1", None + cursor._do_execute_update_in_autocommit, "sql WHERE 1=1", None ) def test_execute_integrity_error(self): @@ -272,6 +271,8 @@ def test_execute_integrity_error(self): with self.assertRaises(IntegrityError): cursor.execute(sql="sql") + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_statement", side_effect=exceptions.FailedPrecondition("message"), @@ -279,6 +280,8 @@ def test_execute_integrity_error(self): with self.assertRaises(IntegrityError): cursor.execute(sql="sql") + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_statement", side_effect=exceptions.OutOfRange("message"), @@ -747,8 +750,8 @@ def test_setoutputsize(self): with self.assertRaises(exceptions.InterfaceError): cursor.setoutputsize(size=None) - def test_handle_dql(self): - from google.cloud.spanner_dbapi import utils + @mock.patch("google.cloud.spanner_dbapi.cursor.PeekIterator") + def test_handle_dql(self, MockedPeekIterator): from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT connection = self._make_connection(self.INSTANCE, mock.MagicMock()) @@ -757,14 +760,15 @@ def test_handle_dql(self): ) = mock.MagicMock() cursor = self._make_one(connection) - mock_snapshot.execute_sql.return_value = ["0"] + _result_set = mock.Mock() + mock_snapshot.execute_sql.return_value = _result_set cursor._handle_DQL("sql", params=None) - self.assertEqual(cursor._result_set, ["0"]) - self.assertIsInstance(cursor._itr, utils.PeekIterator) + self.assertEqual(cursor._result_set, _result_set) + self.assertEqual(cursor._itr, MockedPeekIterator()) self.assertEqual(cursor._row_count, _UNSET_COUNT) - def test_handle_dql_priority(self): - from google.cloud.spanner_dbapi import utils + @mock.patch("google.cloud.spanner_dbapi.cursor.PeekIterator") + def test_handle_dql_priority(self, MockedPeekIterator): from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT from google.cloud.spanner_v1 import RequestOptions @@ -777,10 +781,11 @@ def test_handle_dql_priority(self): cursor = self._make_one(connection) sql = "sql" - mock_snapshot.execute_sql.return_value = ["0"] + _result_set = mock.Mock() + mock_snapshot.execute_sql.return_value = _result_set cursor._handle_DQL(sql, params=None) - self.assertEqual(cursor._result_set, ["0"]) - self.assertIsInstance(cursor._itr, utils.PeekIterator) + self.assertEqual(cursor._result_set, _result_set) + self.assertEqual(cursor._itr, MockedPeekIterator()) self.assertEqual(cursor._row_count, _UNSET_COUNT) mock_snapshot.execute_sql.assert_called_with( sql, None, None, request_options=RequestOptions(priority=1) diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 06819c3a3d..7f179d6d31 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -52,13 +52,15 @@ def test_classify_stmt(self): ), ("CREATE ROLE parent", StatementType.DDL), ("commit", StatementType.CLIENT_SIDE), - (" commit TRANSACTION ", StatementType.CLIENT_SIDE), ("begin", StatementType.CLIENT_SIDE), ("start", StatementType.CLIENT_SIDE), ("begin transaction", StatementType.CLIENT_SIDE), ("start transaction", StatementType.CLIENT_SIDE), ("rollback", StatementType.CLIENT_SIDE), + (" commit TRANSACTION ", StatementType.CLIENT_SIDE), (" rollback TRANSACTION ", StatementType.CLIENT_SIDE), + (" SHOW VARIABLE COMMIT_TIMESTAMP ", StatementType.CLIENT_SIDE), + ("SHOW VARIABLE READ_TIMESTAMP", StatementType.CLIENT_SIDE), ("GRANT SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL), ("REVOKE SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL), ("GRANT ROLE parent TO ROLE child", StatementType.DDL), diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 0010877396..a2799262dc 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -48,6 +48,13 @@ } +def _makeTimestamp(): + import datetime + from google.cloud._helpers import UTC + + return datetime.datetime.utcnow().replace(tzinfo=UTC) + + class Test_restart_on_unavailable(OpenTelemetryBase): def _getTargetClass(self): from google.cloud.spanner_v1.snapshot import _SnapshotBase @@ -1376,12 +1383,6 @@ def _make_spanner_api(self): return mock.create_autospec(SpannerClient, instance=True) - def _makeTimestamp(self): - import datetime - from google.cloud._helpers import UTC - - return datetime.datetime.utcnow().replace(tzinfo=UTC) - def _makeDuration(self, seconds=1, microseconds=0): import datetime @@ -1399,7 +1400,7 @@ def test_ctor_defaults(self): self.assertFalse(snapshot._multi_use) def test_ctor_w_multiple_options(self): - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() duration = self._makeDuration() session = _Session() @@ -1407,7 +1408,7 @@ def test_ctor_w_multiple_options(self): self._make_one(session, read_timestamp=timestamp, max_staleness=duration) def test_ctor_w_read_timestamp(self): - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() session = _Session() snapshot = self._make_one(session, read_timestamp=timestamp) self.assertIs(snapshot._session, session) @@ -1419,7 +1420,7 @@ def test_ctor_w_read_timestamp(self): self.assertFalse(snapshot._multi_use) def test_ctor_w_min_read_timestamp(self): - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() session = _Session() snapshot = self._make_one(session, min_read_timestamp=timestamp) self.assertIs(snapshot._session, session) @@ -1466,7 +1467,7 @@ def test_ctor_w_multi_use(self): self.assertTrue(snapshot._multi_use) def test_ctor_w_multi_use_and_read_timestamp(self): - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() session = _Session() snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True) self.assertTrue(snapshot._session is session) @@ -1478,7 +1479,7 @@ def test_ctor_w_multi_use_and_read_timestamp(self): self.assertTrue(snapshot._multi_use) def test_ctor_w_multi_use_and_min_read_timestamp(self): - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() session = _Session() with self.assertRaises(ValueError): @@ -1520,7 +1521,7 @@ def test__make_txn_selector_strong(self): def test__make_txn_selector_w_read_timestamp(self): from google.cloud._helpers import _pb_timestamp_to_datetime - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() session = _Session() snapshot = self._make_one(session, read_timestamp=timestamp) selector = snapshot._make_txn_selector() @@ -1535,7 +1536,7 @@ def test__make_txn_selector_w_read_timestamp(self): def test__make_txn_selector_w_min_read_timestamp(self): from google.cloud._helpers import _pb_timestamp_to_datetime - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() session = _Session() snapshot = self._make_one(session, min_read_timestamp=timestamp) selector = snapshot._make_txn_selector() @@ -1579,7 +1580,7 @@ def test__make_txn_selector_strong_w_multi_use(self): def test__make_txn_selector_w_read_timestamp_w_multi_use(self): from google.cloud._helpers import _pb_timestamp_to_datetime - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() session = _Session() snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True) selector = snapshot._make_txn_selector() @@ -1626,7 +1627,7 @@ def test_begin_w_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() database.spanner_api.begin_transaction.side_effect = RuntimeError() - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() session = _Session(database) snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True) @@ -1651,7 +1652,7 @@ def test_begin_w_retry(self): InternalServerError("Received unexpected EOS on DATA frame from server"), TransactionPB(id=TXN_ID), ] - timestamp = self._makeTimestamp() + timestamp = _makeTimestamp() session = _Session(database) snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True) @@ -1680,7 +1681,9 @@ def test_begin_ok_exact_staleness(self): expected_duration = Duration(seconds=SECONDS, nanos=MICROS * 1000) expected_txn_options = TransactionOptions( - read_only=TransactionOptions.ReadOnly(exact_staleness=expected_duration) + read_only=TransactionOptions.ReadOnly( + exact_staleness=expected_duration, return_read_timestamp=True + ) ) api.begin_transaction.assert_called_once_with( @@ -1714,7 +1717,9 @@ def test_begin_ok_exact_strong(self): self.assertEqual(snapshot._transaction_id, TXN_ID) expected_txn_options = TransactionOptions( - read_only=TransactionOptions.ReadOnly(strong=True) + read_only=TransactionOptions.ReadOnly( + strong=True, return_read_timestamp=True + ) ) api.begin_transaction.assert_called_once_with(