Skip to content

Commit 0046742

Browse files
authored
feat: support CSV format in load_table_from_dataframe pandas connector (#399)
* WIP: support alternative serialization formats for load_table_from_dataframe * fix: address review comments * docs: make clear repeated fields are not supportedin csv
1 parent a9d8ae8 commit 0046742

File tree

3 files changed

+239
-27
lines changed

3 files changed

+239
-27
lines changed

google/cloud/bigquery/client.py

+55-27
Original file line numberDiff line numberDiff line change
@@ -2111,9 +2111,12 @@ def load_table_from_dataframe(
21112111
21122112
.. note::
21132113
2114-
Due to the way REPEATED fields are encoded in the ``parquet`` file
2115-
format, a mismatch with the existing table schema can occur, and
2116-
100% compatibility cannot be guaranteed for REPEATED fields.
2114+
REPEATED fields are NOT supported when using the CSV source format.
2115+
They are supported when using the PARQUET source format, but
2116+
due to the way they are encoded in the ``parquet`` file,
2117+
a mismatch with the existing table schema can occur, so
2118+
100% compatibility cannot be guaranteed for REPEATED fields when
2119+
using the parquet format.
21172120
21182121
https://github.com/googleapis/python-bigquery/issues/17
21192122
@@ -2153,6 +2156,14 @@ def load_table_from_dataframe(
21532156
column names matching those of the dataframe. The BigQuery
21542157
schema is used to determine the correct data type conversion.
21552158
Indexes are not loaded. Requires the :mod:`pyarrow` library.
2159+
2160+
By default, this method uses the parquet source format. To
2161+
override this, supply a value for
2162+
:attr:`~google.cloud.bigquery.job.LoadJobConfig.source_format`
2163+
with the format name. Currently only
2164+
:attr:`~google.cloud.bigquery.job.SourceFormat.CSV` and
2165+
:attr:`~google.cloud.bigquery.job.SourceFormat.PARQUET` are
2166+
supported.
21562167
parquet_compression (Optional[str]):
21572168
[Beta] The compression method to use if intermittently
21582169
serializing ``dataframe`` to a parquet file.
@@ -2181,10 +2192,6 @@ def load_table_from_dataframe(
21812192
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig`
21822193
class.
21832194
"""
2184-
if pyarrow is None:
2185-
# pyarrow is now the only supported parquet engine.
2186-
raise ValueError("This method requires pyarrow to be installed")
2187-
21882195
job_id = _make_job_id(job_id, job_id_prefix)
21892196

21902197
if job_config:
@@ -2197,15 +2204,20 @@ def load_table_from_dataframe(
21972204
else:
21982205
job_config = job.LoadJobConfig()
21992206

2200-
if job_config.source_format:
2201-
if job_config.source_format != job.SourceFormat.PARQUET:
2202-
raise ValueError(
2203-
"Got unexpected source_format: '{}'. Currently, only PARQUET is supported".format(
2204-
job_config.source_format
2205-
)
2206-
)
2207-
else:
2207+
supported_formats = {job.SourceFormat.CSV, job.SourceFormat.PARQUET}
2208+
if job_config.source_format is None:
2209+
# default value
22082210
job_config.source_format = job.SourceFormat.PARQUET
2211+
if job_config.source_format not in supported_formats:
2212+
raise ValueError(
2213+
"Got unexpected source_format: '{}'. Currently, only PARQUET and CSV are supported".format(
2214+
job_config.source_format
2215+
)
2216+
)
2217+
2218+
if pyarrow is None and job_config.source_format == job.SourceFormat.PARQUET:
2219+
# pyarrow is now the only supported parquet engine.
2220+
raise ValueError("This method requires pyarrow to be installed")
22092221

22102222
if location is None:
22112223
location = self.location
@@ -2245,27 +2257,43 @@ def load_table_from_dataframe(
22452257
stacklevel=2,
22462258
)
22472259

2248-
tmpfd, tmppath = tempfile.mkstemp(suffix="_job_{}.parquet".format(job_id[:8]))
2260+
tmpfd, tmppath = tempfile.mkstemp(
2261+
suffix="_job_{}.{}".format(job_id[:8], job_config.source_format.lower())
2262+
)
22492263
os.close(tmpfd)
22502264

22512265
try:
2252-
if job_config.schema:
2253-
if parquet_compression == "snappy": # adjust the default value
2254-
parquet_compression = parquet_compression.upper()
22552266

2256-
_pandas_helpers.dataframe_to_parquet(
2257-
dataframe,
2258-
job_config.schema,
2267+
if job_config.source_format == job.SourceFormat.PARQUET:
2268+
2269+
if job_config.schema:
2270+
if parquet_compression == "snappy": # adjust the default value
2271+
parquet_compression = parquet_compression.upper()
2272+
2273+
_pandas_helpers.dataframe_to_parquet(
2274+
dataframe,
2275+
job_config.schema,
2276+
tmppath,
2277+
parquet_compression=parquet_compression,
2278+
)
2279+
else:
2280+
dataframe.to_parquet(tmppath, compression=parquet_compression)
2281+
2282+
else:
2283+
2284+
dataframe.to_csv(
22592285
tmppath,
2260-
parquet_compression=parquet_compression,
2286+
index=False,
2287+
header=False,
2288+
encoding="utf-8",
2289+
float_format="%.17g",
2290+
date_format="%Y-%m-%d %H:%M:%S.%f",
22612291
)
2262-
else:
2263-
dataframe.to_parquet(tmppath, compression=parquet_compression)
22642292

2265-
with open(tmppath, "rb") as parquet_file:
2293+
with open(tmppath, "rb") as tmpfile:
22662294
file_size = os.path.getsize(tmppath)
22672295
return self.load_table_from_file(
2268-
parquet_file,
2296+
tmpfile,
22692297
destination,
22702298
num_retries=num_retries,
22712299
rewind=True,

tests/system.py

+134
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,140 @@ def test_load_table_from_json_basic_use(self):
11651165
self.assertEqual(tuple(table.schema), table_schema)
11661166
self.assertEqual(table.num_rows, 2)
11671167

1168+
@unittest.skipIf(pandas is None, "Requires `pandas`")
1169+
def test_load_table_from_dataframe_w_explicit_schema_source_format_csv(self):
1170+
from google.cloud.bigquery.job import SourceFormat
1171+
1172+
table_schema = (
1173+
bigquery.SchemaField("bool_col", "BOOLEAN"),
1174+
bigquery.SchemaField("bytes_col", "BYTES"),
1175+
bigquery.SchemaField("date_col", "DATE"),
1176+
bigquery.SchemaField("dt_col", "DATETIME"),
1177+
bigquery.SchemaField("float_col", "FLOAT"),
1178+
bigquery.SchemaField("geo_col", "GEOGRAPHY"),
1179+
bigquery.SchemaField("int_col", "INTEGER"),
1180+
bigquery.SchemaField("num_col", "NUMERIC"),
1181+
bigquery.SchemaField("str_col", "STRING"),
1182+
bigquery.SchemaField("time_col", "TIME"),
1183+
bigquery.SchemaField("ts_col", "TIMESTAMP"),
1184+
)
1185+
df_data = collections.OrderedDict(
1186+
[
1187+
("bool_col", [True, None, False]),
1188+
("bytes_col", ["abc", None, "def"]),
1189+
(
1190+
"date_col",
1191+
[datetime.date(1, 1, 1), None, datetime.date(9999, 12, 31)],
1192+
),
1193+
(
1194+
"dt_col",
1195+
[
1196+
datetime.datetime(1, 1, 1, 0, 0, 0),
1197+
None,
1198+
datetime.datetime(9999, 12, 31, 23, 59, 59, 999999),
1199+
],
1200+
),
1201+
("float_col", [float("-inf"), float("nan"), float("inf")]),
1202+
(
1203+
"geo_col",
1204+
[
1205+
"POINT(30 10)",
1206+
None,
1207+
"POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))",
1208+
],
1209+
),
1210+
("int_col", [-9223372036854775808, None, 9223372036854775807]),
1211+
(
1212+
"num_col",
1213+
[
1214+
decimal.Decimal("-99999999999999999999999999999.999999999"),
1215+
None,
1216+
decimal.Decimal("99999999999999999999999999999.999999999"),
1217+
],
1218+
),
1219+
("str_col", [u"abc", None, u"def"]),
1220+
(
1221+
"time_col",
1222+
[datetime.time(0, 0, 0), None, datetime.time(23, 59, 59, 999999)],
1223+
),
1224+
(
1225+
"ts_col",
1226+
[
1227+
datetime.datetime(1, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
1228+
None,
1229+
datetime.datetime(
1230+
9999, 12, 31, 23, 59, 59, 999999, tzinfo=pytz.utc
1231+
),
1232+
],
1233+
),
1234+
]
1235+
)
1236+
dataframe = pandas.DataFrame(df_data, dtype="object", columns=df_data.keys())
1237+
1238+
dataset_id = _make_dataset_id("bq_load_test")
1239+
self.temp_dataset(dataset_id)
1240+
table_id = "{}.{}.load_table_from_dataframe_w_explicit_schema_csv".format(
1241+
Config.CLIENT.project, dataset_id
1242+
)
1243+
1244+
job_config = bigquery.LoadJobConfig(
1245+
schema=table_schema, source_format=SourceFormat.CSV
1246+
)
1247+
load_job = Config.CLIENT.load_table_from_dataframe(
1248+
dataframe, table_id, job_config=job_config
1249+
)
1250+
load_job.result()
1251+
1252+
table = Config.CLIENT.get_table(table_id)
1253+
self.assertEqual(tuple(table.schema), table_schema)
1254+
self.assertEqual(table.num_rows, 3)
1255+
1256+
@unittest.skipIf(pandas is None, "Requires `pandas`")
1257+
def test_load_table_from_dataframe_w_explicit_schema_source_format_csv_floats(self):
1258+
from google.cloud.bigquery.job import SourceFormat
1259+
1260+
table_schema = (bigquery.SchemaField("float_col", "FLOAT"),)
1261+
df_data = collections.OrderedDict(
1262+
[
1263+
(
1264+
"float_col",
1265+
[
1266+
0.14285714285714285,
1267+
0.51428571485748,
1268+
0.87128748,
1269+
1.807960649,
1270+
2.0679610649,
1271+
2.4406779661016949,
1272+
3.7148514257,
1273+
3.8571428571428572,
1274+
1.51251252e40,
1275+
],
1276+
),
1277+
]
1278+
)
1279+
dataframe = pandas.DataFrame(df_data, dtype="object", columns=df_data.keys())
1280+
1281+
dataset_id = _make_dataset_id("bq_load_test")
1282+
self.temp_dataset(dataset_id)
1283+
table_id = "{}.{}.load_table_from_dataframe_w_explicit_schema_csv".format(
1284+
Config.CLIENT.project, dataset_id
1285+
)
1286+
1287+
job_config = bigquery.LoadJobConfig(
1288+
schema=table_schema, source_format=SourceFormat.CSV
1289+
)
1290+
load_job = Config.CLIENT.load_table_from_dataframe(
1291+
dataframe, table_id, job_config=job_config
1292+
)
1293+
load_job.result()
1294+
1295+
table = Config.CLIENT.get_table(table_id)
1296+
rows = self._fetch_single_page(table)
1297+
floats = [r.values()[0] for r in rows]
1298+
self.assertEqual(tuple(table.schema), table_schema)
1299+
self.assertEqual(table.num_rows, 9)
1300+
self.assertEqual(floats, df_data["float_col"])
1301+
11681302
def test_load_table_from_json_schema_autodetect(self):
11691303
json_rows = [
11701304
{"name": "John", "age": 18, "birthday": "2001-10-15", "is_awesome": False},

tests/unit/test_client.py

+50
Original file line numberDiff line numberDiff line change
@@ -8410,6 +8410,56 @@ def test_load_table_from_dataframe_w_invaild_job_config(self):
84108410
err_msg = str(exc.value)
84118411
assert "Expected an instance of LoadJobConfig" in err_msg
84128412

8413+
@unittest.skipIf(pandas is None, "Requires `pandas`")
8414+
def test_load_table_from_dataframe_with_csv_source_format(self):
8415+
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
8416+
from google.cloud.bigquery import job
8417+
from google.cloud.bigquery.schema import SchemaField
8418+
8419+
client = self._make_client()
8420+
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
8421+
dataframe = pandas.DataFrame(records)
8422+
job_config = job.LoadJobConfig(
8423+
write_disposition=job.WriteDisposition.WRITE_TRUNCATE,
8424+
source_format=job.SourceFormat.CSV,
8425+
)
8426+
8427+
get_table_patch = mock.patch(
8428+
"google.cloud.bigquery.client.Client.get_table",
8429+
autospec=True,
8430+
return_value=mock.Mock(
8431+
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
8432+
),
8433+
)
8434+
load_patch = mock.patch(
8435+
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
8436+
)
8437+
with load_patch as load_table_from_file, get_table_patch:
8438+
client.load_table_from_dataframe(
8439+
dataframe, self.TABLE_REF, job_config=job_config
8440+
)
8441+
8442+
load_table_from_file.assert_called_once_with(
8443+
client,
8444+
mock.ANY,
8445+
self.TABLE_REF,
8446+
num_retries=_DEFAULT_NUM_RETRIES,
8447+
rewind=True,
8448+
size=mock.ANY,
8449+
job_id=mock.ANY,
8450+
job_id_prefix=None,
8451+
location=None,
8452+
project=None,
8453+
job_config=mock.ANY,
8454+
timeout=None,
8455+
)
8456+
8457+
sent_file = load_table_from_file.mock_calls[0][1][1]
8458+
assert sent_file.closed
8459+
8460+
sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
8461+
assert sent_config.source_format == job.SourceFormat.CSV
8462+
84138463
def test_load_table_from_json_basic_use(self):
84148464
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
84158465
from google.cloud.bigquery import job

0 commit comments

Comments
 (0)