Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementation of client side statements that return #1046

Merged
merged 11 commits into from
Dec 12, 2023
66 changes: 60 additions & 6 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,27 @@

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi import ProgrammingError

from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
ClientSideStatementType,
)
from google.cloud.spanner_v1 import (
Type,
StructType,
TypeCode,
ResultSetMetadata,
PartialResultSet,
)

from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1.streamed import StreamedResultSet

CONNECTION_CLOSED_ERROR = "This connection is closed"
TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as a transaction has not been started."
)


def execute(connection: "Connection", parsed_statement: ParsedStatement):
Expand All @@ -32,9 +49,46 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
return connection.commit()
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
return connection.begin()
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
return connection.rollback()
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
statement_type = parsed_statement.client_side_statement_type
if statement_type == ClientSideStatementType.COMMIT:
connection.commit()
return None
if statement_type == ClientSideStatementType.BEGIN:
connection.begin()
return None
if statement_type == ClientSideStatementType.ROLLBACK:
connection.rollback()
return None
if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP:
if connection._transaction is None:
committed_timestamp = None
else:
committed_timestamp = connection._transaction.committed
return _get_streamed_result_set(
ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name,
TypeCode.TIMESTAMP,
committed_timestamp,
)
if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP:
if connection._snapshot is None:
read_timestamp = None
else:
read_timestamp = connection._snapshot._transaction_read_timestamp
return _get_streamed_result_set(
ClientSideStatementType.SHOW_READ_TIMESTAMP.name,
TypeCode.TIMESTAMP,
read_timestamp,
)


def _get_streamed_result_set(column_name, type_code, column_value):
struct_type_pb = StructType(
fields=[StructType.Field(name=column_name, type_=Type(code=type_code))]
)

result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb))
if column_value is not None:
olavloite marked this conversation as resolved.
Show resolved Hide resolved
result_set.values.extend([_make_value_pb(column_value)])
return StreamedResultSet(iter([result_set]))
23 changes: 16 additions & 7 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)
RE_SHOW_COMMIT_TIMESTAMP = re.compile(
r"^\s*(SHOW)\s+(VARIABLE)\s+(COMMIT_TIMESTAMP)", re.IGNORECASE
)
RE_SHOW_READ_TIMESTAMP = re.compile(
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
)


def parse_stmt(query):
Expand All @@ -37,16 +43,19 @@ def parse_stmt(query):
:rtype: ParsedStatement
:returns: ParsedStatement object.
"""
client_side_statement_type = None
if RE_COMMIT.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
)
client_side_statement_type = ClientSideStatementType.COMMIT
if RE_BEGIN.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
)
client_side_statement_type = ClientSideStatementType.BEGIN
if RE_ROLLBACK.match(query):
client_side_statement_type = ClientSideStatementType.ROLLBACK
if RE_SHOW_COMMIT_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
if RE_SHOW_READ_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK
StatementType.CLIENT_SIDE, query, client_side_statement_type
)
return None
81 changes: 38 additions & 43 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
from deprecated import deprecated

from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
Expand All @@ -35,7 +36,7 @@


CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as transaction has not started"
"This method is non-operational as a transaction has not been started."
)
MAX_INTERNAL_RETRIES = 50

Expand Down Expand Up @@ -107,6 +108,9 @@ def __init__(self, instance, database=None, read_only=False):
self._staleness = None
self.request_priority = None
self._transaction_begin_marked = False
# whether transaction started at Spanner. This means that we had
# made atleast one call to Spanner.
self._spanner_transaction_started = False

@property
def autocommit(self):
Expand Down Expand Up @@ -140,26 +144,15 @@ def database(self):
return self._database

@property
def _spanner_transaction_started(self):
"""Flag: whether transaction started at Spanner. This means that we had
made atleast one call to Spanner. Property client_transaction_started
would always be true if this is true as transaction has to start first
at clientside than at Spanner

Returns:
bool: True if Spanner transaction started, False otherwise.
"""
@deprecated(
reason="This method is deprecated. Use _spanner_transaction_started field"
)
def inside_transaction(self):
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
) or (self._snapshot is not None)
olavloite marked this conversation as resolved.
Show resolved Hide resolved

@property
def inside_transaction(self):
"""Deprecated property which won't be supported in future versions.
Please use spanner_transaction_started property instead."""
return self._spanner_transaction_started
)

@property
def _client_transaction_started(self):
Expand Down Expand Up @@ -277,7 +270,8 @@ def _release_session(self):
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
self.database._pool.put(self._session)
if self._session is not None:
self.database._pool.put(self._session)
self._session = None

def retry_transaction(self):
Expand All @@ -293,7 +287,7 @@ def retry_transaction(self):
"""
attempt = 0
while True:
self._transaction = None
self._spanner_transaction_started = False
attempt += 1
if attempt > MAX_INTERNAL_RETRIES:
raise
Expand All @@ -319,7 +313,6 @@ def _rerun_previous_statements(self):
status, res = transaction.batch_update(statements)

if status.code == ABORTED:
self.connection._transaction = None
raise Aborted(status.details)

retried_checksum = ResultsChecksum()
Expand Down Expand Up @@ -363,6 +356,8 @@ def transaction_checkout(self):
if not self.read_only and self._client_transaction_started:
if not self._spanner_transaction_started:
self._transaction = self._session_checkout().transaction()
self._snapshot = None
self._spanner_transaction_started = True
self._transaction.begin()

return self._transaction
Expand All @@ -377,11 +372,13 @@ def snapshot_checkout(self):
:returns: A Cloud Spanner snapshot object, ready to use.
"""
if self.read_only and self._client_transaction_started:
if not self._snapshot:
if not self._spanner_transaction_started:
self._snapshot = Snapshot(
self._session_checkout(), multi_use=True, **self.staleness
)
self._transaction = None
self._snapshot.begin()
olavloite marked this conversation as resolved.
Show resolved Hide resolved
self._spanner_transaction_started = True

return self._snapshot

Expand All @@ -391,7 +388,7 @@ def close(self):
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
if self._spanner_transaction_started and not self.read_only:
if self._spanner_transaction_started and not self._read_only:
self._transaction.rollback()

if self._own_pool and self.database:
Expand All @@ -405,13 +402,15 @@ def begin(self):
Marks the transaction as started.

:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
:raises: :class:`OperationalError`: if there is an existing transaction
that has been started
"""
if self._transaction_begin_marked:
raise OperationalError("A transaction has already started")
if self._spanner_transaction_started:
raise OperationalError(
"Beginning a new transaction is not allowed when a transaction is already running"
"Beginning a new transaction is not allowed when a transaction "
"is already running"
)
self._transaction_begin_marked = True

Expand All @@ -430,41 +429,37 @@ def commit(self):
return

self.run_prior_DDL_statements()
if self._spanner_transaction_started:
try:
if self.read_only:
self._snapshot = None
else:
self._transaction.commit()

self._release_session()
self._statements = []
self._transaction_begin_marked = False
except Aborted:
self.retry_transaction()
self.commit()
try:
if self._spanner_transaction_started and not self._read_only:
self._transaction.commit()
except Aborted:
self.retry_transaction()
self.commit()
finally:
self._release_session()
self._statements = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False

def rollback(self):
"""Rolls back any pending transaction.

This is a no-op if there is no active client transaction.
"""

if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
return

if self._spanner_transaction_started:
if self.read_only:
self._snapshot = None
else:
try:
if self._spanner_transaction_started and not self._read_only:
self._transaction.rollback()

finally:
self._release_session()
self._statements = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False

@check_not_closed
def cursor(self):
Expand Down
Loading