Skip to content

Commit 6160fee

Browse files
fix: validate job_config.source_format in load_table_from_dataframe (#262)
* fix: address job_congig.source_format * fix: nit
1 parent ae647eb commit 6160fee

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

google/cloud/bigquery/client.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2174,7 +2174,15 @@ def load_table_from_dataframe(
21742174
else:
21752175
job_config = job.LoadJobConfig()
21762176

2177-
job_config.source_format = job.SourceFormat.PARQUET
2177+
if job_config.source_format:
2178+
if job_config.source_format != job.SourceFormat.PARQUET:
2179+
raise ValueError(
2180+
"Got unexpected source_format: '{}'. Currently, only PARQUET is supported".format(
2181+
job_config.source_format
2182+
)
2183+
)
2184+
else:
2185+
job_config.source_format = job.SourceFormat.PARQUET
21782186

21792187
if location is None:
21802188
location = self.location

tests/unit/test_client.py

+76-2
Original file line numberDiff line numberDiff line change
@@ -7544,7 +7544,7 @@ def test_load_table_from_dataframe_w_client_location(self):
75447544

75457545
@unittest.skipIf(pandas is None, "Requires `pandas`")
75467546
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
7547-
def test_load_table_from_dataframe_w_custom_job_config(self):
7547+
def test_load_table_from_dataframe_w_custom_job_config_wihtout_source_format(self):
75487548
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
75497549
from google.cloud.bigquery import job
75507550
from google.cloud.bigquery.schema import SchemaField
@@ -7553,7 +7553,7 @@ def test_load_table_from_dataframe_w_custom_job_config(self):
75537553
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
75547554
dataframe = pandas.DataFrame(records)
75557555
job_config = job.LoadJobConfig(
7556-
write_disposition=job.WriteDisposition.WRITE_TRUNCATE
7556+
write_disposition=job.WriteDisposition.WRITE_TRUNCATE,
75577557
)
75587558
original_config_copy = copy.deepcopy(job_config)
75597559

@@ -7595,6 +7595,80 @@ def test_load_table_from_dataframe_w_custom_job_config(self):
75957595
# the original config object should not have been modified
75967596
assert job_config.to_api_repr() == original_config_copy.to_api_repr()
75977597

7598+
@unittest.skipIf(pandas is None, "Requires `pandas`")
7599+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
7600+
def test_load_table_from_dataframe_w_custom_job_config_w_source_format(self):
7601+
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
7602+
from google.cloud.bigquery import job
7603+
from google.cloud.bigquery.schema import SchemaField
7604+
7605+
client = self._make_client()
7606+
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
7607+
dataframe = pandas.DataFrame(records)
7608+
job_config = job.LoadJobConfig(
7609+
write_disposition=job.WriteDisposition.WRITE_TRUNCATE,
7610+
source_format=job.SourceFormat.PARQUET,
7611+
)
7612+
original_config_copy = copy.deepcopy(job_config)
7613+
7614+
get_table_patch = mock.patch(
7615+
"google.cloud.bigquery.client.Client.get_table",
7616+
autospec=True,
7617+
return_value=mock.Mock(
7618+
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
7619+
),
7620+
)
7621+
load_patch = mock.patch(
7622+
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
7623+
)
7624+
with load_patch as load_table_from_file, get_table_patch as get_table:
7625+
client.load_table_from_dataframe(
7626+
dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION
7627+
)
7628+
7629+
# no need to fetch and inspect table schema for WRITE_TRUNCATE jobs
7630+
assert not get_table.called
7631+
7632+
load_table_from_file.assert_called_once_with(
7633+
client,
7634+
mock.ANY,
7635+
self.TABLE_REF,
7636+
num_retries=_DEFAULT_NUM_RETRIES,
7637+
rewind=True,
7638+
job_id=mock.ANY,
7639+
job_id_prefix=None,
7640+
location=self.LOCATION,
7641+
project=None,
7642+
job_config=mock.ANY,
7643+
)
7644+
7645+
sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
7646+
assert sent_config.source_format == job.SourceFormat.PARQUET
7647+
assert sent_config.write_disposition == job.WriteDisposition.WRITE_TRUNCATE
7648+
7649+
# the original config object should not have been modified
7650+
assert job_config.to_api_repr() == original_config_copy.to_api_repr()
7651+
7652+
@unittest.skipIf(pandas is None, "Requires `pandas`")
7653+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
7654+
def test_load_table_from_dataframe_w_custom_job_config_w_wrong_source_format(self):
7655+
from google.cloud.bigquery import job
7656+
7657+
client = self._make_client()
7658+
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
7659+
dataframe = pandas.DataFrame(records)
7660+
job_config = job.LoadJobConfig(
7661+
write_disposition=job.WriteDisposition.WRITE_TRUNCATE,
7662+
source_format=job.SourceFormat.ORC,
7663+
)
7664+
7665+
with pytest.raises(ValueError) as exc:
7666+
client.load_table_from_dataframe(
7667+
dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION
7668+
)
7669+
7670+
assert "Got unexpected source_format:" in str(exc.value)
7671+
75987672
@unittest.skipIf(pandas is None, "Requires `pandas`")
75997673
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
76007674
def test_load_table_from_dataframe_w_automatic_schema(self):

0 commit comments

Comments
 (0)