Skip to content

Commit 899ad12

Browse files
authored
Revert "Allow any combination of real dtypes in comparisons"
1 parent 77a9c2d commit 899ad12

File tree

4 files changed

+68
-89
lines changed

4 files changed

+68
-89
lines changed

Diff for: array_api_strict/_array_object.py

+8-16
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,7 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
152152
# spec in places where it either deviates from or is more strict than
153153
# NumPy behavior
154154

155-
def _check_allowed_dtypes(
156-
self,
157-
other: bool | int | float | Array,
158-
dtype_category: str,
159-
op: str,
160-
*,
161-
check_promotion: bool = True,
162-
) -> Array:
155+
def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
163156
"""
164157
Helper function for operators to only allow specific input dtypes
165158
@@ -183,8 +176,7 @@ def _check_allowed_dtypes(
183176
# This will raise TypeError for type combinations that are not allowed
184177
# to promote in the spec (even if the NumPy array operator would
185178
# promote them).
186-
if check_promotion:
187-
res_dtype = _result_type(self.dtype, other.dtype)
179+
res_dtype = _result_type(self.dtype, other.dtype)
188180
if op.startswith("__i"):
189181
# Note: NumPy will allow in-place operators in some cases where
190182
# the type promoted operator does not match the left-hand side
@@ -578,7 +570,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
578570
"""
579571
# Even though "all" dtypes are allowed, we still require them to be
580572
# promotable with each other.
581-
other = self._check_allowed_dtypes(other, "all", "__eq__", check_promotion=False)
573+
other = self._check_allowed_dtypes(other, "all", "__eq__")
582574
if other is NotImplemented:
583575
return other
584576
self, other = self._normalize_two_args(self, other)
@@ -612,7 +604,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
612604
"""
613605
Performs the operation __ge__.
614606
"""
615-
other = self._check_allowed_dtypes(other, "real numeric", "__ge__", check_promotion=False)
607+
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
616608
if other is NotImplemented:
617609
return other
618610
self, other = self._normalize_two_args(self, other)
@@ -646,7 +638,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
646638
"""
647639
Performs the operation __gt__.
648640
"""
649-
other = self._check_allowed_dtypes(other, "real numeric", "__gt__", check_promotion=False)
641+
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
650642
if other is NotImplemented:
651643
return other
652644
self, other = self._normalize_two_args(self, other)
@@ -700,7 +692,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
700692
"""
701693
Performs the operation __le__.
702694
"""
703-
other = self._check_allowed_dtypes(other, "real numeric", "__le__", check_promotion=False)
695+
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
704696
if other is NotImplemented:
705697
return other
706698
self, other = self._normalize_two_args(self, other)
@@ -722,7 +714,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
722714
"""
723715
Performs the operation __lt__.
724716
"""
725-
other = self._check_allowed_dtypes(other, "real numeric", "__lt__", check_promotion=False)
717+
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
726718
if other is NotImplemented:
727719
return other
728720
self, other = self._normalize_two_args(self, other)
@@ -767,7 +759,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
767759
"""
768760
Performs the operation __ne__.
769761
"""
770-
other = self._check_allowed_dtypes(other, "all", "__ne__", check_promotion=False)
762+
other = self._check_allowed_dtypes(other, "all", "__ne__")
771763
if other is NotImplemented:
772764
return other
773765
self, other = self._normalize_two_args(self, other)

Diff for: array_api_strict/_elementwise_functions.py

+12
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ def equal(x1: Array, x2: Array, /) -> Array:
375375
376376
See its docstring for more information.
377377
"""
378+
# Call result type here just to raise on disallowed type combinations
379+
_result_type(x1.dtype, x2.dtype)
378380
x1, x2 = Array._normalize_two_args(x1, x2)
379381
return Array._new(np.equal(x1._array, x2._array))
380382

@@ -437,6 +439,8 @@ def greater(x1: Array, x2: Array, /) -> Array:
437439
"""
438440
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
439441
raise TypeError("Only real numeric dtypes are allowed in greater")
442+
# Call result type here just to raise on disallowed type combinations
443+
_result_type(x1.dtype, x2.dtype)
440444
x1, x2 = Array._normalize_two_args(x1, x2)
441445
return Array._new(np.greater(x1._array, x2._array))
442446

@@ -449,6 +453,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
449453
"""
450454
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
451455
raise TypeError("Only real numeric dtypes are allowed in greater_equal")
456+
# Call result type here just to raise on disallowed type combinations
457+
_result_type(x1.dtype, x2.dtype)
452458
x1, x2 = Array._normalize_two_args(x1, x2)
453459
return Array._new(np.greater_equal(x1._array, x2._array))
454460

@@ -518,6 +524,8 @@ def less(x1: Array, x2: Array, /) -> Array:
518524
"""
519525
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
520526
raise TypeError("Only real numeric dtypes are allowed in less")
527+
# Call result type here just to raise on disallowed type combinations
528+
_result_type(x1.dtype, x2.dtype)
521529
x1, x2 = Array._normalize_two_args(x1, x2)
522530
return Array._new(np.less(x1._array, x2._array))
523531

