|
9 | 9 | from pandas.compat import HAS_PYARROW
|
10 | 10 |
|
11 | 11 | from pandas import (
|
| 12 | + ArrowDtype, |
12 | 13 | DataFrame,
|
13 | 14 | Index,
|
14 | 15 | Series,
|
@@ -539,18 +540,38 @@ def test_int_dtype_different_index_not_bool(self):
|
539 | 540 | result = ser1 ^ ser2
|
540 | 541 | tm.assert_series_equal(result, expected)
|
541 | 542 |
|
| 543 | + # TODO: this belongs in comparison tests |
542 | 544 | def test_pyarrow_numpy_string_invalid(self):
|
543 | 545 | # GH#56008
|
544 |
| - pytest.importorskip("pyarrow") |
| 546 | + pa = pytest.importorskip("pyarrow") |
545 | 547 | ser = Series([False, True])
|
546 | 548 | ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]")
|
547 | 549 | result = ser == ser2
|
548 |
| - expected = Series(False, index=ser.index) |
549 |
| - tm.assert_series_equal(result, expected) |
| 550 | + expected_eq = Series(False, index=ser.index) |
| 551 | + tm.assert_series_equal(result, expected_eq) |
550 | 552 |
|
551 | 553 | result = ser != ser2
|
552 |
| - expected = Series(True, index=ser.index) |
553 |
| - tm.assert_series_equal(result, expected) |
| 554 | + expected_ne = Series(True, index=ser.index) |
| 555 | + tm.assert_series_equal(result, expected_ne) |
554 | 556 |
|
555 | 557 | with pytest.raises(TypeError, match="Invalid comparison"):
|
556 | 558 | ser > ser2
|
| 559 | + |
| 560 | + # GH#59505 |
| 561 | + ser3 = ser2.astype("string[pyarrow]") |
| 562 | + result3_eq = ser3 == ser |
| 563 | + tm.assert_series_equal(result3_eq, expected_eq.astype("bool[pyarrow]")) |
| 564 | + result3_ne = ser3 != ser |
| 565 | + tm.assert_series_equal(result3_ne, expected_ne.astype("bool[pyarrow]")) |
| 566 | + |
| 567 | + with pytest.raises(TypeError, match="Invalid comparison"): |
| 568 | + ser > ser3 |
| 569 | + |
| 570 | + ser4 = ser2.astype(ArrowDtype(pa.string())) |
| 571 | + result4_eq = ser4 == ser |
| 572 | + tm.assert_series_equal(result4_eq, expected_eq.astype("bool[pyarrow]")) |
| 573 | + result4_ne = ser4 != ser |
| 574 | + tm.assert_series_equal(result4_ne, expected_ne.astype("bool[pyarrow]")) |
| 575 | + |
| 576 | + with pytest.raises(TypeError, match="Invalid comparison"): |
| 577 | + ser > ser4 |
0 commit comments