Skip to content

Commit 489ac0a

Browse files
test(db_api): increase coverage of db_api (#231)
* pref: increase coverage of db_api * fix: lint * fix: added missing unit tetst
1 parent a2b53a3 commit 489ac0a

File tree

3 files changed

+193
-0
lines changed

3 files changed

+193
-0
lines changed

tests/unit/spanner_dbapi/test_connection.py

+152
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ def test_close(self):
183183
mock_transaction.rollback = mock_rollback = mock.MagicMock()
184184
connection.close()
185185
mock_rollback.assert_called_once_with()
186+
connection._transaction = mock.MagicMock()
187+
connection._own_pool = False
188+
connection.close()
189+
self.assertTrue(connection.is_closed)
186190

187191
@mock.patch.object(warnings, "warn")
188192
def test_commit(self, mock_warn):
@@ -379,6 +383,25 @@ def test_run_statement_dont_remember_retried_statements(self):
379383

380384
self.assertEqual(len(connection._statements), 0)
381385

386+
def test_run_statement_w_heterogenous_insert_statements(self):
387+
"""Check that Connection executed heterogenous insert statements."""
388+
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
389+
from google.cloud.spanner_dbapi.cursor import Statement
390+
391+
sql = "INSERT INTO T (f1, f2) VALUES (1, 2)"
392+
params = None
393+
param_types = None
394+
395+
connection = self._make_connection()
396+
397+
statement = Statement(sql, params, param_types, ResultsChecksum(), True)
398+
with mock.patch(
399+
"google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
400+
):
401+
connection.run_statement(statement, retried=True)
402+
403+
self.assertEqual(len(connection._statements), 0)
404+
382405
def test_run_statement_w_homogeneous_insert_statements(self):
383406
"""Check that Connection executed homogeneous insert statements."""
384407
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
@@ -582,3 +605,132 @@ def test_retry_aborted_retry(self):
582605
mock.call(statement, retried=True),
583606
)
584607
)
608+
609+
def test_retry_transaction_raise_max_internal_retries(self):
610+
"""Check retrying raise an error of max internal retries."""
611+
from google.cloud.spanner_dbapi import connection as conn
612+
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
613+
from google.cloud.spanner_dbapi.cursor import Statement
614+
615+
conn.MAX_INTERNAL_RETRIES = 0
616+
row = ["field1", "field2"]
617+
connection = self._make_connection()
618+
619+
checksum = ResultsChecksum()
620+
checksum.consume_result(row)
621+
622+
statement = Statement("SELECT 1", [], {}, checksum, False)
623+
connection._statements.append(statement)
624+
625+
with self.assertRaises(Exception):
626+
connection.retry_transaction()
627+
628+
conn.MAX_INTERNAL_RETRIES = 50
629+
630+
def test_retry_aborted_retry_without_delay(self):
631+
"""
632+
Check that in case of a retried transaction failed,
633+
the connection will retry it once again.
634+
"""
635+
from google.api_core.exceptions import Aborted
636+
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
637+
from google.cloud.spanner_dbapi.connection import connect
638+
from google.cloud.spanner_dbapi.cursor import Statement
639+
640+
row = ["field1", "field2"]
641+
642+
with mock.patch(
643+
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True,
644+
):
645+
with mock.patch(
646+
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
647+
):
648+
connection = connect("test-instance", "test-database")
649+
650+
cursor = connection.cursor()
651+
cursor._checksum = ResultsChecksum()
652+
cursor._checksum.consume_result(row)
653+
654+
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
655+
connection._statements.append(statement)
656+
657+
metadata_mock = mock.Mock()
658+
metadata_mock.trailing_metadata.return_value = {}
659+
660+
with mock.patch(
661+
"google.cloud.spanner_dbapi.connection.Connection.run_statement",
662+
side_effect=(
663+
Aborted("Aborted", errors=[metadata_mock]),
664+
([row], ResultsChecksum()),
665+
),
666+
) as retry_mock:
667+
with mock.patch(
668+
"google.cloud.spanner_dbapi.connection._get_retry_delay",
669+
return_value=False,
670+
):
671+
connection.retry_transaction()
672+
673+
retry_mock.assert_has_calls(
674+
(
675+
mock.call(statement, retried=True),
676+
mock.call(statement, retried=True),
677+
)
678+
)
679+
680+
def test_retry_transaction_w_multiple_statement(self):
681+
"""Check retrying an aborted transaction."""
682+
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
683+
from google.cloud.spanner_dbapi.cursor import Statement
684+
685+
row = ["field1", "field2"]
686+
connection = self._make_connection()
687+
688+
checksum = ResultsChecksum()
689+
checksum.consume_result(row)
690+
retried_checkum = ResultsChecksum()
691+
692+
statement = Statement("SELECT 1", [], {}, checksum, False)
693+
statement1 = Statement("SELECT 2", [], {}, checksum, False)
694+
connection._statements.append(statement)
695+
connection._statements.append(statement1)
696+
697+
with mock.patch(
698+
"google.cloud.spanner_dbapi.connection.Connection.run_statement",
699+
return_value=([row], retried_checkum),
700+
) as run_mock:
701+
with mock.patch(
702+
"google.cloud.spanner_dbapi.connection._compare_checksums"
703+
) as compare_mock:
704+
connection.retry_transaction()
705+
706+
compare_mock.assert_called_with(checksum, retried_checkum)
707+
708+
run_mock.assert_called_with(statement1, retried=True)
709+
710+
def test_retry_transaction_w_empty_response(self):
711+
"""Check retrying an aborted transaction."""
712+
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
713+
from google.cloud.spanner_dbapi.cursor import Statement
714+
715+
row = []
716+
connection = self._make_connection()
717+
718+
checksum = ResultsChecksum()
719+
checksum.count = 1
720+
retried_checkum = ResultsChecksum()
721+
722+
statement = Statement("SELECT 1", [], {}, checksum, False)
723+
connection._statements.append(statement)
724+
725+
with mock.patch(
726+
"google.cloud.spanner_dbapi.connection.Connection.run_statement",
727+
return_value=(row, retried_checkum),
728+
) as run_mock:
729+
with mock.patch(
730+
"google.cloud.spanner_dbapi.connection._compare_checksums"
731+
) as compare_mock:
732+
connection.retry_transaction()
733+
734+
compare_mock.assert_called_with(checksum, retried_checkum)
735+
736+
run_mock.assert_called_with(statement, retried=True)