@@ -530,6 +538,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
530538
"""
531539
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
532540
raise TypeError("Only real numeric dtypes are allowed in less_equal")
541+
# Call result type here just to raise on disallowed type combinations
542+
_result_type(x1.dtype, x2.dtype)
533543
x1, x2 = Array._normalize_two_args(x1, x2)
534544
return Array._new(np.less_equal(x1._array, x2._array))
535545

@@ -705,6 +715,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array:
705715
706716
See its docstring for more information.
707717
"""
718+
# Call result type here just to raise on disallowed type combinations
719+
_result_type(x1.dtype, x2.dtype)
708720
x1, x2 = Array._normalize_two_args(x1, x2)
709721
return Array._new(np.not_equal(x1._array, x2._array))
710722

Diff for: array_api_strict/tests/test_array_object.py

+46-45
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import operator
22
from builtins import all as all_
33

4-
import numpy.testing
4+
from numpy.testing import assert_raises, suppress_warnings
55
import numpy as np
66
import pytest
77

@@ -29,10 +29,6 @@
2929

3030
import array_api_strict
3131

32-
def assert_raises(exception, func, msg=None):
33-
with numpy.testing.assert_raises(exception, msg=msg):
34-
func()
35-
3632
def test_validate_index():
3733
# The indexing tests in the official array API test suite test that the
3834
# array object correctly handles the subset of indices that are required
@@ -94,7 +90,7 @@ def test_validate_index():
9490

9591
def test_operators():
9692
# For every operator, we test that it works for the required type
97-
# combinations and assert_raises TypeError otherwise
93+
# combinations and raises TypeError otherwise
9894
binary_op_dtypes = {
9995
"__add__": "numeric",
10096
"__and__": "integer_or_boolean",
@@ -115,7 +111,6 @@ def test_operators():
115111
"__truediv__": "floating",
116112
"__xor__": "integer_or_boolean",
117113
}
118-
comparison_ops = ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]
119114
# Recompute each time because of in-place ops
120115
def _array_vals():
121116
for d in _integer_dtypes:
@@ -129,7 +124,7 @@ def _array_vals():
129124
BIG_INT = int(1e30)
130125
for op, dtypes in binary_op_dtypes.items():
131126
ops = [op]
132-
if op not in comparison_ops:
127+
if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
133128
rop = "__r" + op[2:]
134129
iop = "__i" + op[2:]
135130
ops += [rop, iop]
@@ -160,16 +155,16 @@ def _array_vals():
160155
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
161156
)):
162157
if a.dtype in _integer_dtypes and s == BIG_INT:
163-
assert_raises(OverflowError, lambda: getattr(a, _op)(s), _op)
158+
assert_raises(OverflowError, lambda: getattr(a, _op)(s))
164159
else:
165160
# Only test for no error
166-
with numpy.testing.suppress_warnings() as sup:
161+
with suppress_warnings() as sup:
167162
# ignore warnings from pow(BIG_INT)
168163
sup.filter(RuntimeWarning,
169164
"invalid value encountered in power")
170165
getattr(a, _op)(s)
171166
else:
172-
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
167+
assert_raises(TypeError, lambda: getattr(a, _op)(s))
173168

174169
# Test array op array.
175170
for _op in ops:
@@ -178,25 +173,25 @@ def _array_vals():
178173
# See the promotion table in NEP 47 or the array
179174
# API spec page on type promotion. Mixed kind
180175
# promotion is not defined.
181-
if (op not in comparison_ops and
182-
(x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
183-
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
184-
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
185-
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
186-
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
187-
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
188-
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
189-
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
190-
)):
191-
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
176+
if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
177+
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
178+
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
179+
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
180+
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
181+
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
182+
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
183+
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
184+
):
185+
assert_raises(TypeError, lambda: getattr(x, _op)(y))
192186
# Ensure in-place operators only promote to the same dtype as the left operand.
193187
elif (
194188
_op.startswith("__i")
195189
and result_type(x.dtype, y.dtype) != x.dtype
196190
):
197-
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
191+
assert_raises(TypeError, lambda: getattr(x, _op)(y))
198192
# Ensure only those dtypes that are required for every operator are allowed.
199-
elif (dtypes == "all"
193+
elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
194+
or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
200195
or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes)
201196
or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
202197
or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
@@ -207,7 +202,7 @@ def _array_vals():
207202
):
208203
getattr(x, _op)(y)
209204
else:
210-
assert_raises(TypeError, lambda: getattr(x, _op)(y), (x, _op, y))
205+
assert_raises(TypeError, lambda: getattr(x, _op)(y))
211206

