Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementation for Begin and Rollback clientside statements #1041

Merged
merged 13 commits into from
Dec 4, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,15 @@ def execute(connection, parsed_statement: ParsedStatement):
It is an internal method that can make backwards-incompatible changes.
:type connection: Connection
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
:param connection: Connection object of the dbApi
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
return connection.commit()
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
return connection.begin()
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
return connection.rollback()
10 changes: 10 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
ClientSideStatementType,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)


def parse_stmt(query):
Expand All @@ -39,4 +41,12 @@ def parse_stmt(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
)
if RE_BEGIN.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
)
if RE_ROLLBACK.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK
)
return None
71 changes: 53 additions & 18 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from google.rpc.code_pb2 import ABORTED


AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as transaction has not begun"
)
MAX_INTERNAL_RETRIES = 50


Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(self, instance, database=None, read_only=False):
self._read_only = read_only
self._staleness = None
self.request_priority = None
self._transaction_begin_marked = False

@property
def autocommit(self):
Expand All @@ -122,7 +125,7 @@ def autocommit(self, value):
:type value: bool
:param value: New autocommit mode state.
"""
if value and not self._autocommit and self.inside_transaction:
if value and not self._autocommit and self.spanner_transaction_started:
self.commit()

self._autocommit = value
Expand All @@ -137,18 +140,30 @@ def database(self):
return self._database

@property
def inside_transaction(self):
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
"""Flag: transaction is started.
def spanner_transaction_started(self):
"""Flag: whether transaction started at SpanFE. This means that we had
made atleast one call to SpanFE. Property client_transaction_started
would always be true if this is true as transaction has to start first
at clientside than at Spanner (SpanFE)
ankiaga marked this conversation as resolved.
Show resolved Hide resolved

Returns:
bool: True if transaction begun, False otherwise.
bool: True if SpanFE transaction started, False otherwise.
"""
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
)

@property
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
def client_transaction_started(self):
"""Flag: whether transaction started at client side.

Returns:
bool: True if transaction begun, False otherwise.
"""
return (not self._autocommit) or self._transaction_begin_marked
olavloite marked this conversation as resolved.
Show resolved Hide resolved

@property
def instance(self):
"""Instance to which this connection relates.
Expand All @@ -175,7 +190,7 @@ def read_only(self, value):
Args:
value (bool): True for ReadOnly mode, False for ReadWrite.
"""
if self.inside_transaction:
if self.spanner_transaction_started:
raise ValueError(
"Connection read/write mode can't be changed while a transaction is in progress. "
"Commit or rollback the current transaction and try again."
Expand Down Expand Up @@ -213,7 +228,7 @@ def staleness(self, value):
Args:
value (dict): Staleness type and value.
"""
if self.inside_transaction:
if self.spanner_transaction_started:
raise ValueError(
"`staleness` option can't be changed while a transaction is in progress. "
"Commit or rollback the current transaction and try again."
Expand Down Expand Up @@ -333,13 +348,11 @@ def transaction_checkout(self):
Begin a new transaction, if there is no transaction in
this connection yet. Return the begun one otherwise.

The method is non operational in autocommit mode.
ankiaga marked this conversation as resolved.
Show resolved Hide resolved

:rtype: :class:`google.cloud.spanner_v1.transaction.Transaction`
:returns: A Cloud Spanner transaction object, ready to use.
"""
if not self.autocommit:
if not self.inside_transaction:
if self.client_transaction_started:
if not self.spanner_transaction_started:
self._transaction = self._session_checkout().transaction()
self._transaction.begin()

Expand All @@ -354,7 +367,7 @@ def snapshot_checkout(self):
:rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot`
:returns: A Cloud Spanner snapshot object, ready to use.
"""
if self.read_only and not self.autocommit:
if self.read_only and self.client_transaction_started:
if not self._snapshot:
self._snapshot = Snapshot(
self._session_checkout(), multi_use=True, **self.staleness
Expand All @@ -369,14 +382,30 @@ def close(self):
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
if self.inside_transaction:
if self.spanner_transaction_started:
self._transaction.rollback()
ankiaga marked this conversation as resolved.
Show resolved Hide resolved

if self._own_pool and self.database:
self.database._pool.clear()

self.is_closed = True

@check_not_closed
def begin(self):
"""
Marks the transaction as started.
olavloite marked this conversation as resolved.
Show resolved Hide resolved

:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
"""
if self._transaction_begin_marked:
raise OperationalError("A transaction has already begun")
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
if self.spanner_transaction_started:
raise OperationalError(
"Beginning a new transaction is not allowed when a transaction is already running"
)
self._transaction_begin_marked = True

def commit(self):
"""Commits any pending transaction to the database.

Expand All @@ -386,18 +415,21 @@ def commit(self):
raise ValueError("Database needs to be passed for this operation")
self._snapshot = None

if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
if not self.client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
return

self.run_prior_DDL_statements()
if self.inside_transaction:
if self.spanner_transaction_started:
try:
if not self.read_only:
self._transaction.commit()

self._release_session()
self._statements = []
self._transaction_begin_marked = False
except Aborted:
self.retry_transaction()
self.commit()
Expand All @@ -410,14 +442,17 @@ def rollback(self):
"""
self._snapshot = None

if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
if not self.client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
elif self._transaction:
if not self.read_only:
self._transaction.rollback()

self._release_session()
self._statements = []
self._transaction_begin_marked = False

@check_not_closed
def cursor(self):
Expand Down
23 changes: 16 additions & 7 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def execute(self, sql, args=None):
)
if parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if self.connection.autocommit:
if not self.connection.client_transaction_started:
self.connection.run_prior_DDL_statements()
return

Expand All @@ -264,7 +264,7 @@ def execute(self, sql, args=None):

sql, args = sql_pyformat_args_to_spanner(sql, args or None)

if not self.connection.autocommit:
if self.connection.client_transaction_started:
statement = Statement(
sql,
args,
Expand Down Expand Up @@ -348,7 +348,7 @@ def executemany(self, operation, seq_of_params):
)
statements.append((sql, params, get_param_types(params)))

if self.connection.autocommit:
if not self.connection.client_transaction_started:
self.connection.database.run_in_transaction(
self._do_batch_update, statements, many_result_set
)
Expand Down Expand Up @@ -396,7 +396,10 @@ def fetchone(self):
sequence, or None when no more data is available."""
try:
res = next(self)
if not self.connection.autocommit and not self.connection.read_only:
if (
self.connection.client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
return res
except StopIteration:
Expand All @@ -414,7 +417,10 @@ def fetchall(self):
res = []
try:
for row in self:
if not self.connection.autocommit and not self.connection.read_only:
if (
self.connection.client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(row)
res.append(row)
except Aborted:
Expand Down Expand Up @@ -443,7 +449,10 @@ def fetchmany(self, size=None):
for _ in range(size):
try:
res = next(self)
if not self.connection.autocommit and not self.connection.read_only:
if (
self.connection.client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
Expand Down Expand Up @@ -473,7 +482,7 @@ def _handle_DQL(self, sql, params):
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
if self.connection.read_only and not self.connection.autocommit:
if self.connection.read_only and self.connection.client_transaction_started:
# initiate or use the existing multi-use snapshot
self._handle_DQL_with_snapshot(
self.connection.snapshot_checkout(), sql, params
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class StatementType(Enum):
class ClientSideStatementType(Enum):
COMMIT = 1
BEGIN = 2
ROLLBACK = 3


@dataclass
Expand Down
Loading
Loading