Skip to content

Commit

Permalink
feat: Implementation for Begin and Rollback clientside statements
Browse files Browse the repository at this point in the history
  • Loading branch information
ankiaga committed Nov 24, 2023
1 parent 5fb5610 commit 3e80473
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 39 deletions.
7 changes: 7 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,15 @@ def execute(connection, parsed_statement: ParsedStatement):
It is an internal method that can make backwards-incompatible changes.
:type connection: Connection
:param connection: Connection object of the dbApi
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
return connection.commit()
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
return connection.begin()
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
return connection.rollback()
10 changes: 10 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
ClientSideStatementType,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)


def parse_stmt(query):
Expand All @@ -39,4 +41,12 @@ def parse_stmt(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
)
if RE_BEGIN.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
)
if RE_ROLLBACK.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK
)
return None
48 changes: 38 additions & 10 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from google.rpc.code_pb2 import ABORTED


AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
TRANSACTION_NOT_BEGUN_WARNING = (
"This method is non-operational as transaction has not begun"
)
MAX_INTERNAL_RETRIES = 50


Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(self, instance, database=None, read_only=False):
self._read_only = read_only
self._staleness = None
self.request_priority = None
self._transaction_begin_marked = False

@property
def autocommit(self):
Expand Down Expand Up @@ -141,14 +144,23 @@ def inside_transaction(self):
"""Flag: transaction is started.
Returns:
bool: True if transaction begun, False otherwise.
bool: True if transaction started, False otherwise.
"""
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
)

@property
def transaction_begun(self):
"""Flag: transaction has begun
Returns:
bool: True if transaction begun, False otherwise.
"""
return (not self._autocommit) or self._transaction_begin_marked

