Skip to content

Commit 2a9618f

Browse files
plamutdandhlee
andauthored
feat: add max_results parameter to some of the QueryJob methods (#698)
* feat: add max_results to a few QueryJob methods It is now possible to cap the number of result rows returned when invoking `to_dataframe()` or `to_arrow()` method on a `QueryJob` instance. * Work around a pytype complaint * Make _EmptyRowIterator a subclass of RowIterator Co-authored-by: Dan Lee <[email protected]>
1 parent b35e1ad commit 2a9618f

File tree

6 files changed

+240
-23
lines changed

6 files changed

+240
-23
lines changed

google/cloud/bigquery/_tqdm_helpers.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@
1616

1717
import concurrent.futures
1818
import time
19+
import typing
20+
from typing import Optional
1921
import warnings
2022

2123
try:
2224
import tqdm
2325
except ImportError: # pragma: NO COVER
2426
tqdm = None
2527

28+
if typing.TYPE_CHECKING: # pragma: NO COVER
29+
from google.cloud.bigquery import QueryJob
30+
from google.cloud.bigquery.table import RowIterator
31+
2632
_NO_TQDM_ERROR = (
2733
"A progress bar was requested, but there was an error loading the tqdm "
2834
"library. Please install tqdm to use the progress bar functionality."
@@ -32,7 +38,7 @@
3238

3339

3440
def get_progress_bar(progress_bar_type, description, total, unit):
35-
"""Construct a tqdm progress bar object, if tqdm is ."""
41+
"""Construct a tqdm progress bar object, if tqdm is installed."""
3642
if tqdm is None:
3743
if progress_bar_type is not None:
3844
warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3)
@@ -53,16 +59,34 @@ def get_progress_bar(progress_bar_type, description, total, unit):
5359
return None
5460

5561

56-
def wait_for_query(query_job, progress_bar_type=None):
57-
"""Return query result and display a progress bar while the query running, if tqdm is installed."""
62+
def wait_for_query(
63+
query_job: "QueryJob",
64+
progress_bar_type: Optional[str] = None,
65+
max_results: Optional[int] = None,
66+
) -> "RowIterator":
67+
"""Return query result and display a progress bar while the query running, if tqdm is installed.
68+
69+
Args:
70+
query_job:
71+
The job representing the execution of the query on the server.
72+
progress_bar_type:
73+
The type of progress bar to use to show query progress.
74+
max_results:
75+
The maximum number of rows the row iterator should return.
76+
77+
Returns:
78+
A row iterator over the query results.
79+
"""
5880
default_total = 1
5981
current_stage = None
6082
start_time = time.time()
83+
6184
progress_bar = get_progress_bar(
6285
progress_bar_type, "Query is running", default_total, "query"
6386
)
6487
if progress_bar is None:
65-
return query_job.result()
88+
return query_job.result(max_results=max_results)
89+
6690
i = 0
6791
while True:
6892
if query_job.query_plan:
@@ -75,7 +99,9 @@ def wait_for_query(query_job, progress_bar_type=None):
7599
),
76100
)
77101
try:
78-
query_result = query_job.result(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
102+
query_result = query_job.result(
103+
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=max_results
104+
)
79105
progress_bar.update(default_total)
80106
progress_bar.set_description(
81107
"Query complete after {:0.2f}s".format(time.time() - start_time),
@@ -89,5 +115,6 @@ def wait_for_query(query_job, progress_bar_type=None):
89115
progress_bar.update(i + 1)
90116
i += 1
91117
continue
118+
92119
progress_bar.close()
93120
return query_result

google/cloud/bigquery/job/query.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -1300,12 +1300,14 @@ def result(
13001300
return rows
13011301

13021302
# If changing the signature of this method, make sure to apply the same
1303-
# changes to table.RowIterator.to_arrow()
1303+
# changes to table.RowIterator.to_arrow(), except for the max_results parameter
1304+
# that should only exist here in the QueryJob method.
13041305
def to_arrow(
13051306
self,
13061307
progress_bar_type: str = None,
13071308
bqstorage_client: "bigquery_storage.BigQueryReadClient" = None,
13081309
create_bqstorage_client: bool = True,
1310+
max_results: Optional[int] = None,
13091311
) -> "pyarrow.Table":
13101312
"""[Beta] Create a class:`pyarrow.Table` by loading all pages of a
13111313
table or query.
@@ -1349,6 +1351,11 @@ def to_arrow(
13491351
13501352
..versionadded:: 1.24.0
13511353
1354+
max_results (Optional[int]):
1355+
Maximum number of rows to include in the result. No limit by default.
1356+
1357+
..versionadded:: 2.21.0
1358+
13521359
Returns:
13531360
pyarrow.Table
13541361
A :class:`pyarrow.Table` populated with row data and column
@@ -1361,22 +1368,24 @@ def to_arrow(
13611368
13621369
..versionadded:: 1.17.0
13631370
"""
1364-
query_result = wait_for_query(self, progress_bar_type)
1371+
query_result = wait_for_query(self, progress_bar_type, max_results=max_results)
13651372
return query_result.to_arrow(
13661373
progress_bar_type=progress_bar_type,
13671374
bqstorage_client=bqstorage_client,
13681375
create_bqstorage_client=create_bqstorage_client,
13691376
)
13701377

13711378
# If changing the signature of this method, make sure to apply the same
1372-
# changes to table.RowIterator.to_dataframe()
1379+
# changes to table.RowIterator.to_dataframe(), except for the max_results parameter
1380+
# that should only exist here in the QueryJob method.
13731381
def to_dataframe(
13741382
self,
13751383
bqstorage_client: "bigquery_storage.BigQueryReadClient" = None,
13761384
dtypes: Dict[str, Any] = None,
13771385
progress_bar_type: str = None,
13781386
create_bqstorage_client: bool = True,
13791387
date_as_object: bool = True,
1388+
max_results: Optional[int] = None,
13801389
) -> "pandas.DataFrame":
13811390
"""Return a pandas DataFrame from a QueryJob
13821391
@@ -1423,6 +1432,11 @@ def to_dataframe(
14231432
14241433
..versionadded:: 1.26.0
14251434
1435+
max_results (Optional[int]):
1436+
Maximum number of rows to include in the result. No limit by default.
1437+
1438+
..versionadded:: 2.21.0
1439+
14261440
Returns:
14271441
A :class:`~pandas.DataFrame` populated with row data and column
14281442
headers from the query results. The column headers are derived
@@ -1431,7 +1445,7 @@ def to_dataframe(
14311445
Raises:
14321446
ValueError: If the `pandas` library cannot be imported.
14331447
"""
1434-
query_result = wait_for_query(self, progress_bar_type)
1448+
query_result = wait_for_query(self, progress_bar_type, max_results=max_results)
14351449
return query_result.to_dataframe(
14361450
bqstorage_client=bqstorage_client,
14371451
dtypes=dtypes,

google/cloud/bigquery/table.py

+49-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import operator
2323
import pytz
2424
import typing
25-
from typing import Any, Dict, Iterable, Tuple
25+
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple
2626
import warnings
2727

2828
try:
@@ -1415,7 +1415,9 @@ class RowIterator(HTTPIterator):
14151415
"""A class for iterating through HTTP/JSON API row list responses.
14161416
14171417
Args:
1418-
client (google.cloud.bigquery.Client): The API client.
1418+
client (Optional[google.cloud.bigquery.Client]):
1419+
The API client instance. This should always be non-`None`, except for
1420+
subclasses that do not use it, namely the ``_EmptyRowIterator``.
14191421
api_request (Callable[google.cloud._http.JSONConnection.api_request]):
14201422
The function to use to make API requests.
14211423
path (str): The method path to query for the list of items.
@@ -1480,7 +1482,7 @@ def __init__(
14801482
self._field_to_index = _helpers._field_to_index_mapping(schema)
14811483
self._page_size = page_size
14821484
self._preserve_order = False
1483-
self._project = client.project
1485+
self._project = client.project if client is not None else None
14841486
self._schema = schema
14851487
self._selected_fields = selected_fields
14861488
self._table = table
@@ -1895,7 +1897,7 @@ def to_dataframe(
18951897
return df
18961898

18971899

1898-
class _EmptyRowIterator(object):
1900+
class _EmptyRowIterator(RowIterator):
18991901
"""An empty row iterator.
19001902
19011903
This class prevents API requests when there are no rows to fetch or rows
@@ -1907,6 +1909,18 @@ class _EmptyRowIterator(object):
19071909
pages = ()
19081910
total_rows = 0
19091911

1912+
def __init__(
1913+
self, client=None, api_request=None, path=None, schema=(), *args, **kwargs
1914+
):
1915+
super().__init__(
1916+
client=client,
1917+
api_request=api_request,
1918+
path=path,
1919+
schema=schema,
1920+
*args,
1921+
**kwargs,
1922+
)
1923+
19101924
def to_arrow(
19111925
self,
19121926
progress_bar_type=None,
@@ -1951,6 +1965,37 @@ def to_dataframe(
19511965
raise ValueError(_NO_PANDAS_ERROR)
19521966
return pandas.DataFrame()
19531967

1968+
def to_dataframe_iterable(
1969+
self,
1970+
bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None,
1971+
dtypes: Optional[Dict[str, Any]] = None,
1972+
max_queue_size: Optional[int] = None,
1973+
) -> Iterator["pandas.DataFrame"]:
1974+
"""Create an iterable of pandas DataFrames, to process the table as a stream.
1975+
1976+
..versionadded:: 2.21.0
1977+
1978+
Args:
1979+
bqstorage_client:
1980+
Ignored. Added for compatibility with RowIterator.
1981+
1982+
dtypes (Optional[Map[str, Union[str, pandas.Series.dtype]]]):
1983+
Ignored. Added for compatibility with RowIterator.
1984+
1985+
max_queue_size:
1986+
Ignored. Added for compatibility with RowIterator.
1987+
1988+
Returns:
1989+
An iterator yielding a single empty :class:`~pandas.DataFrame`.
1990+
1991+
Raises:
1992+
ValueError:
1993+
If the :mod:`pandas` library cannot be imported.
1994+
"""
1995+
if pandas is None:
1996+
raise ValueError(_NO_PANDAS_ERROR)
1997+
return iter((pandas.DataFrame(),))
1998+
19541999
def __iter__(self):
19552000
return iter(())
19562001

tests/unit/job/test_query_pandas.py

+97-4
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,41 @@ def test_to_arrow():
238238
]
239239

240240

241+
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
242+
def test_to_arrow_max_results_no_progress_bar():
243+
from google.cloud.bigquery import table
244+
from google.cloud.bigquery.job import QueryJob as target_class
245+
from google.cloud.bigquery.schema import SchemaField
246+
247+
connection = _make_connection({})
248+
client = _make_client(connection=connection)
249+
begun_resource = _make_job_resource(job_type="query")
250+
job = target_class.from_api_repr(begun_resource, client)
251+
252+
schema = [
253+
SchemaField("name", "STRING", mode="REQUIRED"),
254+
SchemaField("age", "INTEGER", mode="REQUIRED"),
255+
]
256+
rows = [
257+
{"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]},
258+
{"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]},
259+
]
260+
path = "/foo"
261+
api_request = mock.Mock(return_value={"rows": rows})
262+
row_iterator = table.RowIterator(client, api_request, path, schema)
263+
264+
result_patch = mock.patch(
265+
"google.cloud.bigquery.job.QueryJob.result", return_value=row_iterator,
266+
)
267+
with result_patch as result_patch_tqdm:
268+
tbl = job.to_arrow(create_bqstorage_client=False, max_results=123)
269+
270+
result_patch_tqdm.assert_called_once_with(max_results=123)
271+
272+
assert isinstance(tbl, pyarrow.Table)
273+
assert tbl.num_rows == 2
274+
275+
241276
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
242277
@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`")
243278
def test_to_arrow_w_tqdm_w_query_plan():
@@ -290,7 +325,9 @@ def test_to_arrow_w_tqdm_w_query_plan():
290325
assert result_patch_tqdm.call_count == 3
291326
assert isinstance(tbl, pyarrow.Table)
292327
assert tbl.num_rows == 2
293-
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
328+
result_patch_tqdm.assert_called_with(
329+
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
330+
)
294331

295332

296333
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
@@ -341,7 +378,9 @@ def test_to_arrow_w_tqdm_w_pending_status():
341378
assert result_patch_tqdm.call_count == 2
342379
assert isinstance(tbl, pyarrow.Table)
343380
assert tbl.num_rows == 2
344-
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
381+
result_patch_tqdm.assert_called_with(
382+
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
383+
)
345384

346385

347386
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
@@ -716,7 +755,9 @@ def test_to_dataframe_w_tqdm_pending():
716755
assert isinstance(df, pandas.DataFrame)
717756
assert len(df) == 4 # verify the number of rows
718757
assert list(df) == ["name", "age"] # verify the column names
719-
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
758+
result_patch_tqdm.assert_called_with(
759+
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
760+
)
720761

721762

722763
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@@ -774,4 +815,56 @@ def test_to_dataframe_w_tqdm():
774815
assert isinstance(df, pandas.DataFrame)
775816
assert len(df) == 4 # verify the number of rows
776817
assert list(df), ["name", "age"] # verify the column names
777-
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
818+
result_patch_tqdm.assert_called_with(
819+
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
820+
)
821+
822+
823+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
824+
@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`")
825+
def test_to_dataframe_w_tqdm_max_results():
826+
from google.cloud.bigquery import table
827+
from google.cloud.bigquery.job import QueryJob as target_class
828+
from google.cloud.bigquery.schema import SchemaField
829+
from google.cloud.bigquery._tqdm_helpers import _PROGRESS_BAR_UPDATE_INTERVAL
830+
831+
begun_resource = _make_job_resource(job_type="query")
832+
schema = [
833+
SchemaField("name", "STRING", mode="NULLABLE"),
834+
SchemaField("age", "INTEGER", mode="NULLABLE"),
835+
]
836+
rows = [{"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}]
837+
838+
connection = _make_connection({})
839+
client = _make_client(connection=connection)
840+
job = target_class.from_api_repr(begun_resource, client)
841+
842+
path = "/foo"
843+
api_request = mock.Mock(return_value={"rows": rows})
844+
row_iterator = table.RowIterator(client, api_request, path, schema)
845+
846+
job._properties["statistics"] = {
847+
"query": {
848+
"queryPlan": [
849+
{"name": "S00: Input", "id": "0", "status": "COMPLETE"},
850+
{"name": "S01: Output", "id": "1", "status": "COMPLETE"},
851+
]
852+
},
853+
}
854+
reload_patch = mock.patch(
855+
"google.cloud.bigquery.job._AsyncJob.reload", autospec=True
856+
)
857+
result_patch = mock.patch(
858+
"google.cloud.bigquery.job.QueryJob.result",
859+
side_effect=[concurrent.futures.TimeoutError, row_iterator],
860+
)
861+
862+
with result_patch as result_patch_tqdm, reload_patch:
863+
job.to_dataframe(
864+
progress_bar_type="tqdm", create_bqstorage_client=False, max_results=3
865+
)
866+
867+
assert result_patch_tqdm.call_count == 2
868+
result_patch_tqdm.assert_called_with(
869+
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=3
870+
)

0 commit comments

Comments
 (0)