Skip to content

Commit 21cd710

Browse files
judahrandtswast
andauthored
feat: promote RowIterator.to_arrow_iterable to public method (#1073)
* feat: promote `to_arrow_iterable` to public method * use correct version number * Update google/cloud/bigquery/table.py Co-authored-by: Tim Swast <[email protected]>
1 parent 1b5dc5c commit 21cd710

File tree

3 files changed

+297
-4
lines changed

3 files changed

+297
-4
lines changed

google/cloud/bigquery/_pandas_helpers.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,12 @@ def _download_table_bqstorage(
838838

839839

840840
def download_arrow_bqstorage(
841-
project_id, table, bqstorage_client, preserve_order=False, selected_fields=None,
841+
project_id,
842+
table,
843+
bqstorage_client,
844+
preserve_order=False,
845+
selected_fields=None,
846+
max_queue_size=_MAX_QUEUE_SIZE_DEFAULT,
842847
):
843848
return _download_table_bqstorage(
844849
project_id,
@@ -847,6 +852,7 @@ def download_arrow_bqstorage(
847852
preserve_order=preserve_order,
848853
selected_fields=selected_fields,
849854
page_to_item=_bqstorage_page_to_arrow,
855+
max_queue_size=max_queue_size,
850856
)
851857

852858

google/cloud/bigquery/table.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -1629,15 +1629,57 @@ def _to_page_iterable(
16291629
)
16301630
yield from result_pages
16311631

1632-
def _to_arrow_iterable(self, bqstorage_client=None):
1633-
"""Create an iterable of arrow RecordBatches, to process the table as a stream."""
1632+
def to_arrow_iterable(
1633+
self,
1634+
bqstorage_client: "bigquery_storage.BigQueryReadClient" = None,
1635+
max_queue_size: int = _pandas_helpers._MAX_QUEUE_SIZE_DEFAULT, # type: ignore
1636+
) -> Iterator["pyarrow.RecordBatch"]:
1637+
"""[Beta] Create an iterable of class:`pyarrow.RecordBatch`, to process the table as a stream.
1638+
1639+
Args:
1640+
bqstorage_client (Optional[google.cloud.bigquery_storage_v1.BigQueryReadClient]):
1641+
A BigQuery Storage API client. If supplied, use the faster
1642+
BigQuery Storage API to fetch rows from BigQuery.
1643+
1644+
This method requires the ``pyarrow`` and
1645+
``google-cloud-bigquery-storage`` libraries.
1646+
1647+
This method only exposes a subset of the capabilities of the
1648+
BigQuery Storage API. For full access to all features
1649+
(projections, filters, snapshots) use the Storage API directly.
1650+
1651+
max_queue_size (Optional[int]):
1652+
The maximum number of result pages to hold in the internal queue when
1653+
streaming query results over the BigQuery Storage API. Ignored if
1654+
Storage API is not used.
1655+
1656+
By default, the max queue size is set to the number of BQ Storage streams
1657+
created by the server. If ``max_queue_size`` is :data:`None`, the queue
1658+
size is infinite.
1659+
1660+
Returns:
1661+
pyarrow.RecordBatch:
1662+
A generator of :class:`~pyarrow.RecordBatch`.
1663+
1664+
Raises:
1665+
ValueError:
1666+
If the :mod:`pyarrow` library cannot be imported.
1667+
1668+
.. versionadded:: 2.31.0
1669+
"""
1670+
if pyarrow is None:
1671+
raise ValueError(_NO_PYARROW_ERROR)
1672+
1673+
self._maybe_warn_max_results(bqstorage_client)
1674+
16341675
bqstorage_download = functools.partial(
16351676
_pandas_helpers.download_arrow_bqstorage,
16361677
self._project,
16371678
self._table,
16381679
bqstorage_client,
16391680
preserve_order=self._preserve_order,
16401681
selected_fields=self._selected_fields,
1682+
max_queue_size=max_queue_size,
16411683
)
16421684
tabledata_list_download = functools.partial(
16431685
_pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema
@@ -1729,7 +1771,7 @@ def to_arrow(
17291771
)
17301772

17311773
record_batches = []
1732-
for record_batch in self._to_arrow_iterable(
1774+
for record_batch in self.to_arrow_iterable(
17331775
bqstorage_client=bqstorage_client
17341776
):
17351777
record_batches.append(record_batch)
@@ -2225,6 +2267,33 @@ def to_dataframe_iterable(
22252267
raise ValueError(_NO_PANDAS_ERROR)
22262268
return iter((pandas.DataFrame(),))
22272269

2270+
def to_arrow_iterable(
2271+
self,
2272+
bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None,
2273+
max_queue_size: Optional[int] = None,
2274+
) -> Iterator["pyarrow.RecordBatch"]:
2275+
"""Create an iterable of pandas DataFrames, to process the table as a stream.
2276+
2277+
.. versionadded:: 2.31.0
2278+
2279+
Args:
2280+
bqstorage_client:
2281+
Ignored. Added for compatibility with RowIterator.
2282+
2283+
max_queue_size:
2284+
Ignored. Added for compatibility with RowIterator.
2285+
2286+
Returns:
2287+
An iterator yielding a single empty :class:`~pyarrow.RecordBatch`.
2288+
2289+
Raises:
2290+
ValueError:
2291+
If the :mod:`pyarrow` library cannot be imported.
2292+
"""
2293+
if pyarrow is None:
2294+
raise ValueError(_NO_PYARROW_ERROR)
2295+
return iter((pyarrow.record_batch([]),))
2296+
22282297
def __iter__(self):
22292298
return iter(())
22302299

tests/unit/test_table.py

+218
Original file line numberDiff line numberDiff line change
@@ -1840,6 +1840,25 @@ def test_to_arrow(self):
18401840
self.assertIsInstance(tbl, pyarrow.Table)
18411841
self.assertEqual(tbl.num_rows, 0)
18421842

1843+
@mock.patch("google.cloud.bigquery.table.pyarrow", new=None)
1844+
def test_to_arrow_iterable_error_if_pyarrow_is_none(self):
1845+
row_iterator = self._make_one()
1846+
with self.assertRaises(ValueError):
1847+
row_iterator.to_arrow_iterable()
1848+
1849+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
1850+
def test_to_arrow_iterable(self):
1851+
row_iterator = self._make_one()
1852+
arrow_iter = row_iterator.to_arrow_iterable()
1853+
1854+
result = list(arrow_iter)
1855+
1856+
self.assertEqual(len(result), 1)
1857+
record_batch = result[0]
1858+
self.assertIsInstance(record_batch, pyarrow.RecordBatch)
1859+
self.assertEqual(record_batch.num_rows, 0)
1860+
self.assertEqual(record_batch.num_columns, 0)
1861+
18431862
@mock.patch("google.cloud.bigquery.table.pandas", new=None)
18441863
def test_to_dataframe_error_if_pandas_is_none(self):
18451864
row_iterator = self._make_one()
@@ -2151,6 +2170,205 @@ def test__validate_bqstorage_returns_false_w_warning_if_obsolete_version(self):
21512170
]
21522171
assert matching_warnings, "Obsolete dependency warning not raised."
21532172

2173+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
2174+
def test_to_arrow_iterable(self):
2175+
from google.cloud.bigquery.schema import SchemaField
2176+
2177+
schema = [
2178+
SchemaField("name", "STRING", mode="REQUIRED"),
2179+
SchemaField("age", "INTEGER", mode="REQUIRED"),
2180+
SchemaField(
2181+
"child",
2182+
"RECORD",
2183+
mode="REPEATED",
2184+
fields=[
2185+
SchemaField("name", "STRING", mode="REQUIRED"),
2186+
SchemaField("age", "INTEGER", mode="REQUIRED"),
2187+
],
2188+
),
2189+
]
2190+
rows = [
2191+
{
2192+
"f": [
2193+
{"v": "Bharney Rhubble"},
2194+
{"v": "33"},
2195+
{
2196+
"v": [
2197+
{"v": {"f": [{"v": "Whamm-Whamm Rhubble"}, {"v": "3"}]}},
2198+
{"v": {"f": [{"v": "Hoppy"}, {"v": "1"}]}},
2199+
]
2200+
},
2201+
]
2202+
},
2203+
{
2204+
"f": [
2205+
{"v": "Wylma Phlyntstone"},
2206+
{"v": "29"},
2207+
{
2208+
"v": [
2209+
{"v": {"f": [{"v": "Bepples Phlyntstone"}, {"v": "0"}]}},
2210+
{"v": {"f": [{"v": "Dino"}, {"v": "4"}]}},
2211+
]
2212+
},
2213+
]
2214+
},
2215+
]
2216+
path = "/foo"
2217+
api_request = mock.Mock(
2218+
side_effect=[
2219+
{"rows": [rows[0]], "pageToken": "NEXTPAGE"},
2220+
{"rows": [rows[1]]},
2221+
]
2222+
)
2223+
row_iterator = self._make_one(
2224+
_mock_client(), api_request, path, schema, page_size=1, max_results=5
2225+
)
2226+
2227+
record_batches = row_iterator.to_arrow_iterable()
2228+
self.assertIsInstance(record_batches, types.GeneratorType)
2229+
record_batches = list(record_batches)
2230+
self.assertEqual(len(record_batches), 2)
2231+
2232+
# Check the schema.
2233+
for record_batch in record_batches:
2234+
self.assertIsInstance(record_batch, pyarrow.RecordBatch)
2235+
self.assertEqual(record_batch.schema[0].name, "name")
2236+
self.assertTrue(pyarrow.types.is_string(record_batch.schema[0].type))
2237+
self.assertEqual(record_batch.schema[1].name, "age")
2238+
self.assertTrue(pyarrow.types.is_int64(record_batch.schema[1].type))
2239+
child_field = record_batch.schema[2]
2240+
self.assertEqual(child_field.name, "child")
2241+
self.assertTrue(pyarrow.types.is_list(child_field.type))
2242+
self.assertTrue(pyarrow.types.is_struct(child_field.type.value_type))
2243+
self.assertEqual(child_field.type.value_type[0].name, "name")
2244+
self.assertEqual(child_field.type.value_type[1].name, "age")
2245+
2246+
# Check the data.
2247+
record_batch_1 = record_batches[0].to_pydict()
2248+
names = record_batch_1["name"]
2249+
ages = record_batch_1["age"]
2250+
children = record_batch_1["child"]
2251+
self.assertEqual(names, ["Bharney Rhubble"])
2252+
self.assertEqual(ages, [33])
2253+
self.assertEqual(
2254+
children,
2255+
[
2256+
[
2257+
{"name": "Whamm-Whamm Rhubble", "age": 3},
2258+
{"name": "Hoppy", "age": 1},
2259+
],
2260+
],
2261+
)
2262+
2263+
record_batch_2 = record_batches[1].to_pydict()
2264+
names = record_batch_2["name"]
2265+
ages = record_batch_2["age"]
2266+
children = record_batch_2["child"]
2267+
self.assertEqual(names, ["Wylma Phlyntstone"])
2268+
self.assertEqual(ages, [29])
2269+
self.assertEqual(
2270+
children,
2271+
[[{"name": "Bepples Phlyntstone", "age": 0}, {"name": "Dino", "age": 4}]],
2272+
)
2273+
2274+
@mock.patch("google.cloud.bigquery.table.pyarrow", new=None)
2275+
def test_to_arrow_iterable_error_if_pyarrow_is_none(self):
2276+
from google.cloud.bigquery.schema import SchemaField
2277+
2278+
schema = [
2279+
SchemaField("name", "STRING", mode="REQUIRED"),
2280+
SchemaField("age", "INTEGER", mode="REQUIRED"),
2281+
]
2282+
rows = [
2283+
{"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]},
2284+
{"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]},
2285+
]
2286+
path = "/foo"
2287+
api_request = mock.Mock(return_value={"rows": rows})
2288+
row_iterator = self._make_one(_mock_client(), api_request, path, schema)
2289+
2290+
with pytest.raises(ValueError, match="pyarrow"):
2291+
row_iterator.to_arrow_iterable()
2292+
2293+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
2294+
@unittest.skipIf(
2295+
bigquery_storage is None, "Requires `google-cloud-bigquery-storage`"
2296+
)
2297+
def test_to_arrow_iterable_w_bqstorage(self):
2298+
from google.cloud.bigquery import schema
2299+
from google.cloud.bigquery import table as mut
2300+
from google.cloud.bigquery_storage_v1 import reader
2301+
2302+
bqstorage_client = mock.create_autospec(bigquery_storage.BigQueryReadClient)
2303+
bqstorage_client._transport = mock.create_autospec(
2304+
big_query_read_grpc_transport.BigQueryReadGrpcTransport
2305+
)
2306+
streams = [
2307+
# Use two streams we want to check frames are read from each stream.
2308+
{"name": "/projects/proj/dataset/dset/tables/tbl/streams/1234"},
2309+
{"name": "/projects/proj/dataset/dset/tables/tbl/streams/5678"},
2310+
]
2311+
session = bigquery_storage.types.ReadSession(streams=streams)
2312+
arrow_schema = pyarrow.schema(
2313+
[
2314+
pyarrow.field("colA", pyarrow.int64()),
2315+
# Not alphabetical to test column order.
2316+
pyarrow.field("colC", pyarrow.float64()),
2317+
pyarrow.field("colB", pyarrow.string()),
2318+
]
2319+
)
2320+
session.arrow_schema.serialized_schema = arrow_schema.serialize().to_pybytes()
2321+
bqstorage_client.create_read_session.return_value = session
2322+
2323+
mock_rowstream = mock.create_autospec(reader.ReadRowsStream)
2324+
bqstorage_client.read_rows.return_value = mock_rowstream
2325+
2326+
mock_rows = mock.create_autospec(reader.ReadRowsIterable)
2327+
mock_rowstream.rows.return_value = mock_rows
2328+
page_items = [
2329+
pyarrow.array([1, -1]),
2330+
pyarrow.array([2.0, 4.0]),
2331+
pyarrow.array(["abc", "def"]),
2332+
]
2333+
2334+
expected_record_batch = pyarrow.RecordBatch.from_arrays(
2335+
page_items, schema=arrow_schema
2336+
)
2337+
expected_num_record_batches = 3
2338+
2339+
mock_page = mock.create_autospec(reader.ReadRowsPage)
2340+
mock_page.to_arrow.return_value = expected_record_batch
2341+
mock_pages = (mock_page,) * expected_num_record_batches
2342+
type(mock_rows).pages = mock.PropertyMock(return_value=mock_pages)
2343+
2344+
schema = [
2345+
schema.SchemaField("colA", "INTEGER"),
2346+
schema.SchemaField("colC", "FLOAT"),
2347+
schema.SchemaField("colB", "STRING"),
2348+
]
2349+
2350+
row_iterator = mut.RowIterator(
2351+
_mock_client(),
2352+
None, # api_request: ignored
2353+
None, # path: ignored
2354+
schema,
2355+
table=mut.TableReference.from_string("proj.dset.tbl"),
2356+
selected_fields=schema,
2357+
)
2358+
2359+
record_batches = list(
2360+
row_iterator.to_arrow_iterable(bqstorage_client=bqstorage_client)
2361+
)
2362+
total_record_batches = len(streams) * len(mock_pages)
2363+
self.assertEqual(len(record_batches), total_record_batches)
2364+
2365+
for record_batch in record_batches:
2366+
# Are the record batches return as expected?
2367+
self.assertEqual(record_batch, expected_record_batch)
2368+
2369+
# Don't close the client if it was passed in.
2370+
bqstorage_client._transport.grpc_channel.close.assert_not_called()
2371+
21542372
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
21552373
def test_to_arrow(self):
21562374
from google.cloud.bigquery.schema import SchemaField

0 commit comments

Comments
 (0)