Skip to content

Commit d42a562

Browse files
sycaitswast
andauthored
feat: add bigquery_client as a parameter for read_gbq and to_gbq (#878)
Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent efdbc13 commit d42a562

File tree

5 files changed

+90
-2
lines changed

5 files changed

+90
-2
lines changed

Diff for: pandas_gbq/gbq.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def __init__(
269269
client_secret=None,
270270
user_agent=None,
271271
rfc9110_delimiter=False,
272+
bigquery_client=None,
272273
):
273274
global context
274275
from google.api_core.exceptions import ClientError, GoogleAPIError
@@ -288,6 +289,14 @@ def __init__(
288289
self.client_secret = client_secret
289290
self.user_agent = user_agent
290291
self.rfc9110_delimiter = rfc9110_delimiter
292+
self.use_bqstorage_api = use_bqstorage_api
293+
294+
if bigquery_client is not None:
295+
# If a bq client is already provided, use it to populate auth fields.
296+
self.project_id = bigquery_client.project
297+
self.credentials = bigquery_client._credentials
298+
self.client = bigquery_client
299+
return
291300

292301
default_project = None
293302

@@ -325,8 +334,9 @@ def __init__(
325334
if context.project is None:
326335
context.project = self.project_id
327336

328-
self.client = self.get_client()
329-
self.use_bqstorage_api = use_bqstorage_api
337+
self.client = _get_client(
338+
self.user_agent, self.rfc9110_delimiter, self.project_id, self.credentials
339+
)
330340

331341
def _start_timer(self):
332342
self.start = time.time()
@@ -702,6 +712,7 @@ def read_gbq(
702712
client_secret=None,
703713
*,
704714
col_order=None,
715+
bigquery_client=None,
705716
):
706717
r"""Read data from Google BigQuery to a pandas DataFrame.
707718
@@ -849,6 +860,9 @@ def read_gbq(
849860
the user is attempting to connect to.
850861
col_order : list(str), optional
851862
Alias for columns, retained for backwards compatibility.
863+
bigquery_client : google.cloud.bigquery.Client, optional
864+
A Google Cloud BigQuery Python Client instance. If provided, it will be used for reading
865+
data, while the project and credentials parameters will be ignored.
852866
853867
Returns
854868
-------
@@ -900,6 +914,7 @@ def read_gbq(
900914
auth_redirect_uri=auth_redirect_uri,
901915
client_id=client_id,
902916
client_secret=client_secret,
917+
bigquery_client=bigquery_client,
903918
)
904919

905920
if _is_query(query_or_table):
@@ -971,6 +986,7 @@ def to_gbq(
971986
client_secret=None,
972987
user_agent=None,
973988
rfc9110_delimiter=False,
989+
bigquery_client=None,
974990
):
975991
"""Write a DataFrame to a Google BigQuery table.
976992
@@ -1087,6 +1103,9 @@ def to_gbq(
10871103
rfc9110_delimiter : bool
10881104
Sets user agent delimiter to a hyphen or a slash.
10891105
Default is False, meaning a hyphen will be used.
1106+
bigquery_client : google.cloud.bigquery.Client, optional
1107+
A Google Cloud BigQuery Python Client instance. If provided, it will be used for reading
1108+
data, while the project, user_agent, and credentials parameters will be ignored.
10901109
10911110
.. versionadded:: 0.23.3
10921111
"""
@@ -1157,6 +1176,7 @@ def to_gbq(
11571176
client_secret=client_secret,
11581177
user_agent=user_agent,
11591178
rfc9110_delimiter=rfc9110_delimiter,
1179+
bigquery_client=bigquery_client,
11601180
)
11611181
bqclient = connector.client
11621182

@@ -1492,3 +1512,22 @@ def create_user_agent(
14921512
user_agent = f"{user_agent} {identity}"
14931513

14941514
return user_agent
1515+
1516+
1517+
def _get_client(user_agent, rfc9110_delimiter, project_id, credentials):
1518+
import google.api_core.client_info
1519+
1520+
bigquery = FEATURES.bigquery_try_import()
1521+
1522+
user_agent = create_user_agent(
1523+
user_agent=user_agent, rfc9110_delimiter=rfc9110_delimiter
1524+
)
1525+
1526+
client_info = google.api_core.client_info.ClientInfo(
1527+
user_agent=user_agent,
1528+
)
1529+
return bigquery.Client(
1530+
project=project_id,
1531+
credentials=credentials,
1532+
client_info=client_info,
1533+
)

Diff for: tests/system/conftest.py

+14
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def to_gbq(credentials, project_id):
5454
)
5555

5656

57+
@pytest.fixture
58+
def to_gbq_with_bq_client(bigquery_client):
59+
import pandas_gbq
60+
61+
return functools.partial(pandas_gbq.to_gbq, bigquery_client=bigquery_client)
62+
63+
5764
@pytest.fixture
5865
def read_gbq(credentials, project_id):
5966
import pandas_gbq
@@ -63,6 +70,13 @@ def read_gbq(credentials, project_id):
6370
)
6471

6572

73+
@pytest.fixture
74+
def read_gbq_with_bq_client(bigquery_client):
75+
import pandas_gbq
76+
77+
return functools.partial(pandas_gbq.read_gbq, bigquery_client=bigquery_client)
78+
79+
6680
@pytest.fixture()
6781
def random_dataset_id(bigquery_client: bigquery.Client, project_id: str):
6882
dataset_id = prefixer.create_prefix()

Diff for: tests/system/test_gbq.py

+10
Original file line numberDiff line numberDiff line change
@@ -1398,3 +1398,13 @@ def test_to_gbq_does_not_override_mode(gbq_table, gbq_connector):
13981398
)
13991399

14001400
assert verify_schema(gbq_connector, gbq_table.dataset_id, table_id, table_schema)
1401+
1402+
1403+
def test_gbqconnector_init_with_bq_client(bigquery_client):
1404+
gbq_connector = gbq.GbqConnector(
1405+
project_id="project_id", credentials=None, bigquery_client=bigquery_client
1406+
)
1407+
1408+
assert gbq_connector.project_id == bigquery_client.project
1409+
assert gbq_connector.credentials is bigquery_client._credentials
1410+
assert gbq_connector.client is bigquery_client

Diff for: tests/system/test_read_gbq.py

+11
Original file line numberDiff line numberDiff line change
@@ -659,3 +659,14 @@ def test_dml_query(read_gbq, writable_table: str):
659659
"""
660660
result = read_gbq(query)
661661
assert result is not None
662+
663+
664+
def test_read_gbq_with_bq_client(read_gbq_with_bq_client):
665+
query = "SELECT * FROM UNNEST([1, 2, 3]) AS numbers"
666+
667+
actual_result = read_gbq_with_bq_client(query)
668+
669+
expected_result = pandas.DataFrame(
670+
{"numbers": pandas.Series([1, 2, 3], dtype="Int64")}
671+
)
672+
pandas.testing.assert_frame_equal(actual_result, expected_result)

Diff for: tests/system/test_to_gbq.py

+14
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,17 @@ def test_dataframe_round_trip_with_table_schema(
615615
pandas.testing.assert_frame_equal(
616616
expected_df.set_index("row_num").sort_index(), round_trip
617617
)
618+
619+
620+
def test_dataframe_round_trip_with_bq_client(
621+
to_gbq_with_bq_client, read_gbq_with_bq_client, random_dataset_id
622+
):
623+
table_id = (
624+
f"{random_dataset_id}.round_trip_w_bq_client_{random.randrange(1_000_000)}"
625+
)
626+
df = pandas.DataFrame({"numbers": pandas.Series([1, 2, 3], dtype="Int64")})
627+
628+
to_gbq_with_bq_client(df, table_id)
629+
result = read_gbq_with_bq_client(table_id)
630+
631+
pandas.testing.assert_frame_equal(result, df)

0 commit comments

Comments
 (0)