Skip to content

Commit faed77b

Browse files
String dtype: fix pyarrow-based IO + update tests (pandas-dev#59478)
1 parent 1eb8f0e commit faed77b

File tree

6 files changed

+80
-47
lines changed

6 files changed

+80
-47
lines changed

pandas/io/_util.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def _arrow_dtype_mapping() -> dict:
2424
pa.string(): pd.StringDtype(),
2525
pa.float32(): pd.Float32Dtype(),
2626
pa.float64(): pd.Float64Dtype(),
27+
pa.string(): pd.StringDtype(),
28+
pa.large_string(): pd.StringDtype(),
2729
}
2830

2931

pandas/tests/io/test_feather.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99

1010
from pandas.io.feather_format import read_feather, to_feather # isort:skip
1111

12-
pytestmark = [
13-
pytest.mark.filterwarnings(
14-
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
15-
),
16-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
17-
]
12+
pytestmark = pytest.mark.filterwarnings(
13+
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
14+
)
15+
1816

1917
pa = pytest.importorskip("pyarrow")
2018

@@ -154,8 +152,8 @@ def test_path_localpath(self):
154152
def test_passthrough_keywords(self):
155153
df = pd.DataFrame(
156154
1.1 * np.arange(120).reshape((30, 4)),
157-
columns=pd.Index(list("ABCD"), dtype=object),
158-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
155+
columns=pd.Index(list("ABCD")),
156+
index=pd.Index([f"i-{i}" for i in range(30)]),
159157
).reset_index()
160158
self.check_round_trip(df, write_kwargs={"version": 1})
161159

@@ -169,7 +167,9 @@ def test_http_path(self, feather_file, httpserver):
169167
res = read_feather(httpserver.url)
170168
tm.assert_frame_equal(expected, res)
171169

172-
def test_read_feather_dtype_backend(self, string_storage, dtype_backend):
170+
def test_read_feather_dtype_backend(
171+
self, string_storage, dtype_backend, using_infer_string
172+
):
173173
# GH#50765
174174
df = pd.DataFrame(
175175
{
@@ -191,7 +191,10 @@ def test_read_feather_dtype_backend(self, string_storage, dtype_backend):
191191

192192
if dtype_backend == "pyarrow":
193193
pa = pytest.importorskip("pyarrow")
194-
string_dtype = pd.ArrowDtype(pa.string())
194+
if using_infer_string:
195+
string_dtype = pd.ArrowDtype(pa.large_string())
196+
else:
197+
string_dtype = pd.ArrowDtype(pa.string())
195198
else:
196199
string_dtype = pd.StringDtype(string_storage)
197200

@@ -218,6 +221,10 @@ def test_read_feather_dtype_backend(self, string_storage, dtype_backend):
218221
}
219222
)
220223

224+
if using_infer_string:
225+
expected.columns = expected.columns.astype(
226+
pd.StringDtype(string_storage, na_value=np.nan)
227+
)
221228
tm.assert_frame_equal(result, expected)
222229

223230
def test_int_columns_and_index(self):

pandas/tests/io/test_fsspec.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_excel_options(fsspectest):
168168
assert fsspectest.test[0] == "read"
169169

170170

171-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
171+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string) fastparquet")
172172
def test_to_parquet_new_file(cleared_fs, df1):
173173
"""Regression test for writing to a not-yet-existent GCS Parquet file."""
174174
pytest.importorskip("fastparquet")
@@ -198,7 +198,7 @@ def test_arrowparquet_options(fsspectest):
198198

199199

200200
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) fastparquet
201-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
201+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string) fastparquet")
202202
def test_fastparquet_options(fsspectest):
203203
"""Regression test for writing to a not-yet-existent GCS Parquet file."""
204204
pytest.importorskip("fastparquet")
@@ -256,7 +256,7 @@ def test_s3_protocols(s3_public_bucket_with_data, tips_file, protocol, s3so):
256256
)
257257

258258

259-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
259+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string) fastparquet")
260260
@pytest.mark.single_cpu
261261
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) fastparquet
262262
def test_s3_parquet(s3_public_bucket, s3so, df1):

