Skip to content

Commit fb6842d

Browse files
authored
BUG (string): ArrowEA comparisons with mismatched types (#59505)
* BUG: ArrowEA comparisons with mismatched types * move whatsnew * GH ref
1 parent 6423ee8 commit fb6842d

File tree

4 files changed

+35
-11
lines changed

4 files changed

+35
-11
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,7 @@ ExtensionArray
667667
^^^^^^^^^^^^^^
668668
- Bug in :meth:`.arrays.ArrowExtensionArray.__setitem__` which caused wrong behavior when using an integer array with repeated values as a key (:issue:`58530`)
669669
- Bug in :meth:`api.types.is_datetime64_any_dtype` where a custom :class:`ExtensionDtype` would return ``False`` for array-likes (:issue:`57055`)
670+
- Bug in comparison between object with :class:`ArrowDtype` and incompatible-dtyped (e.g. string vs bool) incorrectly raising instead of returning all-``False`` (for ``==``) or all-``True`` (for ``!=``) (:issue:`59505`)
670671
- Bug in various :class:`DataFrame` reductions for pyarrow temporal dtypes returning incorrect dtype when result was null (:issue:`59234`)
671672

672673
Styler

pandas/core/arrays/arrow/array.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,13 @@ def _cmp_method(self, other, op) -> ArrowExtensionArray:
709709
if isinstance(
710710
other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)
711711
) or isinstance(getattr(other, "dtype", None), CategoricalDtype):
712-
result = pc_func(self._pa_array, self._box_pa(other))
712+
try:
713+
result = pc_func(self._pa_array, self._box_pa(other))
714+
except pa.ArrowNotImplementedError:
715+
# TODO: could this be wrong if other is object dtype?
716+
# in which case we need to operate pointwise?
717+
result = ops.invalid_comparison(self, other, op)
718+
result = pa.array(result, type=pa.bool_())
713719
elif is_scalar(other):
714720
try:
715721
result = pc_func(self._pa_array, self._box_pa(other))

pandas/core/arrays/string_arrow.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
BaseStringArray,
3737
StringDtype,
3838
)
39-
from pandas.core.ops import invalid_comparison
4039
from pandas.core.strings.object_array import ObjectStringArrayMixin
4140

4241
if not pa_version_under10p1:
@@ -563,10 +562,7 @@ def _convert_int_dtype(self, result):
563562
return result
564563

565564
def _cmp_method(self, other, op):
566-
try:
567-
result = super()._cmp_method(other, op)
568-
except pa.ArrowNotImplementedError:
569-
return invalid_comparison(self, other, op)
565+
result = super()._cmp_method(other, op)
570566
if op == operator.ne:
571567
return result.to_numpy(np.bool_, na_value=True)
572568
else:

pandas/tests/series/test_logical_ops.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas.compat import HAS_PYARROW
1010

1111
from pandas import (
12+
ArrowDtype,
1213
DataFrame,
1314
Index,
1415
Series,
@@ -523,18 +524,38 @@ def test_int_dtype_different_index_not_bool(self):
523524
result = ser1 ^ ser2
524525
tm.assert_series_equal(result, expected)
525526

527+
# TODO: this belongs in comparison tests
526528
def test_pyarrow_numpy_string_invalid(self):
527529
# GH#56008
528-
pytest.importorskip("pyarrow")
530+
pa = pytest.importorskip("pyarrow")
529531
ser = Series([False, True])
530532
ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]")
531533
result = ser == ser2
532-
expected = Series(False, index=ser.index)
533-
tm.assert_series_equal(result, expected)
534+
expected_eq = Series(False, index=ser.index)
535+
tm.assert_series_equal(result, expected_eq)
534536

535537
result = ser != ser2
536-
expected = Series(True, index=ser.index)
537-
tm.assert_series_equal(result, expected)
538+
expected_ne = Series(True, index=ser.index)
539+
tm.assert_series_equal(result, expected_ne)
538540

539541
with pytest.raises(TypeError, match="Invalid comparison"):
540542
ser > ser2
543+
544+
# GH#59505
545+
ser3 = ser2.astype("string[pyarrow]")
546+
result3_eq = ser3 == ser
547+
tm.assert_series_equal(result3_eq, expected_eq.astype("bool[pyarrow]"))
548+
result3_ne = ser3 != ser
549+
tm.assert_series_equal(result3_ne, expected_ne.astype("bool[pyarrow]"))
550+
551+
with pytest.raises(TypeError, match="Invalid comparison"):
552+
ser > ser3
553+
554+
ser4 = ser2.astype(ArrowDtype(pa.string()))
555+
result4_eq = ser4 == ser
556+
tm.assert_series_equal(result4_eq, expected_eq.astype("bool[pyarrow]"))
557+
result4_ne = ser4 != ser
558+
tm.assert_series_equal(result4_ne, expected_ne.astype("bool[pyarrow]"))
559+
560+
with pytest.raises(TypeError, match="Invalid comparison"):
561+
ser > ser4

0 commit comments

Comments
 (0)