Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementation for batch dml in dbapi #1055

Merged
merged 4 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions google/cloud/spanner_dbapi/batch_dml_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
olavloite marked this conversation as resolved.
Show resolved Hide resolved

from enum import Enum
from typing import TYPE_CHECKING, List
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
Statement,
)
from google.rpc.code_pb2 import ABORTED, OK
from google.api_core.exceptions import Aborted

from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

if TYPE_CHECKING:
from google.cloud.spanner_dbapi.cursor import Cursor


class BatchDmlExecutor:
"""Executor that is used when a DML batch is started. These batches only
accept DML statements. All DML statements are buffered locally and sent to
Spanner when runBatch() is called.

:type "Cursor": :class:`~google.cloud.spanner_dbapi.cursor.Cursor`
:param cursor:
"""

def __init__(self, cursor: "Cursor"):
self._cursor = cursor
self._connection = cursor.connection
self._statements: List[Statement] = []

def execute_statement(self, parsed_statement: ParsedStatement):
"""Executes the statement when dml batch is active by buffering the
statement in-memory.

:type parsed_statement: ParsedStatement
:param parsed_statement: parsed statement containing sql query and query
params
"""
from google.cloud.spanner_dbapi import ProgrammingError

if (
parsed_statement.statement_type != StatementType.UPDATE
and parsed_statement.statement_type != StatementType.INSERT
):
raise ProgrammingError("Only DML statements are allowed in batch DML mode.")
self._statements.append(parsed_statement.statement)

def run_batch_dml(self):
"""Executes all the buffered statements on the active dml batch by
making a call to Spanner.
"""
return run_batch_dml(self._cursor, self._statements)


def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
"""Executes all the dml statements by making a batch call to Spanner.

:type cursor: Cursor
:param cursor: Database Cursor object

:type statements: List[Statement]
:param statements: list of statements to execute in batch
"""
from google.cloud.spanner_dbapi import OperationalError

connection = cursor.connection
olavloite marked this conversation as resolved.
Show resolved Hide resolved
many_result_set = StreamedManyResultSets()
statements_tuple = []
for statement in statements:
statements_tuple.append(statement.get_tuple())
if not connection._client_transaction_started:
res = connection.database.run_in_transaction(_do_batch_update, statements_tuple)
many_result_set.add_iter(res)
cursor._row_count = sum([max(val, 0) for val in res])
else:
retried = False
while True:
try:
transaction = connection.transaction_checkout()
status, res = transaction.batch_update(statements_tuple)
many_result_set.add_iter(res)
res_checksum = ResultsChecksum()
res_checksum.consume_result(res)
res_checksum.consume_result(status.code)
if not retried:
connection._statements.append((statements, res_checksum))
cursor._row_count = sum([max(val, 0) for val in res])

if status.code == ABORTED:
connection._transaction = None
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should (could) this also include the status code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take it in a follow up PR

return many_result_set
except Aborted:
connection.retry_transaction()
retried = True


def _do_batch_update(transaction, statements):
from google.cloud.spanner_dbapi import OperationalError

status, res = transaction.batch_update(statements)
if status.code == ABORTED:
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
return res


class BatchMode(Enum):
DML = 1
DDL = 2
NONE = 3
16 changes: 12 additions & 4 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi import ProgrammingError

from google.cloud.spanner_dbapi.parsed_statement import (
Expand All @@ -38,17 +38,18 @@
)


