Skip to content

Commit 7a92315

Browse files
authored
feat: Implementation for batch dml in dbapi (#1055)
* feat: Implementation for batch dml in dbapi * Few changes * Incorporated comments
1 parent c70d7da commit 7a92315

11 files changed

+574
-122
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2023 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from enum import Enum
18+
from typing import TYPE_CHECKING, List
19+
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
20+
from google.cloud.spanner_dbapi.parsed_statement import (
21+
ParsedStatement,
22+
StatementType,
23+
Statement,
24+
)
25+
from google.rpc.code_pb2 import ABORTED, OK
26+
from google.api_core.exceptions import Aborted
27+
28+
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
29+
30+
if TYPE_CHECKING:
31+
from google.cloud.spanner_dbapi.cursor import Cursor
32+
33+
34+
class BatchDmlExecutor:
35+
"""Executor that is used when a DML batch is started. These batches only
36+
accept DML statements. All DML statements are buffered locally and sent to
37+
Spanner when runBatch() is called.
38+
39+
:type "Cursor": :class:`~google.cloud.spanner_dbapi.cursor.Cursor`
40+
:param cursor:
41+
"""
42+
43+
def __init__(self, cursor: "Cursor"):
44+
self._cursor = cursor
45+
self._connection = cursor.connection
46+
self._statements: List[Statement] = []
47+
48+
def execute_statement(self, parsed_statement: ParsedStatement):
49+
"""Executes the statement when dml batch is active by buffering the
50+
statement in-memory.
51+
52+
:type parsed_statement: ParsedStatement
53+
:param parsed_statement: parsed statement containing sql query and query
54+
params
55+
"""
56+
from google.cloud.spanner_dbapi import ProgrammingError
57+
58+
if (
59+
parsed_statement.statement_type != StatementType.UPDATE
60+
and parsed_statement.statement_type != StatementType.INSERT
61+
):
62+
raise ProgrammingError("Only DML statements are allowed in batch DML mode.")
63+
self._statements.append(parsed_statement.statement)
64+
65+
def run_batch_dml(self):
66+
"""Executes all the buffered statements on the active dml batch by
67+
making a call to Spanner.
68+
"""
69+
return run_batch_dml(self._cursor, self._statements)
70+
71+
72+
def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
73+
"""Executes all the dml statements by making a batch call to Spanner.
74+
75+
:type cursor: Cursor
76+
:param cursor: Database Cursor object
77+
78+
:type statements: List[Statement]
79+
:param statements: list of statements to execute in batch
80+
"""
81+
from google.cloud.spanner_dbapi import OperationalError
82+
83+
connection = cursor.connection
84+
many_result_set = StreamedManyResultSets()
85+
statements_tuple = []
86+
for statement in statements:
87+
statements_tuple.append(statement.get_tuple())
88+
if not connection._client_transaction_started:
89+
res = connection.database.run_in_transaction(_do_batch_update, statements_tuple)
90+
many_result_set.add_iter(res)
91+
cursor._row_count = sum([max(val, 0) for val in res])
92+
else:
93+
retried = False
94+
while True:
95+
try:
96+
transaction = connection.transaction_checkout()
97+
status, res = transaction.batch_update(statements_tuple)
98+
many_result_set.add_iter(res)
99+
res_checksum = ResultsChecksum()
100+
res_checksum.consume_result(res)
101+
res_checksum.consume_result(status.code)
102+
if not retried:
103+
connection._statements.append((statements, res_checksum))
104+
cursor._row_count = sum([max(val, 0) for val in res])
105+
106+
if status.code == ABORTED:
107+
connection._transaction = None
108+
raise Aborted(status.message)
109+
elif status.code != OK:
110+
raise OperationalError(status.message)
111+
return many_result_set
112+
except Aborted:
113+
connection.retry_transaction()
114+
retried = True
115+
116+
117+
def _do_batch_update(transaction, statements):
118+
from google.cloud.spanner_dbapi import OperationalError
119+
120+
status, res = transaction.batch_update(statements)
121+
if status.code == ABORTED:
122+
raise Aborted(status.message)
123+
elif status.code != OK:
124+
raise OperationalError(status.message)
125+
return res
126+
127+
128+
class BatchMode(Enum):
129+
DML = 1
130+
DDL = 2
131+
NONE = 3

google/cloud/spanner_dbapi/client_side_statement_executor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import TYPE_CHECKING
1515

1616
if TYPE_CHECKING:
17-
from google.cloud.spanner_dbapi import Connection
17+
from google.cloud.spanner_dbapi.cursor import Cursor
1818
from google.cloud.spanner_dbapi import ProgrammingError
1919

2020
from google.cloud.spanner_dbapi.parsed_statement import (
@@ -38,17 +38,18 @@
3838
)
3939

4040

41-
def execute(connection: "Connection", parsed_statement: ParsedStatement):
41+
def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
4242
"""Executes the client side statements by calling the relevant method.
4343
4444
It is an internal method that can make backwards-incompatible changes.
4545
46-
:type connection: Connection
47-
:param connection: Connection object of the dbApi
46+
:type cursor: Cursor
47+
:param cursor: Cursor object of the dbApi
4848
4949
:type parsed_statement: ParsedStatement
5050
:param parsed_statement: parsed_statement based on the sql query
5151
"""
52+
connection = cursor.connection
5253
if connection.is_closed:
5354
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
5455
statement_type = parsed_statement.client_side_statement_type
@@ -81,6 +82,13 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
8182
TypeCode.TIMESTAMP,
8283
read_timestamp,
8384
)
85+
if statement_type == ClientSideStatementType.START_BATCH_DML:
86+
connection.start_batch_dml(cursor)
87+
return None
88+
if statement_type == ClientSideStatementType.RUN_BATCH:
89+
return connection.run_batch()
90+
if statement_type == ClientSideStatementType.ABORT_BATCH:
91+
return connection.abort_batch()
8492