212207
unary_op_dtypes = {
213208
"__abs__": "numeric",
@@ -226,7 +221,7 @@ def _array_vals():
226221
# Only test for no error
227222
getattr(a, op)()
228223
else:
229-
assert_raises(TypeError, lambda: getattr(a, op)(), _op)
224+
assert_raises(TypeError, lambda: getattr(a, op)())
230225

231226
# Finally, matmul() must be tested separately, because it works a bit
232227
# different from the other operations.
@@ -245,9 +240,9 @@ def _matmul_array_vals():
245240
or type(s) == int and a.dtype in _integer_dtypes):
246241
# Type promotion is valid, but @ is not allowed on 0-D
247242
# inputs, so the error is a ValueError
248-
assert_raises(ValueError, lambda: getattr(a, _op)(s), _op)
243+
assert_raises(ValueError, lambda: getattr(a, _op)(s))
249244
else:
250-
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
245+
assert_raises(TypeError, lambda: getattr(a, _op)(s))
251246

252247
for x in _matmul_array_vals():
253248
for y in _matmul_array_vals():
@@ -361,17 +356,20 @@ def test_allow_newaxis():
361356

362357
def test_disallow_flat_indexing_with_newaxis():
363358
a = ones((3, 3, 3))
364-
assert_raises(IndexError, lambda: a[None, 0, 0])
359+
with pytest.raises(IndexError):
360+
a[None, 0, 0]
365361

366362
def test_disallow_mask_with_newaxis():
367363
a = ones((3, 3, 3))
368-
assert_raises(IndexError, lambda: a[None, asarray(True)])
364+
with pytest.raises(IndexError):
365+
a[None, asarray(True)]
369366

370367
@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)])
371368
@pytest.mark.parametrize("index", ["string", False, True])
372369
def test_error_on_invalid_index(shape, index):
373370
a = ones(shape)
374-
assert_raises(IndexError, lambda: a[index])
371+
with pytest.raises(IndexError):
372+
a[index]
375373

376374
def test_mask_0d_array_without_errors():
377375
a = ones(())
@@ -382,8 +380,10 @@ def test_mask_0d_array_without_errors():
382380
)
383381
def test_error_on_invalid_index_with_ellipsis(i):
384382
a = ones((3, 3, 3))
385-
assert_raises(IndexError, lambda: a[..., i])
386-
assert_raises(IndexError, lambda: a[i, ...])
383+
with pytest.raises(IndexError):
384+
a[..., i]
385+
with pytest.raises(IndexError):
386+
a[i, ...]
387387

388388
def test_array_keys_use_private_array():
389389
"""
@@ -400,7 +400,8 @@ def test_array_keys_use_private_array():
400400

401401
a = ones((0,), dtype=bool_)
402402
key = ones((0, 0), dtype=bool_)
403-
assert_raises(IndexError, lambda: a[key])
403+
with pytest.raises(IndexError):
404+
a[key]
404405

405406
def test_array_namespace():
406407
a = ones((3, 3))
@@ -421,16 +422,16 @@ def test_array_namespace():
421422
assert a.__array_namespace__(api_version="2021.12") is array_api_strict
422423
assert array_api_strict.__array_api_version__ == "2021.12"
423424

424-
assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
425-
assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
425+
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
426+
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
426427

427428
def test_iter():
428-
assert_raises(TypeError, lambda: iter(asarray(3)))
429+
pytest.raises(TypeError, lambda: iter(asarray(3)))
429430
assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)]
430431
assert all_(isinstance(a, Array) for a in iter(ones(3)))
431432
assert all_(a.shape == () for a in iter(ones(3)))
432433
assert all_(a.dtype == float64 for a in iter(ones(3)))
433-
assert_raises(TypeError, lambda: iter(ones((3, 3))))
434+
pytest.raises(TypeError, lambda: iter(ones((3, 3))))
434435

435436
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
436437
def dlpack_2023_12(api_version):
@@ -446,17 +447,17 @@ def dlpack_2023_12(api_version):
446447

447448

448449
exception = NotImplementedError if api_version >= '2023.12' else ValueError
449-
assert_raises(exception, lambda:
450+
pytest.raises(exception, lambda:
450451
a.__dlpack__(dl_device=CPU_DEVICE))
451-
assert_raises(exception, lambda:
452+
pytest.raises(exception, lambda:
452453
a.__dlpack__(dl_device=None))
453-
assert_raises(exception, lambda:
454+
pytest.raises(exception, lambda:
454455
a.__dlpack__(max_version=(1, 0)))
455-
assert_raises(exception, lambda:
456+
pytest.raises(exception, lambda:
456457
a.__dlpack__(max_version=None))
457-
assert_raises(exception, lambda:
458+
pytest.raises(exception, lambda:
458459
a.__dlpack__(copy=False))
459-
assert_raises(exception, lambda:
460+
pytest.raises(exception, lambda:
460461
a.__dlpack__(copy=True))
461-
assert_raises(exception, lambda:
462+
pytest.raises(exception, lambda:
462463
a.__dlpack__(copy=None))

0 commit comments

Comments
 (0)