pandas/tests/io/test_gcs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_to_csv_compression_encoding_gcs(
197197
tm.assert_frame_equal(df, read_df)
198198

199199

200-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
200+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string) fastparquet")
201201
def test_to_parquet_gcs_new_file(monkeypatch, tmpdir):
202202
"""Regression test for writing to a not-yet-existent GCS Parquet file."""
203203
pytest.importorskip("fastparquet")

pandas/tests/io/test_orc.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import numpy as np
99
import pytest
1010

11-
from pandas._config import using_string_dtype
12-
1311
import pandas as pd
1412
from pandas import read_orc
1513
import pandas._testing as tm
@@ -19,12 +17,9 @@
1917

2018
import pyarrow as pa
2119

22-
pytestmark = [
23-
pytest.mark.filterwarnings(
24-
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
25-
),
26-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
27-
]
20+
pytestmark = pytest.mark.filterwarnings(
21+
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
22+
)
2823

2924

3025
@pytest.fixture
@@ -47,7 +42,7 @@ def orc_writer_dtypes_not_supported(request):
4742
return pd.DataFrame({"unimpl": request.param})
4843

4944

50-
def test_orc_reader_empty(dirpath):
45+
def test_orc_reader_empty(dirpath, using_infer_string):
5146
columns = [
5247
"boolean1",
5348
"byte1",
@@ -68,11 +63,12 @@ def test_orc_reader_empty(dirpath):
6863
"float32",
6964
"float64",
7065
"object",
71-
"object",
66+
"str" if using_infer_string else "object",
7267
]
7368
expected = pd.DataFrame(index=pd.RangeIndex(0))
7469
for colname, dtype in zip(columns, dtypes):
7570
expected[colname] = pd.Series(dtype=dtype)
71+
expected.columns = expected.columns.astype("str")
7672

7773
inputfile = os.path.join(dirpath, "TestOrcFile.emptyFile.orc")
7874
got = read_orc(inputfile, columns=columns)
@@ -309,7 +305,7 @@ def test_orc_writer_dtypes_not_supported(orc_writer_dtypes_not_supported):
309305
orc_writer_dtypes_not_supported.to_orc()
310306

311307

312-
def test_orc_dtype_backend_pyarrow():
308+
def test_orc_dtype_backend_pyarrow(using_infer_string):
313309
pytest.importorskip("pyarrow")
314310
df = pd.DataFrame(
315311
{
@@ -340,6 +336,13 @@ def test_orc_dtype_backend_pyarrow():
340336
for col in df.columns
341337
}
342338
)
339+
if using_infer_string:
340+
# ORC does not preserve distinction between string and large string
341+
# -> the default large string comes back as string
342+
string_dtype = pd.ArrowDtype(pa.string())
343+
expected["string"] = expected["string"].astype(string_dtype)
344+
expected["string_with_nan"] = expected["string_with_nan"].astype(string_dtype)
345+
expected["string_with_none"] = expected["string_with_none"].astype(string_dtype)
343346

344347
tm.assert_frame_equal(result, expected)
345348

pandas/tests/io/test_parquet.py

+43-22
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
pytest.mark.filterwarnings(
5656
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
5757
),
58-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
5958
]
6059

6160

