Skip to content

Commit ffd22c1

Browse files
jbrockmendeljorisvandenbossche
authored andcommitted
BUG (string): ArrowEA comparisons with mismatched types (#59505)
* BUG: ArrowEA comparisons with mismatched types * move whatsnew * GH ref
1 parent 13ad111 commit ffd22c1

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

pandas/core/arrays/arrow/array.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,13 @@ def _cmp_method(self, other, op):
704704
if isinstance(
705705
other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)
706706
) or isinstance(getattr(other, "dtype", None), CategoricalDtype):
707-
result = pc_func(self._pa_array, self._box_pa(other))
707+
try:
708+
result = pc_func(self._pa_array, self._box_pa(other))
709+
except pa.ArrowNotImplementedError:
710+
# TODO: could this be wrong if other is object dtype?
711+
# in which case we need to operate pointwise?
712+
result = ops.invalid_comparison(self, other, op)
713+
result = pa.array(result, type=pa.bool_())
708714
elif is_scalar(other):
709715
try:
710716
result = pc_func(self._pa_array, self._box_pa(other))

pandas/core/arrays/string_arrow.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
BaseStringArray,
3838
StringDtype,
3939
)
40-
from pandas.core.ops import invalid_comparison
4140
from pandas.core.strings.object_array import ObjectStringArrayMixin
4241

4342
if not pa_version_under10p1:
@@ -565,10 +564,7 @@ def _convert_int_dtype(self, result):
565564
return result
566565

567566
def _cmp_method(self, other, op):
568-
try:
569-
result = super()._cmp_method(other, op)
570-
except pa.ArrowNotImplementedError:
571-
return invalid_comparison(self, other, op)
567+
result = super()._cmp_method(other, op)
572568
if op == operator.ne:
573569
return result.to_numpy(np.bool_, na_value=True)
574570
else:

pandas/tests/series/test_logical_ops.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas.compat import HAS_PYARROW
1010

1111
from pandas import (
12+
ArrowDtype,
1213
DataFrame,
1314
Index,
1415
Series,
@@ -539,18 +540,38 @@ def test_int_dtype_different_index_not_bool(self):
539540
result = ser1 ^ ser2
540541
tm.assert_series_equal(result, expected)
541542

543+
# TODO: this belongs in comparison tests
542544
def test_pyarrow_numpy_string_invalid(self):
543545
# GH#56008
544-
pytest.importorskip("pyarrow")
546+
pa = pytest.importorskip("pyarrow")
545547
ser = Series([False, True])
546548
ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]")
547549
result = ser == ser2
548-
expected = Series(False, index=ser.index)
549-
tm.assert_series_equal(result, expected)
550+
expected_eq = Series(False, index=ser.index)
551+
tm.assert_series_equal(result, expected_eq)
550552

551553
result = ser != ser2
552-
expected = Series(True, index=ser.index)
553-
tm.assert_series_equal(result, expected)
554+
expected_ne = Series(True, index=ser.index)
555+
tm.assert_series_equal(result, expected_ne)
554556

555557
with pytest.raises(TypeError, match="Invalid comparison"):
556558
ser > ser2
559+
560+
# GH#59505
561+
ser3 = ser2.astype("string[pyarrow]")
562+
result3_eq = ser3 == ser
563+
tm.assert_series_equal(result3_eq, expected_eq.astype("bool[pyarrow]"))
564+
result3_ne = ser3 != ser
565+
tm.assert_series_equal(result3_ne, expected_ne.astype("bool[pyarrow]"))
566+
567+
with pytest.raises(TypeError, match="Invalid comparison"):
568+
ser > ser3
569+
570+
ser4 = ser2.astype(ArrowDtype(pa.string()))
571+
result4_eq = ser4 == ser
572+
tm.assert_series_equal(result4_eq, expected_eq.astype("bool[pyarrow]"))
573+
result4_ne = ser4 != ser
574+
tm.assert_series_equal(result4_ne, expected_ne.astype("bool[pyarrow]"))
575+
576+
with pytest.raises(TypeError, match="Invalid comparison"):
577+
ser > ser4

0 commit comments

Comments
 (0)