@@ -1840,6 +1840,25 @@ def test_to_arrow(self):
1840
1840
self .assertIsInstance (tbl , pyarrow .Table )
1841
1841
self .assertEqual (tbl .num_rows , 0 )
1842
1842
1843
+ @mock .patch ("google.cloud.bigquery.table.pyarrow" , new = None )
1844
+ def test_to_arrow_iterable_error_if_pyarrow_is_none (self ):
1845
+ row_iterator = self ._make_one ()
1846
+ with self .assertRaises (ValueError ):
1847
+ row_iterator .to_arrow_iterable ()
1848
+
1849
+ @unittest .skipIf (pyarrow is None , "Requires `pyarrow`" )
1850
+ def test_to_arrow_iterable (self ):
1851
+ row_iterator = self ._make_one ()
1852
+ arrow_iter = row_iterator .to_arrow_iterable ()
1853
+
1854
+ result = list (arrow_iter )
1855
+
1856
+ self .assertEqual (len (result ), 1 )
1857
+ record_batch = result [0 ]
1858
+ self .assertIsInstance (record_batch , pyarrow .RecordBatch )
1859
+ self .assertEqual (record_batch .num_rows , 0 )
1860
+ self .assertEqual (record_batch .num_columns , 0 )
1861
+
1843
1862
@mock .patch ("google.cloud.bigquery.table.pandas" , new = None )
1844
1863
def test_to_dataframe_error_if_pandas_is_none (self ):
1845
1864
row_iterator = self ._make_one ()
@@ -2151,6 +2170,205 @@ def test__validate_bqstorage_returns_false_w_warning_if_obsolete_version(self):
2151
2170
]
2152
2171
assert matching_warnings , "Obsolete dependency warning not raised."
2153
2172
2173
+ @unittest .skipIf (pyarrow is None , "Requires `pyarrow`" )
2174
+ def test_to_arrow_iterable (self ):
2175
+ from google .cloud .bigquery .schema import SchemaField
2176
+
2177
+ schema = [
2178
+ SchemaField ("name" , "STRING" , mode = "REQUIRED" ),
2179
+ SchemaField ("age" , "INTEGER" , mode = "REQUIRED" ),
2180
+ SchemaField (
2181
+ "child" ,
2182
+ "RECORD" ,
2183
+ mode = "REPEATED" ,
2184
+ fields = [
2185
+ SchemaField ("name" , "STRING" , mode = "REQUIRED" ),
2186
+ SchemaField ("age" , "INTEGER" , mode = "REQUIRED" ),
2187
+ ],
2188
+ ),
2189
+ ]
2190
+ rows = [
2191
+ {
2192
+ "f" : [
2193
+ {"v" : "Bharney Rhubble" },
2194
+ {"v" : "33" },
2195
+ {
2196
+ "v" : [
2197
+ {"v" : {"f" : [{"v" : "Whamm-Whamm Rhubble" }, {"v" : "3" }]}},
2198
+ {"v" : {"f" : [{"v" : "Hoppy" }, {"v" : "1" }]}},
2199
+ ]
2200
+ },
2201
+ ]
2202
+ },
2203
+ {
2204
+ "f" : [
2205
+ {"v" : "Wylma Phlyntstone" },
2206
+ {"v" : "29" },
2207
+ {
2208
+ "v" : [
2209
+ {"v" : {"f" : [{"v" : "Bepples Phlyntstone" }, {"v" : "0" }]}},
2210
+ {"v" : {"f" : [{"v" : "Dino" }, {"v" : "4" }]}},
2211
+ ]
2212
+ },
2213
+ ]
2214
+ },
2215
+ ]
2216
+ path = "/foo"
2217
+ api_request = mock .Mock (
2218
+ side_effect = [
2219
+ {"rows" : [rows [0 ]], "pageToken" : "NEXTPAGE" },
2220
+ {"rows" : [rows [1 ]]},
2221
+ ]
2222
+ )
2223
+ row_iterator = self ._make_one (
2224
+ _mock_client (), api_request , path , schema , page_size = 1 , max_results = 5
2225
+ )
2226
+
2227
+ record_batches = row_iterator .to_arrow_iterable ()
2228
+ self .assertIsInstance (record_batches , types .GeneratorType )
2229
+ record_batches = list (record_batches )
2230
+ self .assertEqual (len (record_batches ), 2 )
2231
+
2232
+ # Check the schema.
2233
+ for record_batch in record_batches :
2234
+ self .assertIsInstance (record_batch , pyarrow .RecordBatch )
2235
+ self .assertEqual (record_batch .schema [0 ].name , "name" )
2236
+ self .assertTrue (pyarrow .types .is_string (record_batch .schema [0 ].type ))
2237
+ self .assertEqual (record_batch .schema [1 ].name , "age" )
2238
+ self .assertTrue (pyarrow .types .is_int64 (record_batch .schema [1 ].type ))
2239
+ child_field = record_batch .schema [2 ]
2240
+ self .assertEqual (child_field .name , "child" )
2241
+ self .assertTrue (pyarrow .types .is_list (child_field .type ))
2242
+ self .assertTrue (pyarrow .types .is_struct (child_field .type .value_type ))
2243
+ self .assertEqual (child_field .type .value_type [0 ].name , "name" )
2244
+ self .assertEqual (child_field .type .value_type [1 ].name , "age" )
2245
+
2246
+ # Check the data.
2247
+ record_batch_1 = record_batches [0 ].to_pydict ()
2248
+ names = record_batch_1 ["name" ]
2249
+ ages = record_batch_1 ["age" ]
2250
+ children = record_batch_1 ["child" ]
2251
+ self .assertEqual (names , ["Bharney Rhubble" ])
2252
+ self .assertEqual (ages , [33 ])
2253
+ self .assertEqual (
2254
+ children ,
2255
+ [
2256
+ [
2257
+ {"name" : "Whamm-Whamm Rhubble" , "age" : 3 },
2258
+ {"name" : "Hoppy" , "age" : 1 },
2259
+ ],
2260
+ ],
2261
+ )
2262
+
2263
+ record_batch_2 = record_batches [1 ].to_pydict ()
2264
+ names = record_batch_2 ["name" ]
2265
+ ages = record_batch_2 ["age" ]
2266
+ children = record_batch_2 ["child" ]
2267
+ self .assertEqual (names , ["Wylma Phlyntstone" ])
2268
+ self .assertEqual (ages , [29 ])
2269
+ self .assertEqual (
2270
+ children ,
2271
+ [[{"name" : "Bepples Phlyntstone" , "age" : 0 }, {"name" : "Dino" , "age" : 4 }]],
2272
+ )
2273
+
2274
+ @mock .patch ("google.cloud.bigquery.table.pyarrow" , new = None )
2275
+ def test_to_arrow_iterable_error_if_pyarrow_is_none (self ):
2276
+ from google .cloud .bigquery .schema import SchemaField
2277
+
2278
+ schema = [
2279
+ SchemaField ("name" , "STRING" , mode = "REQUIRED" ),
2280
+ SchemaField ("age" , "INTEGER" , mode = "REQUIRED" ),
2281
+ ]
2282
+ rows = [
2283
+ {"f" : [{"v" : "Phred Phlyntstone" }, {"v" : "32" }]},
2284
+ {"f" : [{"v" : "Bharney Rhubble" }, {"v" : "33" }]},
2285
+ ]
2286
+ path = "/foo"
2287
+ api_request = mock .Mock (return_value = {"rows" : rows })
2288
+ row_iterator = self ._make_one (_mock_client (), api_request , path , schema )
2289
+
2290
+ with pytest .raises (ValueError , match = "pyarrow" ):
2291
+ row_iterator .to_arrow_iterable ()
2292
+
2293
+ @unittest .skipIf (pyarrow is None , "Requires `pyarrow`" )
2294
+ @unittest .skipIf (
2295
+ bigquery_storage is None , "Requires `google-cloud-bigquery-storage`"
2296
+ )
2297
+ def test_to_arrow_iterable_w_bqstorage (self ):
2298
+ from google .cloud .bigquery import schema
2299
+ from google .cloud .bigquery import table as mut
2300
+ from google .cloud .bigquery_storage_v1 import reader
2301
+
2302
+ bqstorage_client = mock .create_autospec (bigquery_storage .BigQueryReadClient )
2303
+ bqstorage_client ._transport = mock .create_autospec (
2304
+ big_query_read_grpc_transport .BigQueryReadGrpcTransport
2305
+ )
2306
+ streams = [
2307
+ # Use two streams we want to check frames are read from each stream.
2308
+ {"name" : "/projects/proj/dataset/dset/tables/tbl/streams/1234" },
2309
+ {"name" : "/projects/proj/dataset/dset/tables/tbl/streams/5678" },
2310
+ ]
2311
+ session = bigquery_storage .types .ReadSession (streams = streams )
2312
+ arrow_schema = pyarrow .schema (
2313
+ [
2314
+ pyarrow .field ("colA" , pyarrow .int64 ()),
2315
+ # Not alphabetical to test column order.
2316
+ pyarrow .field ("colC" , pyarrow .float64 ()),
2317
+ pyarrow .field ("colB" , pyarrow .string ()),
2318
+ ]
2319
+ )
2320
+ session .arrow_schema .serialized_schema = arrow_schema .serialize ().to_pybytes ()
2321
+ bqstorage_client .create_read_session .return_value = session
2322
+
2323
+ mock_rowstream = mock .create_autospec (reader .ReadRowsStream )
2324
+ bqstorage_client .read_rows .return_value = mock_rowstream
2325
+
2326
+ mock_rows = mock .create_autospec (reader .ReadRowsIterable )
2327
+ mock_rowstream .rows .return_value = mock_rows
2328
+ page_items = [
2329
+ pyarrow .array ([1 , - 1 ]),
2330
+ pyarrow .array ([2.0 , 4.0 ]),
2331
+ pyarrow .array (["abc" , "def" ]),
2332
+ ]
2333
+
2334
+ expected_record_batch = pyarrow .RecordBatch .from_arrays (
2335
+ page_items , schema = arrow_schema
2336
+ )
2337
+ expected_num_record_batches = 3
2338
+
2339
+ mock_page = mock .create_autospec (reader .ReadRowsPage )
2340
+ mock_page .to_arrow .return_value = expected_record_batch
2341
+ mock_pages = (mock_page ,) * expected_num_record_batches
2342
+ type(mock_rows ).pages = mock .PropertyMock (return_value = mock_pages )
2343
+
2344
+ schema = [
2345
+ schema .SchemaField ("colA" , "INTEGER" ),
2346
+ schema .SchemaField ("colC" , "FLOAT" ),
2347
+ schema .SchemaField ("colB" , "STRING" ),
2348
+ ]
2349
+
2350
+ row_iterator = mut .RowIterator (
2351
+ _mock_client (),
2352
+ None , # api_request: ignored
2353
+ None , # path: ignored
2354
+ schema ,
2355
+ table = mut .TableReference .from_string ("proj.dset.tbl" ),
2356
+ selected_fields = schema ,
2357
+ )
2358
+
2359
+ record_batches = list (
2360
+ row_iterator .to_arrow_iterable (bqstorage_client = bqstorage_client )
2361
+ )
2362
+ total_record_batches = len (streams ) * len (mock_pages )
2363
+ self .assertEqual (len (record_batches ), total_record_batches )
2364
+
2365
+ for record_batch in record_batches :
2366
+ # Are the record batches return as expected?
2367
+ self .assertEqual (record_batch , expected_record_batch )
2368
+
2369
+ # Don't close the client if it was passed in.
2370
+ bqstorage_client ._transport .grpc_channel .close .assert_not_called ()
2371
+
2154
2372
@unittest .skipIf (pyarrow is None , "Requires `pyarrow`" )
2155
2373
def test_to_arrow (self ):
2156
2374
from google .cloud .bigquery .schema import SchemaField
0 commit comments