Skip to content

Commit 5e013ca

Browse files
committed
feat: Fixing and refactoring transaction retry logic in dbapi. Also adding interceptors support for testing
1 parent 7a92315 commit 5e013ca

File tree

16 files changed

+1561
-851
lines changed

16 files changed

+1561
-851
lines changed

google/cloud/spanner_dbapi/batch_dml_executor.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from enum import Enum
1818
from typing import TYPE_CHECKING, List
19-
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
2019
from google.cloud.spanner_dbapi.parsed_statement import (
2120
ParsedStatement,
2221
StatementType,
@@ -80,8 +79,10 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
8079
"""
8180
from google.cloud.spanner_dbapi import OperationalError
8281

83-
connection = cursor.connection
8482
many_result_set = StreamedManyResultSets()
83+
if not statements:
84+
return many_result_set
85+
connection = cursor.connection
8586
statements_tuple = []
8687
for statement in statements:
8788
statements_tuple.append(statement.get_tuple())
@@ -90,28 +91,24 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
9091
many_result_set.add_iter(res)
9192
cursor._row_count = sum([max(val, 0) for val in res])
9293
else:
93-
retried = False
9494
while True:
9595
try:
9696
transaction = connection.transaction_checkout()
9797
status, res = transaction.batch_update(statements_tuple)
98-
many_result_set.add_iter(res)
99-
res_checksum = ResultsChecksum()
100-
res_checksum.consume_result(res)
101-
res_checksum.consume_result(status.code)
102-
if not retried:
103-
connection._statements.append((statements, res_checksum))
104-
cursor._row_count = sum([max(val, 0) for val in res])
105-
10698
if status.code == ABORTED:
10799
connection._transaction = None
108100
raise Aborted(status.message)
109101
elif status.code != OK:
110102
raise OperationalError(status.message)
103+
104+
many_result_set.add_iter(res)
105+
cursor._row_count = sum([max(val, 0) for val in res])
111106
return many_result_set
112107
except Aborted:
113-
connection.retry_transaction()
114-
retried = True
108+
if cursor._in_retry_mode:
109+
raise
110+
else:
111+
connection._transaction_helper.retry_transaction()
115112

116113

117114
def _do_batch_update(transaction, statements):

google/cloud/spanner_dbapi/checksum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def consume_result(self, result):
6262

6363

6464
def _compare_checksums(original, retried):
65+
from google.cloud.spanner_dbapi.transaction_helper import RETRY_ABORTED_ERROR
66+
6567
"""Compare the given checksums.
6668
6769
Raise an error if the given checksums are not equal.
@@ -75,6 +77,4 @@ def _compare_checksums(original, retried):
7577
:raises: :exc:`google.cloud.spanner_dbapi.exceptions.RetryAborted` in case if checksums are not equal.
7678
"""
7779
if retried != original:
78-
raise RetryAborted(
79-
"The transaction was aborted and could not be retried due to a concurrent modification."
80-
)
80+
raise RetryAborted(RETRY_ABORTED_ERROR)

google/cloud/spanner_dbapi/connection.py

Lines changed: 11 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,18 @@
1313
# limitations under the License.
1414

1515
"""DB-API Connection for the Google Cloud Spanner."""
16-
import time
1716
import warnings
1817

1918
from google.api_core.exceptions import Aborted
2019
from google.api_core.gapic_v1.client_info import ClientInfo
2120
from google.cloud import spanner_v1 as spanner
2221
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
2322
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
23+
from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper
2424
from google.cloud.spanner_v1 import RequestOptions
25-
from google.cloud.spanner_v1.session import _get_retry_delay
2625
from google.cloud.spanner_v1.snapshot import Snapshot
2726
from deprecated import deprecated
2827

29-
from google.cloud.spanner_dbapi.checksum import _compare_checksums
30-
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
3128
from google.cloud.spanner_dbapi.cursor import Cursor
3229
from google.cloud.spanner_dbapi.exceptions import (
3330
InterfaceError,
@@ -37,13 +34,10 @@
3734
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
3835
from google.cloud.spanner_dbapi.version import PY_VERSION
3936

40-
from google.rpc.code_pb2 import ABORTED
41-
4237

4338
CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
4439
"This method is non-operational as a transaction has not been started."
4540
)
46-
MAX_INTERNAL_RETRIES = 50
4741

4842

4943
def check_not_closed(function):
@@ -99,9 +93,6 @@ def __init__(self, instance, database=None, read_only=False):
9993
self._transaction = None
10094
self._session = None
10195
self._snapshot = None
102-
# SQL statements, which were executed
103-
# within the current transaction
104-
self._statements = []
10596

10697
self.is_closed = False
10798
self._autocommit = False
@@ -118,6 +109,7 @@ def __init__(self, instance, database=None, read_only=False):
118109
self._spanner_transaction_started = False
119110
self._batch_mode = BatchMode.NONE
120111
self._batch_dml_executor: BatchDmlExecutor = None
112+
self._transaction_helper = TransactionRetryHelper(self)
121113

122114
@property
123115
def autocommit(self):
@@ -281,76 +273,6 @@ def _release_session(self):
281273
self.database._pool.put(self._session)
282274
self._session = None
283275

284-
def retry_transaction(self):
285-
"""Retry the aborted transaction.
286-
287-
All the statements executed in the original transaction
288-
will be re-executed in new one. Results checksums of the
289-
original statements and the retried ones will be compared.
290-
291-
:raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
292-
If results checksum of the retried statement is
293-
not equal to the checksum of the original one.
294-
"""
295-
attempt = 0
296-
while True:
297-
self._spanner_transaction_started = False
298-
attempt += 1
299-
if attempt > MAX_INTERNAL_RETRIES:
300-
raise
301-
302-
try:
303-
self._rerun_previous_statements()
304-
break
305-
except Aborted as exc:
306-
delay = _get_retry_delay(exc.errors[0], attempt)
307-
if delay:
308-
time.sleep(delay)
309-
310-
def _rerun_previous_statements(self):
311-
"""
312-
Helper to run all the remembered statements
313-
from the last transaction.
314-
"""
315-
for statement in self._statements:
316-
if isinstance(statement, list):
317-
statements, checksum = statement
318-
319-
transaction = self.transaction_checkout()
320-
statements_tuple = []
321-
for single_statement in statements:
322-
statements_tuple.append(single_statement.get_tuple())
323-
status, res = transaction.batch_update(statements_tuple)
324-
325-
if status.code == ABORTED:
326-
raise Aborted(status.details)
327-
328-
retried_checksum = ResultsChecksum()
329-
retried_checksum.consume_result(res)
330-
retried_checksum.consume_result(status.code)
331-
332-
_compare_checksums(checksum, retried_checksum)
333-
else:
334-
res_iter, retried_checksum = self.run_statement(statement, retried=True)
335-
# executing all the completed statements
336-
if statement != self._statements[-1]:
337-
for res in res_iter:
338-
retried_checksum.consume_result(res)
339-
340-
_compare_checksums(statement.checksum, retried_checksum)
341-
# executing the failed statement
342-
else:
343-
# streaming up to the failed result or
344-
# to the end of the streaming iterator
345-
while len(retried_checksum) < len(statement.checksum):
346-
try:
347-
res = next(iter(res_iter))
348-
retried_checksum.consume_result(res)
349-
except StopIteration:
350-
break
351-
352-
_compare_checksums(statement.checksum, retried_checksum)
353-
354276
def transaction_checkout(self):
355277
"""Get a Cloud Spanner transaction.
356278
@@ -443,11 +365,11 @@ def commit(self):
443365
if self._spanner_transaction_started and not self._read_only:
444366
self._transaction.commit()
445367
except Aborted:
446-
self.retry_transaction()
368+
self._transaction_helper.retry_transaction()
447369
self.commit()
448370
finally:
449371
self._release_session()
450-
self._statements = []
372+
self._transaction_helper.reset()
451373
self._transaction_begin_marked = False
452374
self._spanner_transaction_started = False
453375

@@ -467,7 +389,7 @@ def rollback(self):
467389
self._transaction.rollback()
468390
finally:
469391
self._release_session()
470-
self._statements = []
392+
self._transaction_helper.reset()
471393
self._transaction_begin_marked = False
472394
self._spanner_transaction_started = False
473395

@@ -486,7 +408,7 @@ def run_prior_DDL_statements(self):
486408

487409
return self.database.update_ddl(ddl_statements).result()
488410

489-
def run_statement(self, statement: Statement, retried=False):
411+
def run_statement(self, statement: Statement):
490412
"""Run single SQL statement in begun transaction.
491413
492414
This method is never used in autocommit mode. In
@@ -506,17 +428,11 @@ def run_statement(self, statement: Statement, retried=False):
506428
checksum of this statement results.
507429
"""
508430
transaction = self.transaction_checkout()
509-
if not retried:
510-
self._statements.append(statement)
511-
512-
return (
513-
transaction.execute_sql(
514-
statement.sql,
515-
statement.params,
516-
param_types=statement.param_types,
517-
request_options=self.request_options,
518-
),
519-
ResultsChecksum() if retried else statement.checksum,
431+
return transaction.execute_sql(
432+
statement.sql,
433+
statement.params,
434+
param_types=statement.param_types,
435+
request_options=self.request_options,
520436
)
521437

522438
@check_not_closed

0 commit comments

Comments
 (0)