Skip to content

Commit e981adb

Browse files
Ilya Gurovc24tlarkeeAVaksman
authored
fix(dbapi): autocommit enabling fails if no transactions begun (#177)
* fix(dbapi): autocommit enabling fails if no transactions begun * remove unused import * don't calculate checksums in autocommit mode * try using dummy WHERE clause * revert where clause * unveil error * fix where clauses * add print * don't log * print failed exceptions * don't print * separate insert statements * don't return * re-run * don't pyformat insert args * args * re-run * fix * fix error in transactions.tests.NonAutocommitTests.test_orm_query_without_autocommit * fix "already committed" error * fix for AttributeError: 'tuple' object has no attribute 'items' * fix * fix KeyError: 'type' Co-authored-by: Chris Kleinknecht <[email protected]> Co-authored-by: larkee <[email protected]> Co-authored-by: Alex <[email protected]>
1 parent 4ef793c commit e981adb

File tree

6 files changed

+116
-50
lines changed

6 files changed

+116
-50
lines changed

google/cloud/spanner_dbapi/connection.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from google.cloud import spanner_v1 as spanner
2323
from google.cloud.spanner_v1.session import _get_retry_delay
2424

25+
from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous
26+
from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous
27+
from google.cloud.spanner_dbapi._helpers import parse_insert
2528
from google.cloud.spanner_dbapi.checksum import _compare_checksums
2629
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
2730
from google.cloud.spanner_dbapi.cursor import Cursor
@@ -82,7 +85,7 @@ def autocommit(self, value):
8285
:type value: bool
8386
:param value: New autocommit mode state.
8487
"""
85-
if value and not self._autocommit:
88+
if value and not self._autocommit and self.inside_transaction:
8689
self.commit()
8790

8891
self._autocommit = value
@@ -96,6 +99,19 @@ def database(self):
9699
"""
97100
return self._database
98101

102+
@property
103+
def inside_transaction(self):
104+
"""Flag: transaction is started.
105+
106+
Returns:
107+
bool: True if transaction begun, False otherwise.
108+
"""
109+
return (
110+
self._transaction
111+
and not self._transaction.committed
112+
and not self._transaction.rolled_back
113+
)
114+
99115
@property
100116
def instance(self):
101117
"""Instance to which this connection relates.
@@ -191,11 +207,7 @@ def transaction_checkout(self):
191207
:returns: A Cloud Spanner transaction object, ready to use.
192208
"""
193209
if not self.autocommit:
194-
if (
195-
not self._transaction
196-
or self._transaction.committed
197-
or self._transaction.rolled_back
198-
):
210+
if not self.inside_transaction:
199211
self._transaction = self._session_checkout().transaction()
200212
self._transaction.begin()
201213

@@ -216,11 +228,7 @@ def close(self):
216228
The connection will be unusable from this point forward. If the
217229
connection has an active transaction, it will be rolled back.
218230
"""
219-
if (
220-
self._transaction
221-
and not self._transaction.committed
222-
and not self._transaction.rolled_back
223-
):
231+
if self.inside_transaction:
224232
self._transaction.rollback()
225233

226234
if self._own_pool:
@@ -235,7 +243,7 @@ def commit(self):
235243
"""
236244
if self._autocommit:
237245
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
238-
elif self._transaction:
246+
elif self.inside_transaction:
239247
try:
240248
self._transaction.commit()
241249
self._release_session()
@@ -291,6 +299,24 @@ def run_statement(self, statement, retried=False):
291299
if not retried:
292300
self._statements.append(statement)
293301

302+
if statement.is_insert:
303+
parts = parse_insert(statement.sql, statement.params)
304+
305+
if parts.get("homogenous"):
306+
_execute_insert_homogenous(transaction, parts)
307+
return (
308+
iter(()),
309+
ResultsChecksum() if retried else statement.checksum,
310+
)
311+
else:
312+
_execute_insert_heterogenous(
313+
transaction, parts.get("sql_params_list"),
314+
)
315+
return (
316+
iter(()),
317+
ResultsChecksum() if retried else statement.checksum,
318+
)
319+
294320
return (
295321
transaction.execute_sql(
296322
statement.sql, statement.params, param_types=statement.param_types,

google/cloud/spanner_dbapi/cursor.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
_UNSET_COUNT = -1
4343

4444
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
45-
Statement = namedtuple("Statement", "sql, params, param_types, checksum")
45+
Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert")
4646

4747

4848
class Cursor(object):
@@ -95,9 +95,9 @@ def description(self):
9595
for field in row_type.fields:
9696
column_info = ColumnInfo(
9797
name=field.name,
98-
type_code=field.type.code,
98+
type_code=field.type_.code,
9999
# Size of the SQL type of the column.
100-
display_size=code_to_display_size.get(field.type.code),
100+
display_size=code_to_display_size.get(field.type_.code),
101101
# Client perceived size of the column.
102102
internal_size=field.ByteSize(),
103103
)
@@ -172,10 +172,20 @@ def execute(self, sql, args=None):
172172
self.connection.run_prior_DDL_statements()
173173

174174
if not self.connection.autocommit:
175-
sql, params = sql_pyformat_args_to_spanner(sql, args)
175+
if classification == parse_utils.STMT_UPDATING:
176+
sql = parse_utils.ensure_where_clause(sql)
177+
178+
if classification != parse_utils.STMT_INSERT:
179+
sql, args = sql_pyformat_args_to_spanner(sql, args or None)
176180

177181
statement = Statement(
178-
sql, params, get_param_types(params), ResultsChecksum(),
182+
sql,
183+
args,
184+
get_param_types(args or None)
185+
if classification != parse_utils.STMT_INSERT
186+
else {},
187+
ResultsChecksum(),
188+
classification == parse_utils.STMT_INSERT,
179189
)
180190
(self._result_set, self._checksum,) = self.connection.run_statement(
181191
statement
@@ -233,7 +243,8 @@ def fetchone(self):
233243

234244
try:
235245
res = next(self)
236-
self._checksum.consume_result(res)
246+
if not self.connection.autocommit:
247+
self._checksum.consume_result(res)
237248
return res
238249
except StopIteration:
239250
return
@@ -250,7 +261,8 @@ def fetchall(self):
250261
res = []
251262
try:
252263
for row in self:
253-
self._checksum.consume_result(row)
264+
if not self.connection.autocommit:
265+
self._checksum.consume_result(row)
254266
res.append(row)
255267
except Aborted:
256268
self._connection.retry_transaction()
@@ -278,7 +290,8 @@ def fetchmany(self, size=None):
278290
for i in range(size):
279291
try:
280292
res = next(self)
281-
self._checksum.consume_result(res)
293+
if not self.connection.autocommit:
294+
self._checksum.consume_result(res)
282295
items.append(res)
283296
except StopIteration:
284297
break

google/cloud/spanner_dbapi/parse_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -523,19 +523,15 @@ def get_param_types(params):
523523
def ensure_where_clause(sql):
524524
"""
525525
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
526-
Raise an error, if the given sql doesn't include it.
526+
Add a dummy WHERE clause if non detected.
527527
528528
:type sql: `str`
529529
:param sql: SQL code to check.
530-
531-
:raises: :class:`ProgrammingError` if the given sql doesn't include a WHERE clause.
532530
"""
533531
if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]):
534532
return sql
535533

536-
raise ProgrammingError(
537-
"Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query"
538-
)
534+
return sql + " WHERE 1=1"
539535

540536

541537
def escape_name(name):

tests/unit/spanner_dbapi/test_connection.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Cloud Spanner DB-API Connection class unit tests."""
1616

1717
import mock
18-
import sys
1918
import unittest
2019
import warnings
2120

@@ -51,25 +50,57 @@ def _make_connection(self):
5150
database = instance.database(self.DATABASE)
5251
return Connection(instance, database)
5352

54-
@unittest.skipIf(sys.version_info[0] < 3, "Python 2 patching is outdated")
55-
def test_property_autocommit_setter(self):
56-
from google.cloud.spanner_dbapi import Connection
57-
58-
connection = Connection(self.INSTANCE, self.DATABASE)
53+
def test_autocommit_setter_transaction_not_started(self):
54+
connection = self._make_connection()
5955

6056
with mock.patch(
6157
"google.cloud.spanner_dbapi.connection.Connection.commit"
6258
) as mock_commit:
6359
connection.autocommit = True
64-
mock_commit.assert_called_once_with()
65-
self.assertEqual(connection._autocommit, True)
60+
mock_commit.assert_not_called()
61+
self.assertTrue(connection._autocommit)
6662

6763
with mock.patch(
6864
"google.cloud.spanner_dbapi.connection.Connection.commit"
6965
) as mock_commit:
7066
connection.autocommit = False
7167
mock_commit.assert_not_called()
72-
self.assertEqual(connection._autocommit, False)
68+
self.assertFalse(connection._autocommit)
69+
70+
def test_autocommit_setter_transaction_started(self):
71+
connection = self._make_connection()
72+
73+
with mock.patch(
74+
"google.cloud.spanner_dbapi.connection.Connection.commit"
75+
) as mock_commit:
76+
connection._transaction = mock.Mock(committed=False, rolled_back=False)
77+
78+
connection.autocommit = True
79+
mock_commit.assert_called_once()
80+
self.assertTrue(connection._autocommit)
81+
82+
def test_autocommit_setter_transaction_started_commited_rolled_back(self):
83+
connection = self._make_connection()
84+
85+
with mock.patch(
86+
"google.cloud.spanner_dbapi.connection.Connection.commit"
87+
) as mock_commit:
88+
connection._transaction = mock.Mock(committed=True, rolled_back=False)
89+
90+
connection.autocommit = True
91+
mock_commit.assert_not_called()
92+
self.assertTrue(connection._autocommit)
93+
94+
connection.autocommit = False
95+
96+
with mock.patch(
97+
"google.cloud.spanner_dbapi.connection.Connection.commit"
98+
) as mock_commit:
99+
connection._transaction = mock.Mock(committed=False, rolled_back=True)
100+
101+
connection.autocommit = True
102+
mock_commit.assert_not_called()
103+
self.assertTrue(connection._autocommit)
73104

74105
def test_property_database(self):
75106
from google.cloud.spanner_v1.database import Database
@@ -166,7 +197,9 @@ def test_commit(self, mock_warn):
166197
connection.commit()
167198
mock_release.assert_not_called()
168199

169-
connection._transaction = mock_transaction = mock.MagicMock()
200+
connection._transaction = mock_transaction = mock.MagicMock(
201+
rolled_back=False, committed=False
202+
)
170203
mock_transaction.commit = mock_commit = mock.MagicMock()
171204

172205
with mock.patch(
@@ -316,7 +349,7 @@ def test_run_statement_remember_statements(self):
316349

317350
connection = self._make_connection()
318351

319-
statement = Statement(sql, params, param_types, ResultsChecksum(),)
352+
statement = Statement(sql, params, param_types, ResultsChecksum(), False)
320353
with mock.patch(
321354
"google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
322355
):
@@ -338,7 +371,7 @@ def test_run_statement_dont_remember_retried_statements(self):
338371

339372
connection = self._make_connection()
340373

341-
statement = Statement(sql, params, param_types, ResultsChecksum(),)
374+
statement = Statement(sql, params, param_types, ResultsChecksum(), False)
342375
with mock.patch(
343376
"google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
344377
):
@@ -352,7 +385,7 @@ def test_clear_statements_on_commit(self):
352385
cleared, when the transaction is commited.
353386
"""
354387
connection = self._make_connection()
355-
connection._transaction = mock.Mock()
388+
connection._transaction = mock.Mock(rolled_back=False, committed=False)
356389
connection._statements = [{}, {}]
357390

358391
self.assertEqual(len(connection._statements), 2)
@@ -390,7 +423,7 @@ def test_retry_transaction(self):
390423
checksum.consume_result(row)
391424
retried_checkum = ResultsChecksum()
392425

393-
statement = Statement("SELECT 1", [], {}, checksum,)
426+
statement = Statement("SELECT 1", [], {}, checksum, False)
394427
connection._statements.append(statement)
395428

396429
with mock.patch(
@@ -423,7 +456,7 @@ def test_retry_transaction_checksum_mismatch(self):
423456
checksum.consume_result(row)
424457
retried_checkum = ResultsChecksum()
425458

426-
statement = Statement("SELECT 1", [], {}, checksum,)
459+
statement = Statement("SELECT 1", [], {}, checksum, False)
427460
connection._statements.append(statement)
428461

429462
with mock.patch(
@@ -453,9 +486,9 @@ def test_commit_retry_aborted_statements(self):
453486
cursor._checksum = ResultsChecksum()
454487
cursor._checksum.consume_result(row)
455488

456-
statement = Statement("SELECT 1", [], {}, cursor._checksum,)
489+
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
457490
connection._statements.append(statement)
458-
connection._transaction = mock.Mock()
491+
connection._transaction = mock.Mock(rolled_back=False, committed=False)
459492

460493
with mock.patch.object(
461494
connection._transaction, "commit", side_effect=(Aborted("Aborted"), None),
@@ -507,7 +540,7 @@ def test_retry_aborted_retry(self):
507540
cursor._checksum = ResultsChecksum()
508541
cursor._checksum.consume_result(row)
509542

510-
statement = Statement("SELECT 1", [], {}, cursor._checksum,)
543+
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
511544
connection._statements.append(statement)
512545

513546
metadata_mock = mock.Mock()

tests/unit/spanner_dbapi/test_cursor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_execute_attribute_error(self):
126126
cursor = self._make_one(connection)
127127

128128
with self.assertRaises(AttributeError):
129-
cursor.execute(sql="")
129+
cursor.execute(sql="SELECT 1")
130130

131131
def test_execute_autocommit_off(self):
132132
from google.cloud.spanner_dbapi.utils import PeekIterator
@@ -531,7 +531,7 @@ def test_fetchone_retry_aborted_statements(self):
531531
cursor._checksum = ResultsChecksum()
532532
cursor._checksum.consume_result(row)
533533

534-
statement = Statement("SELECT 1", [], {}, cursor._checksum,)
534+
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
535535
connection._statements.append(statement)
536536

537537
with mock.patch(
@@ -570,7 +570,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self):
570570
cursor._checksum = ResultsChecksum()
571571
cursor._checksum.consume_result(row)
572572

573-
statement = Statement("SELECT 1", [], {}, cursor._checksum,)
573+
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
574574
connection._statements.append(statement)
575575

576576
with mock.patch(

tests/unit/spanner_dbapi/test_parse_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ def test_get_param_types_none(self):
391391

392392
@unittest.skipIf(skip_condition, skip_message)
393393
def test_ensure_where_clause(self):
394-
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
395394
from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause
396395

397396
cases = (
@@ -409,8 +408,7 @@ def test_ensure_where_clause(self):
409408

410409
for sql in err_cases:
411410
with self.subTest(sql=sql):
412-
with self.assertRaises(ProgrammingError):
413-
ensure_where_clause(sql)
411+
self.assertEqual(ensure_where_clause(sql), sql + " WHERE 1=1")
414412

415413
@unittest.skipIf(skip_condition, skip_message)
416414
def test_escape_name(self):

0 commit comments

Comments
 (0)