def execute(connection: "Connection", parsed_statement: ParsedStatement):
def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
"""Executes the client side statements by calling the relevant method.

It is an internal method that can make backwards-incompatible changes.

:type connection: Connection
:param connection: Connection object of the dbApi
:type cursor: Cursor
:param cursor: Cursor object of the dbApi

:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
connection = cursor.connection
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
statement_type = parsed_statement.client_side_statement_type
Expand Down Expand Up @@ -81,6 +82,13 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
TypeCode.TIMESTAMP,
read_timestamp,
)
if statement_type == ClientSideStatementType.START_BATCH_DML:
connection.start_batch_dml(cursor)
return None
if statement_type == ClientSideStatementType.RUN_BATCH:
return connection.run_batch()
if statement_type == ClientSideStatementType.ABORT_BATCH:
return connection.abort_batch()


def _get_streamed_result_set(column_name, type_code, column_value):
Expand Down
12 changes: 11 additions & 1 deletion google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParsedStatement,
StatementType,
ClientSideStatementType,
Statement,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
Expand All @@ -29,6 +30,9 @@
RE_SHOW_READ_TIMESTAMP = re.compile(
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
)
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
olavloite marked this conversation as resolved.
Show resolved Hide resolved
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)


def parse_stmt(query):
Expand All @@ -54,8 +58,14 @@ def parse_stmt(query):
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
if RE_SHOW_READ_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
if RE_START_BATCH_DML.match(query):
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
if RE_RUN_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.RUN_BATCH
if RE_ABORT_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE, query, client_side_statement_type
StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type
)
return None
61 changes: 56 additions & 5 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

"""DB-API Connection for the Google Cloud Spanner."""

import time
import warnings

from google.api_core.exceptions import Aborted
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
Expand All @@ -28,7 +29,11 @@
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError
from google.cloud.spanner_dbapi.exceptions import (
InterfaceError,
OperationalError,
ProgrammingError,
)
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

Expand Down Expand Up @@ -111,6 +116,8 @@ def __init__(self, instance, database=None, read_only=False):
# whether transaction started at Spanner. This means that we had
# made atleast one call to Spanner.
self._spanner_transaction_started = False
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None

@property
def autocommit(self):
Expand Down Expand Up @@ -310,7 +317,10 @@ def _rerun_previous_statements(self):
statements, checksum = statement

transaction = self.transaction_checkout()
status, res = transaction.batch_update(statements)
statements_tuple = []
for single_statement in statements:
statements_tuple.append(single_statement.get_tuple())
status, res = transaction.batch_update(statements_tuple)

if status.code == ABORTED:
raise Aborted(status.details)
Expand Down Expand Up @@ -476,14 +486,14 @@ def run_prior_DDL_statements(self):

return self.database.update_ddl(ddl_statements).result()

def run_statement(self, statement, retried=False):
def run_statement(self, statement: Statement, retried=False):
"""Run single SQL statement in begun transaction.

This method is never used in autocommit mode. In
!autocommit mode however it remembers every executed
SQL statement with its parameters.

:type statement: :class:`dict`
:type statement: :class:`Statement`
:param statement: SQL statement to execute.

:type retried: bool
Expand Down Expand Up @@ -534,6 +544,47 @@ def validate(self):
"Expected: [[1]]" % result
)

@check_not_closed
def start_batch_dml(self, cursor):
if self._batch_mode is not BatchMode.NONE:
raise ProgrammingError(
"Cannot start a DML batch when a batch is already active"
)
if self.read_only:
raise ProgrammingError(
"Cannot start a DML batch when the connection is in read-only mode"
)
self._batch_mode = BatchMode.DML
self._batch_dml_executor = BatchDmlExecutor(cursor)

@check_not_closed
def execute_batch_dml_statement(self, parsed_statement: ParsedStatement):
if self._batch_mode is not BatchMode.DML:
raise ProgrammingError(
"Cannot execute statement when the BatchMode is not DML"
)
self._batch_dml_executor.execute_statement(parsed_statement)

@check_not_closed
def run_batch(self):
if self._batch_mode is BatchMode.NONE:
raise ProgrammingError("Cannot run a batch when the BatchMode is not set")
try:
if self._batch_mode is BatchMode.DML:
many_result_set = self._batch_dml_executor.run_batch_dml()
finally:
self._batch_mode = BatchMode.NONE
self._batch_dml_executor = None
return many_result_set

@check_not_closed
def abort_batch(self):
if self._batch_mode is BatchMode.NONE:
raise ProgrammingError("Cannot abort a batch when the BatchMode is not set")
if self._batch_mode is BatchMode.DML:
self._batch_dml_executor = None
self._batch_mode = BatchMode.NONE

def __enter__(self):
return self

Expand Down
Loading