8593

8694
def _get_streamed_result_set(column_name, type_code, column_value):

google/cloud/spanner_dbapi/client_side_statement_parser.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ParsedStatement,
1919
StatementType,
2020
ClientSideStatementType,
21+
Statement,
2122
)
2223

2324
RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
@@ -29,6 +30,9 @@
2930
RE_SHOW_READ_TIMESTAMP = re.compile(
3031
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
3132
)
33+
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
34+
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
35+
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)
3236

3337

3438
def parse_stmt(query):
@@ -54,8 +58,14 @@ def parse_stmt(query):
5458
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
5559
if RE_SHOW_READ_TIMESTAMP.match(query):
5660
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
61+
if RE_START_BATCH_DML.match(query):
62+
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
63+
if RE_RUN_BATCH.match(query):
64+
client_side_statement_type = ClientSideStatementType.RUN_BATCH
65+
if RE_ABORT_BATCH.match(query):
66+
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
5767
if client_side_statement_type is not None:
5868
return ParsedStatement(
59-
StatementType.CLIENT_SIDE, query, client_side_statement_type
69+
StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type
6070
)
6171
return None

google/cloud/spanner_dbapi/connection.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
# limitations under the License.
1414

1515
"""DB-API Connection for the Google Cloud Spanner."""
16-
1716
import time
1817
import warnings
1918

2019
from google.api_core.exceptions import Aborted
2120
from google.api_core.gapic_v1.client_info import ClientInfo
2221
from google.cloud import spanner_v1 as spanner
22+
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
23+
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
2324
from google.cloud.spanner_v1 import RequestOptions
2425
from google.cloud.spanner_v1.session import _get_retry_delay
2526
from google.cloud.spanner_v1.snapshot import Snapshot
@@ -28,7 +29,11 @@
2829
from google.cloud.spanner_dbapi.checksum import _compare_checksums
2930
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
3031
from google.cloud.spanner_dbapi.cursor import Cursor
31-
from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError
32+
from google.cloud.spanner_dbapi.exceptions import (
33+
InterfaceError,
34+
OperationalError,
35+
ProgrammingError,
36+
)
3237
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
3338
from google.cloud.spanner_dbapi.version import PY_VERSION
3439

