|
30 | 30 | from six.moves import http_client
|
31 | 31 | import pytest
|
32 | 32 | import pytz
|
| 33 | +import pkg_resources |
33 | 34 |
|
34 | 35 | try:
|
35 | 36 | import fastparquet
|
|
56 | 57 | bigquery_storage_v1beta1 = None
|
57 | 58 | from tests.unit.helpers import make_connection
|
58 | 59 |
|
| 60 | +PANDAS_MINIUM_VERSION = pkg_resources.parse_version("1.0.0") |
| 61 | +PANDAS_INSTALLED_VERSION = pkg_resources.get_distribution("pandas").parsed_version |
| 62 | + |
59 | 63 |
|
60 | 64 | def _make_credentials():
|
61 | 65 | import google.auth.credentials
|
@@ -6973,6 +6977,98 @@ def test_load_table_from_dataframe_no_schema_warning_wo_pyarrow(self):
|
6973 | 6977 | ]
|
6974 | 6978 | assert matches, "A missing schema deprecation warning was not raised."
|
6975 | 6979 |
|
| 6980 | + @unittest.skipIf( |
| 6981 | + pandas is None or PANDAS_INSTALLED_VERSION < PANDAS_MINIUM_VERSION, |
| 6982 | + "Only `pandas version >=1.0.0` supported", |
| 6983 | + ) |
| 6984 | + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") |
| 6985 | + def test_load_table_from_dataframe_w_nullable_int64_datatype(self): |
| 6986 | + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES |
| 6987 | + from google.cloud.bigquery import job |
| 6988 | + from google.cloud.bigquery.schema import SchemaField |
| 6989 | + |
| 6990 | + client = self._make_client() |
| 6991 | + dataframe = pandas.DataFrame({"x": [1, 2, None, 4]}, dtype="Int64") |
| 6992 | + load_patch = mock.patch( |
| 6993 | + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True |
| 6994 | + ) |
| 6995 | + |
| 6996 | + get_table_patch = mock.patch( |
| 6997 | + "google.cloud.bigquery.client.Client.get_table", |
| 6998 | + autospec=True, |
| 6999 | + return_value=mock.Mock(schema=[SchemaField("x", "INT64", "NULLABLE")]), |
| 7000 | + ) |
| 7001 | + |
| 7002 | + with load_patch as load_table_from_file, get_table_patch: |
| 7003 | + client.load_table_from_dataframe( |
| 7004 | + dataframe, self.TABLE_REF, location=self.LOCATION |
| 7005 | + ) |
| 7006 | + |
| 7007 | + load_table_from_file.assert_called_once_with( |
| 7008 | + client, |
| 7009 | + mock.ANY, |
| 7010 | + self.TABLE_REF, |
| 7011 | + num_retries=_DEFAULT_NUM_RETRIES, |
| 7012 | + rewind=True, |
| 7013 | + job_id=mock.ANY, |
| 7014 | + job_id_prefix=None, |
| 7015 | + location=self.LOCATION, |
| 7016 | + project=None, |
| 7017 | + job_config=mock.ANY, |
| 7018 | + ) |
| 7019 | + |
| 7020 | + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] |
| 7021 | + assert sent_config.source_format == job.SourceFormat.PARQUET |
| 7022 | + assert tuple(sent_config.schema) == ( |
| 7023 | + SchemaField("x", "INT64", "NULLABLE", None), |
| 7024 | + ) |
| 7025 | + |
| 7026 | + @unittest.skipIf( |
| 7027 | + pandas is None or PANDAS_INSTALLED_VERSION < PANDAS_MINIUM_VERSION, |
| 7028 | + "Only `pandas version >=1.0.0` supported", |
| 7029 | + ) |
| 7030 | + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") |
| 7031 | + def test_load_table_from_dataframe_w_nullable_int64_datatype_automatic_schema(self): |
| 7032 | + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES |
| 7033 | + from google.cloud.bigquery import job |
| 7034 | + from google.cloud.bigquery.schema import SchemaField |
| 7035 | + |
| 7036 | + client = self._make_client() |
| 7037 | + dataframe = pandas.DataFrame({"x": [1, 2, None, 4]}, dtype="Int64") |
| 7038 | + load_patch = mock.patch( |
| 7039 | + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True |
| 7040 | + ) |
| 7041 | + |
| 7042 | + get_table_patch = mock.patch( |
| 7043 | + "google.cloud.bigquery.client.Client.get_table", |
| 7044 | + autospec=True, |
| 7045 | + side_effect=google.api_core.exceptions.NotFound("Table not found"), |
| 7046 | + ) |
| 7047 | + |
| 7048 | + with load_patch as load_table_from_file, get_table_patch: |
| 7049 | + client.load_table_from_dataframe( |
| 7050 | + dataframe, self.TABLE_REF, location=self.LOCATION |
| 7051 | + ) |
| 7052 | + |
| 7053 | + load_table_from_file.assert_called_once_with( |
| 7054 | + client, |
| 7055 | + mock.ANY, |
| 7056 | + self.TABLE_REF, |
| 7057 | + num_retries=_DEFAULT_NUM_RETRIES, |
| 7058 | + rewind=True, |
| 7059 | + job_id=mock.ANY, |
| 7060 | + job_id_prefix=None, |
| 7061 | + location=self.LOCATION, |
| 7062 | + project=None, |
| 7063 | + job_config=mock.ANY, |
| 7064 | + ) |
| 7065 | + |
| 7066 | + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] |
| 7067 | + assert sent_config.source_format == job.SourceFormat.PARQUET |
| 7068 | + assert tuple(sent_config.schema) == ( |
| 7069 | + SchemaField("x", "INT64", "NULLABLE", None), |
| 7070 | + ) |
| 7071 | + |
6976 | 7072 | @unittest.skipIf(pandas is None, "Requires `pandas`")
|
6977 | 7073 | @unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
|
6978 | 7074 | def test_load_table_from_dataframe_struct_fields_error(self):
|
|
0 commit comments