Skip to content

Commit 4e31455

Browse files
committed
fix: add support for row_count in cursor. (googleapis#675)
* fix: add support for row_count * docs: update rowcount property doc * fix: updated tests for cursor to check row_count * refactor: lint fixes * test: add test for do_batch_update * refactor: Empty commit
1 parent 77552f3 commit 4e31455

File tree

2 files changed

+67
-15
lines changed

2 files changed

+67
-15
lines changed

google/cloud/spanner_dbapi/cursor.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444

4545
from google.rpc.code_pb2 import ABORTED, OK
4646

47+
_UNSET_COUNT = -1
48+
4749
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
4850
Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert")
4951

@@ -80,6 +82,7 @@ class Cursor(object):
8082
def __init__(self, connection):
8183
self._itr = None
8284
self._result_set = None
85+
self._row_count = _UNSET_COUNT
8386
self.lastrowid = None
8487
self.connection = connection
8588
self._is_closed = False
@@ -134,13 +137,14 @@ def description(self):
134137

135138
@property
136139
def rowcount(self):
137-
"""The number of rows produced by the last `execute()` call.
140+
"""The number of rows updated by the last UPDATE, DELETE request's `execute()` call.
141+
For SELECT requests the rowcount returns -1.
138142
139-
The property is non-operational and always returns -1. Request
140-
resulting rows are streamed by the `fetch*()` methods and
141-
can't be counted before they are all streamed.
143+
:rtype: int
144+
:returns: The number of rows updated by the last UPDATE, DELETE request's .execute*() call.
142145
"""
143-
return -1
146+
147+
return self._row_count
144148

145149
@check_not_closed
146150
def callproc(self, procname, args=None):
@@ -170,7 +174,11 @@ def _do_execute_update(self, transaction, sql, params):
170174
result = transaction.execute_update(
171175
sql, params=params, param_types=get_param_types(params)
172176
)
173-
self._itr = iter([result])
177+
self._itr = None
178+
if type(result) == int:
179+
self._row_count = result
180+
181+
return result
174182

175183
def _do_batch_update(self, transaction, statements, many_result_set):
176184
status, res = transaction.batch_update(statements)
@@ -181,6 +189,8 @@ def _do_batch_update(self, transaction, statements, many_result_set):
181189
elif status.code != OK:
182190
raise OperationalError(status.message)
183191

192+
self._row_count = sum([max(val, 0) for val in res])
193+
184194
def _batch_DDLs(self, sql):
185195
"""
186196
Check that the given operation contains only DDL
@@ -414,6 +424,9 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
414424
# Read the first element so that the StreamedResultSet can
415425
# return the metadata after a DQL statement.
416426
self._itr = PeekIterator(self._result_set)
427+
# Unfortunately, Spanner doesn't seem to send back
428+
# information about the number of rows available.
429+
self._row_count = _UNSET_COUNT
417430

418431
def _handle_DQL(self, sql, params):
419432
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)

tests/unit/spanner_dbapi/test_cursor.py

+48-9
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ def _make_connection(self, *args, **kwargs):
3737

3838
return Connection(*args, **kwargs)
3939

40-
def _transaction_mock(self):
40+
def _transaction_mock(self, mock_response=[]):
4141
from google.rpc.code_pb2 import OK
4242

4343
transaction = mock.Mock(committed=False, rolled_back=False)
44-
transaction.batch_update = mock.Mock(return_value=[mock.Mock(code=OK), []])
44+
transaction.batch_update = mock.Mock(
45+
return_value=[mock.Mock(code=OK), mock_response]
46+
)
4547
return transaction
4648

4749
def test_property_connection(self):
@@ -62,10 +64,12 @@ def test_property_description(self):
6264
self.assertIsInstance(cursor.description[0], ColumnInfo)
6365

6466
def test_property_rowcount(self):
67+
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT
68+
6569
connection = self._make_connection(self.INSTANCE, self.DATABASE)
6670
cursor = self._make_one(connection)
6771

68-
assert cursor.rowcount == -1
72+
self.assertEqual(cursor.rowcount, _UNSET_COUNT)
6973

7074
def test_callproc(self):
7175
from google.cloud.spanner_dbapi.exceptions import InterfaceError
@@ -93,25 +97,58 @@ def test_close(self, mock_client):
9397
cursor.execute("SELECT * FROM database")
9498

9599
def test_do_execute_update(self):
96-
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
100+
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT
97101

98102
connection = self._make_connection(self.INSTANCE, self.DATABASE)
99103
cursor = self._make_one(connection)
100-
cursor._checksum = ResultsChecksum()
101104
transaction = mock.MagicMock()
102105

103106
def run_helper(ret_value):
104107
transaction.execute_update.return_value = ret_value
105-
cursor._do_execute_update(
108+
res = cursor._do_execute_update(
106109
transaction=transaction, sql="SELECT * WHERE true", params={},
107110
)
108-
return cursor.fetchall()
111+
return res
109112

110113
expected = "good"
111-
self.assertEqual(run_helper(expected), [expected])
114+
self.assertEqual(run_helper(expected), expected)
115+
self.assertEqual(cursor._row_count, _UNSET_COUNT)
112116

113117
expected = 1234
114-
self.assertEqual(run_helper(expected), [expected])
118+
self.assertEqual(run_helper(expected), expected)
119+
self.assertEqual(cursor._row_count, expected)
120+
121+
def test_do_batch_update(self):
122+
from google.cloud.spanner_dbapi import connect
123+
from google.cloud.spanner_v1.param_types import INT64
124+
from google.cloud.spanner_v1.types.spanner import Session
125+
126+
sql = "DELETE FROM table WHERE col1 = %s"
127+
128+
connection = connect("test-instance", "test-database")
129+
130+
connection.autocommit = True
131+
transaction = self._transaction_mock(mock_response=[1, 1, 1])
132+
cursor = connection.cursor()
133+
134+
with mock.patch(
135+
"google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session",
136+
return_value=Session(),
137+
):
138+
with mock.patch(
139+
"google.cloud.spanner_v1.session.Session.transaction",
140+
return_value=transaction,
141+
):
142+
cursor.executemany(sql, [(1,), (2,), (3,)])
143+
144+
transaction.batch_update.assert_called_once_with(
145+
[
146+
("DELETE FROM table WHERE col1 = @a0", {"a0": 1}, {"a0": INT64}),
147+
("DELETE FROM table WHERE col1 = @a0", {"a0": 2}, {"a0": INT64}),
148+
("DELETE FROM table WHERE col1 = @a0", {"a0": 3}, {"a0": INT64}),
149+
]
150+
)
151+
self.assertEqual(cursor._row_count, 3)
115152

116153
def test_execute_programming_error(self):
117154
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
@@ -704,6 +741,7 @@ def test_setoutputsize(self):
704741

705742
def test_handle_dql(self):
706743
from google.cloud.spanner_dbapi import utils
744+
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT
707745

708746
connection = self._make_connection(self.INSTANCE, mock.MagicMock())
709747
connection.database.snapshot.return_value.__enter__.return_value = (
@@ -715,6 +753,7 @@ def test_handle_dql(self):
715753
cursor._handle_DQL("sql", params=None)
716754
self.assertEqual(cursor._result_set, ["0"])
717755
self.assertIsInstance(cursor._itr, utils.PeekIterator)
756+
self.assertEqual(cursor._row_count, _UNSET_COUNT)
718757

719758
def test_context(self):
720759
connection = self._make_connection(self.INSTANCE, self.DATABASE)

0 commit comments

Comments
 (0)