Skip to content

Commit 3d08919

Browse files
authored
clib.conversion._to_numpy: Add tests for pandas.Series with pandas string dtype (#3607)
1 parent d982275 commit 3d08919

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

pygmt/clib/conversion.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Functions to convert data types into ctypes friendly formats.
33
"""
44

5+
import contextlib
56
import ctypes as ctp
67
import warnings
78
from collections.abc import Sequence
@@ -160,7 +161,7 @@ def _to_numpy(data: Any) -> np.ndarray:
160161
dtypes: dict[str, type | str] = {
161162
# For string dtypes.
162163
"large_string": np.str_, # pa.large_string and pa.large_utf8
163-
"string": np.str_, # pa.string and pa.utf8
164+
"string": np.str_, # pa.string, pa.utf8, pd.StringDtype
164165
"string_view": np.str_, # pa.string_view
165166
# For datetime dtypes.
166167
"date32[day][pyarrow]": "datetime64[D]",
@@ -180,6 +181,11 @@ def _to_numpy(data: Any) -> np.ndarray:
180181
else:
181182
vec_dtype = str(getattr(data, "dtype", getattr(data, "type", "")))
182183
array = np.ascontiguousarray(data, dtype=dtypes.get(vec_dtype))
184+
185+
# Check if a np.object_ array can be converted to np.str_.
186+
if array.dtype == np.object_:
187+
with contextlib.suppress(TypeError, ValueError):
188+
return np.ascontiguousarray(array, dtype=np.str_)
183189
return array
184190

185191

pygmt/clib/session.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1475,7 +1475,7 @@ def virtualfile_from_vectors(
14751475
# 2 columns contains coordinates like longitude, latitude, or datetime string
14761476
# types.
14771477
for col, array in enumerate(arrays[2:]):
1478-
if pd.api.types.is_string_dtype(array.dtype):
1478+
if np.issubdtype(array.dtype, np.str_):
14791479
columns = col + 2
14801480
break
14811481

@@ -1506,9 +1506,9 @@ def virtualfile_from_vectors(
15061506
strings = string_arrays[0]
15071507
elif len(string_arrays) > 1:
15081508
strings = np.array(
1509-
[" ".join(vals) for vals in zip(*string_arrays, strict=True)]
1509+
[" ".join(vals) for vals in zip(*string_arrays, strict=True)],
1510+
dtype=np.str_,
15101511
)
1511-
strings = np.asanyarray(a=strings, dtype=np.str_)
15121512
self.put_strings(
15131513
dataset, family="GMT_IS_VECTOR|GMT_IS_DUPLICATE", strings=strings
15141514
)

pygmt/tests/test_clib_to_numpy.py

+26
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212
from packaging.version import Version
1313
from pygmt.clib.conversion import _to_numpy
14+
from pygmt.helpers.testing import skip_if_no
1415

1516
try:
1617
import pyarrow as pa
@@ -174,6 +175,31 @@ def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
174175
npt.assert_array_equal(result, series)
175176

176177

178+
@pytest.mark.parametrize(
179+
"dtype",
180+
[
181+
None,
182+
np.str_,
183+
"U10",
184+
"string[python]",
185+
pytest.param("string[pyarrow]", marks=skip_if_no(package="pyarrow")),
186+
pytest.param("string[pyarrow_numpy]", marks=skip_if_no(package="pyarrow")),
187+
],
188+
)
189+
def test_to_numpy_pandas_series_pandas_dtypes_string(dtype):
190+
"""
191+
Test the _to_numpy function with pandas.Series of pandas string types.
192+
193+
In pandas, string arrays can be specified in multiple ways.
194+
195+
Reference: https://pandas.pydata.org/docs/reference/api/pandas.StringDtype.html
196+
"""
197+
array = pd.Series(["abc", "defg", "12345"], dtype=dtype)
198+
result = _to_numpy(array)
199+
_check_result(result, np.str_)
200+
npt.assert_array_equal(result, array)
201+
202+
177203
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
178204
@pytest.mark.parametrize(
179205
("dtype", "expected_dtype"),

0 commit comments

Comments
 (0)