@@ -64,11 +63,18 @@
6463
params=[
6564
pytest.param(
6665
"fastparquet",
67-
marks=pytest.mark.skipif(
68-
not _HAVE_FASTPARQUET
69-
or _get_option("mode.data_manager", silent=True) == "array",
70-
reason="fastparquet is not installed or ArrayManager is used",
71-
),
66+
marks=[
67+
pytest.mark.skipif(
68+
not _HAVE_FASTPARQUET
69+
or _get_option("mode.data_manager", silent=True) == "array",
70+
reason="fastparquet is not installed or ArrayManager is used",
71+
),
72+
pytest.mark.xfail(
73+
using_string_dtype(),
74+
reason="TODO(infer_string) fastparquet",
75+
strict=False,
76+
),
77+
],
7278
),
7379
pytest.param(
7480
"pyarrow",
@@ -90,17 +96,24 @@ def pa():
9096

9197

9298
@pytest.fixture
93-
def fp():
99+
def fp(request):
94100
if not _HAVE_FASTPARQUET:
95101
pytest.skip("fastparquet is not installed")
96102
elif _get_option("mode.data_manager", silent=True) == "array":
97103
pytest.skip("ArrayManager is not supported with fastparquet")
104+
if using_string_dtype():
105+
request.applymarker(
106+
pytest.mark.xfail(reason="TODO(infer_string) fastparquet", strict=False)
107+
)
98108
return "fastparquet"
99109

100110

101111
@pytest.fixture
102112
def df_compat():
103-
return pd.DataFrame({"A": [1, 2, 3], "B": "foo"})
113+
# TODO(infer_string) should this give str columns?
114+
return pd.DataFrame(
115+
{"A": [1, 2, 3], "B": "foo"}, columns=pd.Index(["A", "B"], dtype=object)
116+
)
104117

105118

106119
@pytest.fixture
@@ -389,16 +402,6 @@ def check_external_error_on_write(self, df, engine, exc):
389402
with tm.external_error_raised(exc):
390403
to_parquet(df, path, engine, compression=None)
391404

392-
@pytest.mark.network
393-
@pytest.mark.single_cpu
394-
def test_parquet_read_from_url(self, httpserver, datapath, df_compat, engine):
395-
if engine != "auto":
396-
pytest.importorskip(engine)
397-
with open(datapath("io", "data", "parquet", "simple.parquet"), mode="rb") as f:
398-
httpserver.serve_content(content=f.read())
399-
df = read_parquet(httpserver.url)
400-
tm.assert_frame_equal(df, df_compat)
401-
402405

403406
class TestBasic(Base):
404407
def test_error(self, engine):
@@ -696,6 +699,16 @@ def test_read_empty_array(self, pa, dtype):
696699
df, pa, read_kwargs={"dtype_backend": "numpy_nullable"}, expected=expected
697700
)
698701

702+
@pytest.mark.network
703+
@pytest.mark.single_cpu
704+
def test_parquet_read_from_url(self, httpserver, datapath, df_compat, engine):
705+
if engine != "auto":
706+
pytest.importorskip(engine)
707+
with open(datapath("io", "data", "parquet", "simple.parquet"), mode="rb") as f:
708+
httpserver.serve_content(content=f.read())
709+
df = read_parquet(httpserver.url, engine=engine)
710+
tm.assert_frame_equal(df, df_compat)
711+
699712

700713
class TestParquetPyArrow(Base):
701714
def test_basic(self, pa, df_full):
@@ -925,7 +938,7 @@ def test_write_with_schema(self, pa):
925938
out_df = df.astype(bool)
926939
check_round_trip(df, pa, write_kwargs={"schema": schema}, expected=out_df)
927940

928-
def test_additional_extension_arrays(self, pa):
941+
def test_additional_extension_arrays(self, pa, using_infer_string):
929942
# test additional ExtensionArrays that are supported through the
930943
# __arrow_array__ protocol
931944
pytest.importorskip("pyarrow")
@@ -936,17 +949,25 @@ def test_additional_extension_arrays(self, pa):
936949
"c": pd.Series(["a", None, "c"], dtype="string"),
937950
}
938951
)
939-
check_round_trip(df, pa)
952+
if using_infer_string:
953+
check_round_trip(df, pa, expected=df.astype({"c": "str"}))
954+
else:
955+
check_round_trip(df, pa)
940956

941957
df = pd.DataFrame({"a": pd.Series([1, 2, 3, None], dtype="Int64")})
942958
check_round_trip(df, pa)
943959

944-
def test_pyarrow_backed_string_array(self, pa, string_storage):
960+
def test_pyarrow_backed_string_array(self, pa, string_storage, using_infer_string):
945961
# test ArrowStringArray supported through the __arrow_array__ protocol
946962
pytest.importorskip("pyarrow")
947963
df = pd.DataFrame({"a": pd.Series(["a", None, "c"], dtype="string[pyarrow]")})
948964
with pd.option_context("string_storage", string_storage):
949-
check_round_trip(df, pa, expected=df.astype(f"string[{string_storage}]"))
965+
if using_infer_string:
966+
expected = df.astype("str")
967+
expected.columns = expected.columns.astype("str")
968+
else:
969+
expected = df.astype(f"string[{string_storage}]")
970+
check_round_trip(df, pa, expected=expected)
950971

951972
def test_additional_extension_types(self, pa):
952973
# test additional ExtensionArrays that are supported through the

0 commit comments

Comments
 (0)