Skip to content

Commit 003385d

Browse files
WillAydphofl
andcommitted
String dtype: implement object-dtype based StringArray variant with NumPy semantics (pandas-dev#58451)
Co-authored-by: Patrick Hoefler <[email protected]>
1 parent 67caf2d commit 003385d

File tree

14 files changed

+232
-48
lines changed

14 files changed

+232
-48
lines changed

Diff for: pandas/_libs/lib.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -2728,7 +2728,7 @@ def maybe_convert_objects(ndarray[object] objects,
27282728
if using_string_dtype() and is_string_array(objects, skipna=True):
27292729
from pandas.core.arrays.string_ import StringDtype
27302730

2731-
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
2731+
dtype = StringDtype(na_value=np.nan)
27322732
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)
27332733

27342734
seen.object_ = True

Diff for: pandas/_testing/asserters.py

+18
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,24 @@ def assert_extension_array_equal(
811811
left_na, right_na, obj=f"{obj} NA mask", index_values=index_values
812812
)
813813

814+
# Specifically for StringArrayNumpySemantics, validate here we have a valid array
815+
if (
816+
isinstance(left.dtype, StringDtype)
817+
and left.dtype.storage == "python"
818+
and left.dtype.na_value is np.nan
819+
):
820+
assert np.all(
821+
[np.isnan(val) for val in left._ndarray[left_na]] # type: ignore[attr-defined]
822+
), "wrong missing value sentinels"
823+
if (
824+
isinstance(right.dtype, StringDtype)
825+
and right.dtype.storage == "python"
826+
and right.dtype.na_value is np.nan
827+
):
828+
assert np.all(
829+
[np.isnan(val) for val in right._ndarray[right_na]] # type: ignore[attr-defined]
830+
), "wrong missing value sentinels"
831+
814832
left_valid = left[~left_na].to_numpy(dtype=object)
815833
right_valid = right[~right_na].to_numpy(dtype=object)
816834
if check_exact:

Diff for: pandas/compat/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import pandas.compat.compressors
2626
from pandas.compat.numpy import is_numpy_dev
2727
from pandas.compat.pyarrow import (
28+
HAS_PYARROW,
2829
pa_version_under10p1,
2930
pa_version_under11p0,
3031
pa_version_under13p0,
@@ -190,6 +191,7 @@ def get_bz2_file() -> type[pandas.compat.compressors.BZ2File]:
190191
"pa_version_under14p1",
191192
"pa_version_under16p0",
192193
"pa_version_under17p0",
194+
"HAS_PYARROW",
193195
"IS64",
194196
"ISMUSL",
195197
"PY310",

Diff for: pandas/compat/pyarrow.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
pa_version_under15p0 = _palv < Version("15.0.0")
1818
pa_version_under16p0 = _palv < Version("16.0.0")
1919
pa_version_under17p0 = _palv < Version("17.0.0")
20+
HAS_PYARROW = True
2021
except ImportError:
2122
pa_version_under10p1 = True
2223
pa_version_under11p0 = True
@@ -27,3 +28,4 @@
2728
pa_version_under15p0 = True
2829
pa_version_under16p0 = True
2930
pa_version_under17p0 = True
31+
HAS_PYARROW = False

Diff for: pandas/conftest.py

+4
Original file line numberDiff line numberDiff line change
@@ -1265,6 +1265,7 @@ def string_storage(request):
12651265
("python", pd.NA),
12661266
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
12671267
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
1268+
("python", np.nan),
12681269
]
12691270
)
12701271
def string_dtype_arguments(request):
@@ -1326,12 +1327,14 @@ def object_dtype(request):
13261327
("python", pd.NA),
13271328
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
13281329
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
1330+
("python", np.nan),
13291331
],
13301332
ids=[
13311333
"string=object",
13321334
"string=string[python]",
13331335
"string=string[pyarrow]",
13341336
"string=str[pyarrow]",
1337+
"string=str[python]",
13351338
],
13361339
)
13371340
def any_string_dtype(request):
@@ -1341,6 +1344,7 @@ def any_string_dtype(request):
13411344
* 'string[python]' (NA variant)
13421345
* 'string[pyarrow]' (NA variant)
13431346
* 'str' (NaN variant, with pyarrow)
1347+
* 'str' (NaN variant, without pyarrow)
13441348
"""
13451349
if isinstance(request.param, np.dtype):
13461350
return request.param

0 commit comments

Comments
 (0)