Skip to content

Commit 94f335b

Browse files
committed
Add special handling for float16
1 parent 06e6958 commit 94f335b

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

pygmt/tests/test_clib_to_numpy.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,12 @@ def test_to_numpy_pandas_series_pyarrow_dtypes_numeric(dtype, expected_dtype):
187187
"""
188188
Test the _to_numpy function with pandas.Series of PyArrow numeric dtypes.
189189
"""
190-
series = pd.Series([1, 2, 3, 4, 5, 6], dtype=dtype)[::2] # Not C-contiguous
190+
data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
191+
if dtype == "float16[pyarrow]" and Version(pd.__version__) < Version("2.2"):
192+
# float16 needs special handling for pandas < 2.2.
193+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
194+
data = np.array(data, dtype=np.float16)
195+
series = pd.Series(data, dtype=dtype)[::2] # Not C-contiguous
191196
result = _to_numpy(series)
192197
_check_result(result, expected_dtype)
193198
npt.assert_array_equal(result, series)
@@ -214,7 +219,12 @@ def test_to_numpy_pandas_series_pyarrow_dtypes_numeric_with_na(dtype, expected_d
214219
"""
215220
Test the _to_numpy function with pandas.Series of PyArrow numeric dtypes and NA.
216221
"""
217-
series = pd.Series([1, 2, pd.NA, 4, 5, 6], dtype=dtype)[::2]
222+
data = [1.0, 2.0, None, 4.0, 5.0, 6.0]
223+
if dtype == "float16[pyarrow]" and Version(pd.__version__) < Version("2.2"):
224+
# float16 needs special handling for pandas < 2.2.
225+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
226+
data = np.array(data, dtype=np.float16)
227+
series = pd.Series(data, dtype=dtype)[::2] # Not C-contiguous
218228
assert series.isna().any()
219229
result = _to_numpy(series)
220230
_check_result(result, expected_dtype)

0 commit comments

Comments
 (0)