@@ -37,11 +37,13 @@ def _make_connection(self, *args, **kwargs):
37
37
38
38
return Connection (* args , ** kwargs )
39
39
40
- def _transaction_mock (self ):
40
+ def _transaction_mock (self , mock_response = [] ):
41
41
from google .rpc .code_pb2 import OK
42
42
43
43
transaction = mock .Mock (committed = False , rolled_back = False )
44
- transaction .batch_update = mock .Mock (return_value = [mock .Mock (code = OK ), []])
44
+ transaction .batch_update = mock .Mock (
45
+ return_value = [mock .Mock (code = OK ), mock_response ]
46
+ )
45
47
return transaction
46
48
47
49
def test_property_connection (self ):
@@ -62,10 +64,12 @@ def test_property_description(self):
62
64
self .assertIsInstance (cursor .description [0 ], ColumnInfo )
63
65
64
66
def test_property_rowcount (self ):
67
+ from google .cloud .spanner_dbapi .cursor import _UNSET_COUNT
68
+
65
69
connection = self ._make_connection (self .INSTANCE , self .DATABASE )
66
70
cursor = self ._make_one (connection )
67
71
68
- assert cursor .rowcount == - 1
72
+ self . assertEqual ( cursor .rowcount , _UNSET_COUNT )
69
73
70
74
def test_callproc (self ):
71
75
from google .cloud .spanner_dbapi .exceptions import InterfaceError
@@ -93,25 +97,58 @@ def test_close(self, mock_client):
93
97
cursor .execute ("SELECT * FROM database" )
94
98
95
99
def test_do_execute_update (self ):
96
- from google .cloud .spanner_dbapi .checksum import ResultsChecksum
100
+ from google .cloud .spanner_dbapi .cursor import _UNSET_COUNT
97
101
98
102
connection = self ._make_connection (self .INSTANCE , self .DATABASE )
99
103
cursor = self ._make_one (connection )
100
- cursor ._checksum = ResultsChecksum ()
101
104
transaction = mock .MagicMock ()
102
105
103
106
def run_helper (ret_value ):
104
107
transaction .execute_update .return_value = ret_value
105
- cursor ._do_execute_update (
108
+ res = cursor ._do_execute_update (
106
109
transaction = transaction , sql = "SELECT * WHERE true" , params = {},
107
110
)
108
- return cursor . fetchall ()
111
+ return res
109
112
110
113
expected = "good"
111
- self .assertEqual (run_helper (expected ), [expected ])
114
+ self .assertEqual (run_helper (expected ), expected )
115
+ self .assertEqual (cursor ._row_count , _UNSET_COUNT )
112
116
113
117
expected = 1234
114
- self .assertEqual (run_helper (expected ), [expected ])
118
+ self .assertEqual (run_helper (expected ), expected )
119
+ self .assertEqual (cursor ._row_count , expected )
120
+
121
+ def test_do_batch_update (self ):
122
+ from google .cloud .spanner_dbapi import connect
123
+ from google .cloud .spanner_v1 .param_types import INT64
124
+ from google .cloud .spanner_v1 .types .spanner import Session
125
+
126
+ sql = "DELETE FROM table WHERE col1 = %s"
127
+
128
+ connection = connect ("test-instance" , "test-database" )
129
+
130
+ connection .autocommit = True
131
+ transaction = self ._transaction_mock (mock_response = [1 , 1 , 1 ])
132
+ cursor = connection .cursor ()
133
+
134
+ with mock .patch (
135
+ "google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session" ,
136
+ return_value = Session (),
137
+ ):
138
+ with mock .patch (
139
+ "google.cloud.spanner_v1.session.Session.transaction" ,
140
+ return_value = transaction ,
141
+ ):
142
+ cursor .executemany (sql , [(1 ,), (2 ,), (3 ,)])
143
+
144
+ transaction .batch_update .assert_called_once_with (
145
+ [
146
+ ("DELETE FROM table WHERE col1 = @a0" , {"a0" : 1 }, {"a0" : INT64 }),
147
+ ("DELETE FROM table WHERE col1 = @a0" , {"a0" : 2 }, {"a0" : INT64 }),
148
+ ("DELETE FROM table WHERE col1 = @a0" , {"a0" : 3 }, {"a0" : INT64 }),
149
+ ]
150
+ )
151
+ self .assertEqual (cursor ._row_count , 3 )
115
152
116
153
def test_execute_programming_error (self ):
117
154
from google .cloud .spanner_dbapi .exceptions import ProgrammingError
@@ -704,6 +741,7 @@ def test_setoutputsize(self):
704
741
705
742
def test_handle_dql (self ):
706
743
from google .cloud .spanner_dbapi import utils
744
+ from google .cloud .spanner_dbapi .cursor import _UNSET_COUNT
707
745
708
746
connection = self ._make_connection (self .INSTANCE , mock .MagicMock ())
709
747
connection .database .snapshot .return_value .__enter__ .return_value = (
@@ -715,6 +753,7 @@ def test_handle_dql(self):
715
753
cursor ._handle_DQL ("sql" , params = None )
716
754
self .assertEqual (cursor ._result_set , ["0" ])
717
755
self .assertIsInstance (cursor ._itr , utils .PeekIterator )
756
+ self .assertEqual (cursor ._row_count , _UNSET_COUNT )
718
757
719
758
def test_context (self ):
720
759
connection = self ._make_connection (self .INSTANCE , self .DATABASE )
0 commit comments