@property
def instance(self):
"""Instance to which this connection relates.
Expand Down Expand Up @@ -333,12 +345,10 @@ def transaction_checkout(self):
Begin a new transaction, if there is no transaction in
this connection yet. Return the begun one otherwise.
The method is non operational in autocommit mode.
:rtype: :class:`google.cloud.spanner_v1.transaction.Transaction`
:returns: A Cloud Spanner transaction object, ready to use.
"""
if not self.autocommit:
if self.transaction_begun:
if not self.inside_transaction:
self._transaction = self._session_checkout().transaction()
self._transaction.begin()
Expand All @@ -354,7 +364,7 @@ def snapshot_checkout(self):
:rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot`
:returns: A Cloud Spanner snapshot object, ready to use.
"""
if self.read_only and not self.autocommit:
if self.read_only and self.transaction_begun:
if not self._snapshot:
self._snapshot = Snapshot(
self._session_checkout(), multi_use=True, **self.staleness
Expand All @@ -377,6 +387,22 @@ def close(self):

self.is_closed = True

@check_not_closed
def begin(self):
"""
Marks the transaction as started.
:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
"""
if self._transaction_begin_marked:
raise OperationalError("A transaction has already begun")
if self.inside_transaction:
raise OperationalError(
"Beginning a new transaction is not allowed when a transaction is already running"
)
self._transaction_begin_marked = True

def commit(self):
"""Commits any pending transaction to the database.
Expand All @@ -386,8 +412,8 @@ def commit(self):
raise ValueError("Database needs to be passed for this operation")
self._snapshot = None

if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
if not self.transaction_begun:
warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2)
return

self.run_prior_DDL_statements()
Expand All @@ -398,6 +424,7 @@ def commit(self):

self._release_session()
self._statements = []
self._transaction_begin_marked = False
except Aborted:
self.retry_transaction()
self.commit()
Expand All @@ -410,14 +437,15 @@ def rollback(self):
"""
self._snapshot = None

if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
if not self.transaction_begun:
warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
if not self.read_only:
self._transaction.rollback()

self._release_session()
self._statements = []
self._transaction_begin_marked = False

@check_not_closed
def cursor(self):
Expand Down
14 changes: 7 additions & 7 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def execute(self, sql, args=None):
)
if parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if self.connection.autocommit:
if not self.connection.transaction_begun:
self.connection.run_prior_DDL_statements()
return

Expand All @@ -264,7 +264,7 @@ def execute(self, sql, args=None):

sql, args = sql_pyformat_args_to_spanner(sql, args or None)

if not self.connection.autocommit:
if self.connection.transaction_begun:
statement = Statement(
sql,
args,
Expand Down Expand Up @@ -348,7 +348,7 @@ def executemany(self, operation, seq_of_params):
)
statements.append((sql, params, get_param_types(params)))

if self.connection.autocommit:
if not self.connection.transaction_begun:
self.connection.database.run_in_transaction(
self._do_batch_update, statements, many_result_set
)
Expand Down Expand Up @@ -396,7 +396,7 @@ def fetchone(self):
sequence, or None when no more data is available."""
try:
res = next(self)
if not self.connection.autocommit and not self.connection.read_only:
if self.connection.transaction_begun and not self.connection.read_only:
self._checksum.consume_result(res)
return res
except StopIteration:
Expand All @@ -414,7 +414,7 @@ def fetchall(self):
res = []
try:
for row in self:
if not self.connection.autocommit and not self.connection.read_only:
if self.connection.transaction_begun and not self.connection.read_only:
self._checksum.consume_result(row)
res.append(row)
except Aborted:
Expand Down Expand Up @@ -443,7 +443,7 @@ def fetchmany(self, size=None):
for _ in range(size):
try:
res = next(self)
if not self.connection.autocommit and not self.connection.read_only:
if self.connection.transaction_begun and not self.connection.read_only:
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
Expand Down Expand Up @@ -473,7 +473,7 @@ def _handle_DQL(self, sql, params):
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
if self.connection.read_only and not self.connection.autocommit:
if self.connection.read_only and self.connection.transaction_begun:
# initiate or use the existing multi-use snapshot
self._handle_DQL_with_snapshot(
self.connection.snapshot_checkout(), sql, params
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class StatementType(Enum):
class ClientSideStatementType(Enum):
COMMIT = 1
BEGIN = 2
ROLLBACK = 3


@dataclass
Expand Down
103 changes: 85 additions & 18 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from google.cloud._helpers import UTC

from google.cloud.spanner_dbapi.connection import Connection, connect
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
from google.cloud.spanner_dbapi.exceptions import ProgrammingError, OperationalError
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1 import gapic_version as package_version
from . import _helpers
Expand Down Expand Up @@ -80,42 +80,43 @@ def init_connection(self, request, shared_instance, dbapi_database):
self._cursor.close()
self._conn.close()

@pytest.fixture
def execute_common_statements(self):
def _execute_common_statements(self, cursor):
# execute several DML statements within one transaction
self._cursor.execute(
cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru')
"""
)
self._cursor.execute(
cursor.execute(
"""
UPDATE contacts
SET first_name = 'updated-first-name'
WHERE first_name = 'first-name'
"""
)
self._cursor.execute(
cursor.execute(
"""
UPDATE contacts
SET email = 'test.email_updated@domen.ru'
WHERE email = 'test.email@domen.ru'
"""
)

@pytest.fixture
def updated_row(self, execute_common_statements):
return (
1,
"updated-first-name",
"last-name",
"test.email_updated@domen.ru",
)

def test_commit(self, updated_row):
@pytest.mark.parametrize("client_side", [False, True])
def test_commit(self, client_side):
"""Test committing a transaction with several statements."""
self._conn.commit()
updated_row = self._execute_common_statements(self._cursor)
if client_side:
self._cursor.execute("""COMMIT""")
else:
self._conn.commit()

# read the resulting data from the database
self._cursor.execute("SELECT * FROM contacts")
Expand All @@ -124,18 +125,80 @@ def test_commit(self, updated_row):

assert got_rows == [updated_row]

def test_commit_client_side(self, updated_row):
"""Test committing a transaction with several statements."""
self._cursor.execute("""COMMIT""")
@pytest.mark.noautofixt
def test_begin_client_side(self, shared_instance, dbapi_database):
"""Test beginning a transaction using client side statement,
where connection is in autocommit mode."""

conn1 = Connection(shared_instance, dbapi_database)
conn1.autocommit = True
cursor1 = conn1.cursor()
cursor1.execute("begin transaction")
updated_row = self._execute_common_statements(cursor1)

# As the connection conn1 is not committed a new connection wont see its results
conn2 = Connection(shared_instance, dbapi_database)
cursor2 = conn2.cursor()
cursor2.execute("SELECT * FROM contacts")
conn2.commit()
got_rows = cursor2.fetchall()
assert got_rows != [updated_row]

assert conn1._transaction_begin_marked is True
conn1.commit()
assert conn1._transaction_begin_marked is False

# As the connection conn1 is committed a new connection should see its results
conn3 = Connection(shared_instance, dbapi_database)
cursor3 = conn3.cursor()
cursor3.execute("SELECT * FROM contacts")
conn3.commit()
got_rows = cursor3.fetchall()
assert got_rows == [updated_row]

# read the resulting data from the database
conn1.close()
conn2.close()
conn3.close()
cursor1.close()
cursor2.close()
cursor3.close()

def test_begin_success_post_commit(self):
"""Test beginning a new transaction post commiting an existing transaction
is possible on a connection, when connection is in autocommit mode."""
want_row = (2, "first-name", "last-name", "test.email@domen.ru")
self._conn.autocommit = True
self._cursor.execute("begin transaction")
self._cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru')
"""
)
self._conn.commit()

self._cursor.execute("begin transaction")
self._cursor.execute("SELECT * FROM contacts")
got_rows = self._cursor.fetchall()
self._conn.commit()
assert got_rows == [want_row]

assert got_rows == [updated_row]
def test_begin_error_before_commit(self):
"""Test beginning a new transaction before commiting an existing transaction is not possible on a connection, when connection is in autocommit mode."""
self._conn.autocommit = True
self._cursor.execute("begin transaction")
self._cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru')
"""
)

with pytest.raises(OperationalError):
self._cursor.execute("begin transaction")

def test_rollback(self):
@pytest.mark.parametrize("client_side", [False, True])
def test_rollback(self, client_side):
"""Test rollbacking a transaction with several statements."""
want_row = (2, "first-name", "last-name", "test.email@domen.ru")

Expand All @@ -162,7 +225,11 @@ def test_rollback(self):
WHERE email = 'test.email@domen.ru'
"""
)
self._conn.rollback()

if client_side:
self._cursor.execute("ROLLBACK")
else:
self._conn.rollback()

# read the resulting data from the database
self._cursor.execute("SELECT * FROM contacts")
Expand Down
Loading

0 comments on commit 3e80473

Please sign in to comment.