@@ -111,6 +116,8 @@ def __init__(self, instance, database=None, read_only=False):
111116
# whether transaction started at Spanner. This means that we had
112117
# made atleast one call to Spanner.
113118
self._spanner_transaction_started = False
119+
self._batch_mode = BatchMode.NONE
120+
self._batch_dml_executor: BatchDmlExecutor = None
114121

115122
@property
116123
def autocommit(self):
@@ -310,7 +317,10 @@ def _rerun_previous_statements(self):
310317
statements, checksum = statement
311318

312319
transaction = self.transaction_checkout()
313-
status, res = transaction.batch_update(statements)
320+
statements_tuple = []
321+
for single_statement in statements:
322+
statements_tuple.append(single_statement.get_tuple())
323+
status, res = transaction.batch_update(statements_tuple)
314324

315325
if status.code == ABORTED:
316326
raise Aborted(status.details)
@@ -476,14 +486,14 @@ def run_prior_DDL_statements(self):
476486

477487
return self.database.update_ddl(ddl_statements).result()
478488

479-
def run_statement(self, statement, retried=False):
489+
def run_statement(self, statement: Statement, retried=False):
480490
"""Run single SQL statement in begun transaction.
481491
482492
This method is never used in autocommit mode. In
483493
!autocommit mode however it remembers every executed
484494
SQL statement with its parameters.
485495
486-
:type statement: :class:`dict`
496+
:type statement: :class:`Statement`
487497
:param statement: SQL statement to execute.
488498
489499
:type retried: bool
@@ -534,6 +544,47 @@ def validate(self):
534544
"Expected: [[1]]" % result
535545
)
536546

547+
@check_not_closed
548+
def start_batch_dml(self, cursor):
549+
if self._batch_mode is not BatchMode.NONE:
550+
raise ProgrammingError(
551+
"Cannot start a DML batch when a batch is already active"
552+
)
553+
if self.read_only:
554+
raise ProgrammingError(
555+
"Cannot start a DML batch when the connection is in read-only mode"
556+
)
557+
self._batch_mode = BatchMode.DML
558+
self._batch_dml_executor = BatchDmlExecutor(cursor)
559+
560+
@check_not_closed
561+
def execute_batch_dml_statement(self, parsed_statement: ParsedStatement):
562+
if self._batch_mode is not BatchMode.DML:
563+
raise ProgrammingError(
564+
"Cannot execute statement when the BatchMode is not DML"
565+
)
566+
self._batch_dml_executor.execute_statement(parsed_statement)
567+
568+
@check_not_closed
569+
def run_batch(self):
570+
if self._batch_mode is BatchMode.NONE:
571+
raise ProgrammingError("Cannot run a batch when the BatchMode is not set")
572+
try:
573+
if self._batch_mode is BatchMode.DML:
574+
many_result_set = self._batch_dml_executor.run_batch_dml()
575+
finally:
576+
self._batch_mode = BatchMode.NONE
577+
self._batch_dml_executor = None
578+
return many_result_set
579+
580+
@check_not_closed
581+
def abort_batch(self):
582+
if self._batch_mode is BatchMode.NONE:
583+
raise ProgrammingError("Cannot abort a batch when the BatchMode is not set")
584+
if self._batch_mode is BatchMode.DML:
585+
self._batch_dml_executor = None
586+
self._batch_mode = BatchMode.NONE
587+
537588
def __enter__(self):
538589
return self
539590

0 commit comments

Comments
 (0)