tests/unit/spanner_dbapi/test_cursor.py

+25
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,31 @@ def test_execute_autocommit_off(self):
140140
self.assertIsInstance(cursor._result_set, mock.MagicMock)
141141
self.assertIsInstance(cursor._itr, PeekIterator)
142142

143+
def test_execute_insert_statement_autocommit_off(self):
144+
from google.cloud.spanner_dbapi import parse_utils
145+
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
146+
from google.cloud.spanner_dbapi.utils import PeekIterator
147+
148+
connection = self._make_connection(self.INSTANCE, mock.MagicMock())
149+
cursor = self._make_one(connection)
150+
cursor.connection._autocommit = False
151+
cursor.connection.transaction_checkout = mock.MagicMock(autospec=True)
152+
153+
cursor._checksum = ResultsChecksum()
154+
with mock.patch(
155+
"google.cloud.spanner_dbapi.parse_utils.classify_stmt",
156+
return_value=parse_utils.STMT_INSERT,
157+
):
158+
with mock.patch(
159+
"google.cloud.spanner_dbapi.connection.Connection.run_statement",
160+
return_value=(mock.MagicMock(), ResultsChecksum()),
161+
):
162+
cursor.execute(
163+
sql="INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)"
164+
)
165+
self.assertIsInstance(cursor._result_set, mock.MagicMock)
166+
self.assertIsInstance(cursor._itr, PeekIterator)
167+
143168
def test_execute_statement(self):
144169
from google.cloud.spanner_dbapi import parse_utils
145170

tests/unit/spanner_dbapi/test_utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,19 @@ def test_backtick_unicode(self):
8585
with self.subTest(sql=sql):
8686
got = backtick_unicode(sql)
8787
self.assertEqual(got, want)
88+
89+
@unittest.skipIf(skip_condition, skip_message)
90+
def test_StreamedManyResultSets(self):
91+
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
92+
93+
cases = [
94+
("iter_from_list", iter([1, 2, 3, 4, 6, 7]), [1, 2, 3, 4, 6, 7]),
95+
("iter_from_tuple", iter(("a", 12, 0xFF)), ["a", 12, 0xFF]),
96+
]
97+
98+
for name, data_in, expected in cases:
99+
with self.subTest(name=name):
100+
stream_result = StreamedManyResultSets()
101+
stream_result._iterators.append(data_in)
102+
actual = list(stream_result)
103+
self.assertEqual(actual, expected)

0 commit comments

Comments
 (0)