Skip to content

Commit 77a9c2d

Browse files
authored
Merge pull request #53 from asmeurer/relational-any-dtypes
Allow any combination of real dtypes in comparisons
2 parents 6b0079b + 1f87699 commit 77a9c2d

File tree

4 files changed

+89
-68
lines changed

4 files changed

+89
-68
lines changed

Diff for: array_api_strict/_array_object.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,14 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
152152
# spec in places where it either deviates from or is more strict than
153153
# NumPy behavior
154154

155-
def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
155+
def _check_allowed_dtypes(
156+
self,
157+
other: bool | int | float | Array,
158+
dtype_category: str,
159+
op: str,
160+
*,
161+
check_promotion: bool = True,
162+
) -> Array:
156163
"""
157164
Helper function for operators to only allow specific input dtypes
158165
@@ -176,7 +183,8 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor
176183
# This will raise TypeError for type combinations that are not allowed
177184
# to promote in the spec (even if the NumPy array operator would
178185
# promote them).
179-
res_dtype = _result_type(self.dtype, other.dtype)
186+
if check_promotion:
187+
res_dtype = _result_type(self.dtype, other.dtype)
180188
if op.startswith("__i"):
181189
# Note: NumPy will allow in-place operators in some cases where
182190
# the type promoted operator does not match the left-hand side
@@ -570,7 +578,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
570578
"""
571579
# Even though "all" dtypes are allowed, we still require them to be
572580
# promotable with each other.
573-
other = self._check_allowed_dtypes(other, "all", "__eq__")
581+
other = self._check_allowed_dtypes(other, "all", "__eq__", check_promotion=False)
574582
if other is NotImplemented:
575583
return other
576584
self, other = self._normalize_two_args(self, other)
@@ -604,7 +612,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
604612
"""
605613
Performs the operation __ge__.
606614
"""
607-
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
615+
other = self._check_allowed_dtypes(other, "real numeric", "__ge__", check_promotion=False)
608616
if other is NotImplemented:
609617
return other
610618
self, other = self._normalize_two_args(self, other)
@@ -638,7 +646,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
638646
"""
639647
Performs the operation __gt__.
640648
"""
641-
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
649+
other = self._check_allowed_dtypes(other, "real numeric", "__gt__", check_promotion=False)
642650
if other is NotImplemented:
643651
return other
644652
self, other = self._normalize_two_args(self, other)
@@ -692,7 +700,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
692700
"""
693701
Performs the operation __le__.
694702
"""
695-
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
703+
other = self._check_allowed_dtypes(other, "real numeric", "__le__", check_promotion=False)
696704
if other is NotImplemented:
697705
return other
698706
self, other = self._normalize_two_args(self, other)
@@ -714,7 +722,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
714722
"""
715723
Performs the operation __lt__.
716724
"""
717-
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
725+
other = self._check_allowed_dtypes(other, "real numeric", "__lt__", check_promotion=False)
718726
if other is NotImplemented:
719727
return other
720728
self, other = self._normalize_two_args(self, other)
@@ -759,7 +767,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
759767
"""
760768
Performs the operation __ne__.
761769
"""
762-
other = self._check_allowed_dtypes(other, "all", "__ne__")
770+
other = self._check_allowed_dtypes(other, "all", "__ne__", check_promotion=False)
763771
if other is NotImplemented:
764772
return other
765773
self, other = self._normalize_two_args(self, other)

Diff for: array_api_strict/_elementwise_functions.py

-12
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,6 @@ def equal(x1: Array, x2: Array, /) -> Array:
375375
376376
See its docstring for more information.
377377
"""
378-
# Call result type here just to raise on disallowed type combinations
379-
_result_type(x1.dtype, x2.dtype)
380378
x1, x2 = Array._normalize_two_args(x1, x2)
381379
return Array._new(np.equal(x1._array, x2._array))
382380

@@ -439,8 +437,6 @@ def greater(x1: Array, x2: Array, /) -> Array:
439437
"""
440438
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
441439
raise TypeError("Only real numeric dtypes are allowed in greater")
442-
# Call result type here just to raise on disallowed type combinations
443-
_result_type(x1.dtype, x2.dtype)
444440
x1, x2 = Array._normalize_two_args(x1, x2)
445441
return Array._new(np.greater(x1._array, x2._array))
446442

@@ -453,8 +449,6 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
453449
"""
454450
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
455451
raise TypeError("Only real numeric dtypes are allowed in greater_equal")
456-
# Call result type here just to raise on disallowed type combinations
457-
_result_type(x1.dtype, x2.dtype)
458452
x1, x2 = Array._normalize_two_args(x1, x2)
459453
return Array._new(np.greater_equal(x1._array, x2._array))
460454

@@ -524,8 +518,6 @@ def less(x1: Array, x2: Array, /) -> Array:
524518
"""
525519
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
526520
raise TypeError("Only real numeric dtypes are allowed in less")
527-
# Call result type here just to raise on disallowed type combinations
528-
_result_type(x1.dtype, x2.dtype)
529521
x1, x2 = Array._normalize_two_args(x1, x2)
530522
return Array._new(np.less(x1._array, x2._array))
531523

@@ -538,8 +530,6 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
538530
"""
539531
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
540532
raise TypeError("Only real numeric dtypes are allowed in less_equal")
541-
# Call result type here just to raise on disallowed type combinations
542-
_result_type(x1.dtype, x2.dtype)
543533
x1, x2 = Array._normalize_two_args(x1, x2)
544534
return Array._new(np.less_equal(x1._array, x2._array))
545535

@@ -715,8 +705,6 @@ def not_equal(x1: Array, x2: Array, /) -> Array:
715705
716706
See its docstring for more information.
717707
"""
718-
# Call result type here just to raise on disallowed type combinations
719-
_result_type(x1.dtype, x2.dtype)
720708
x1, x2 = Array._normalize_two_args(x1, x2)
721709
return Array._new(np.not_equal(x1._array, x2._array))
722710

Diff for: array_api_strict/tests/test_array_object.py

+45-46
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import operator
22
from builtins import all as all_
33

4-
from numpy.testing import assert_raises, suppress_warnings
4+
import numpy.testing
55
import numpy as np
66
import pytest
77

@@ -29,6 +29,10 @@
2929

3030
import array_api_strict
3131

32+
def assert_raises(exception, func, msg=None):
33+
with numpy.testing.assert_raises(exception, msg=msg):
34+
func()
35+
3236
def test_validate_index():
3337
# The indexing tests in the official array API test suite test that the
3438
# array object correctly handles the subset of indices that are required
@@ -90,7 +94,7 @@ def test_validate_index():
9094

9195
def test_operators():
9296
# For every operator, we test that it works for the required type
93-
# combinations and raises TypeError otherwise
97+
# combinations and assert_raises TypeError otherwise
9498
binary_op_dtypes = {
9599
"__add__": "numeric",
96100
"__and__": "integer_or_boolean",
@@ -111,6 +115,7 @@ def test_operators():
111115
"__truediv__": "floating",
112116
"__xor__": "integer_or_boolean",
113117
}
118+
comparison_ops = ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]
114119
# Recompute each time because of in-place ops
115120
def _array_vals():
116121
for d in _integer_dtypes:
@@ -124,7 +129,7 @@ def _array_vals():
124129
BIG_INT = int(1e30)
125130
for op, dtypes in binary_op_dtypes.items():
126131
ops = [op]
127-
if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
132+
if op not in comparison_ops:
128133
rop = "__r" + op[2:]
129134
iop = "__i" + op[2:]
130135
ops += [rop, iop]
@@ -155,16 +160,16 @@ def _array_vals():
155160
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
156161
)):
157162
if a.dtype in _integer_dtypes and s == BIG_INT:
158-
assert_raises(OverflowError, lambda: getattr(a, _op)(s))
163+
assert_raises(OverflowError, lambda: getattr(a, _op)(s), _op)
159164
else:
160165
# Only test for no error
161-
with suppress_warnings() as sup:
166+
with numpy.testing.suppress_warnings() as sup:
162167
# ignore warnings from pow(BIG_INT)
163168
sup.filter(RuntimeWarning,
164169
"invalid value encountered in power")
165170
getattr(a, _op)(s)
166171
else:
167-
assert_raises(TypeError, lambda: getattr(a, _op)(s))
172+
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
168173

169174
# Test array op array.
170175
for _op in ops:
@@ -173,25 +178,25 @@ def _array_vals():
173178
# See the promotion table in NEP 47 or the array
174179
# API spec page on type promotion. Mixed kind
175180
# promotion is not defined.
176-
if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
177-
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
178-
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
179-
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
180-
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
181-
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
182-
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
183-
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
184-
):
185-
assert_raises(TypeError, lambda: getattr(x, _op)(y))
181+
if (op not in comparison_ops and
182+
(x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
183+
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
184+
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
185+
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
186+
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
187+
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
188+
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
189+
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
190+
)):
191+
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
186192
# Ensure in-place operators only promote to the same dtype as the left operand.
187193
elif (
188194
_op.startswith("__i")
189195
and result_type(x.dtype, y.dtype) != x.dtype
190196
):
191-
assert_raises(TypeError, lambda: getattr(x, _op)(y))
197+
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
192198
# Ensure only those dtypes that are required for every operator are allowed.
193-
elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
194-
or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
199+
elif (dtypes == "all"
195200
or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes)
196201
or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
197202
or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
@@ -202,7 +207,7 @@ def _array_vals():
202207
):
203208
getattr(x, _op)(y)
204209
else:
205-
assert_raises(TypeError, lambda: getattr(x, _op)(y))
210+
assert_raises(TypeError, lambda: getattr(x, _op)(y), (x, _op, y))
206211

207212
unary_op_dtypes = {
208213
"__abs__": "numeric",
@@ -221,7 +226,7 @@ def _array_vals():
221226
# Only test for no error
222227
getattr(a, op)()
223228
else:
224-
assert_raises(TypeError, lambda: getattr(a, op)())
229+
assert_raises(TypeError, lambda: getattr(a, op)(), _op)
225230

226231
# Finally, matmul() must be tested separately, because it works a bit
227232
# different from the other operations.
@@ -240,9 +245,9 @@ def _matmul_array_vals():
240245
or type(s) == int and a.dtype in _integer_dtypes):
241246
# Type promotion is valid, but @ is not allowed on 0-D
242247
# inputs, so the error is a ValueError
243-
assert_raises(ValueError, lambda: getattr(a, _op)(s))
248+
assert_raises(ValueError, lambda: getattr(a, _op)(s), _op)
244249
else:
245-
assert_raises(TypeError, lambda: getattr(a, _op)(s))
250+
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
246251

247252
for x in _matmul_array_vals():
248253
for y in _matmul_array_vals():
@@ -356,20 +361,17 @@ def test_allow_newaxis():
356361

357362
def test_disallow_flat_indexing_with_newaxis():
358363
a = ones((3, 3, 3))
359-
with pytest.raises(IndexError):
360-
a[None, 0, 0]
364+
assert_raises(IndexError, lambda: a[None, 0, 0])
361365

362366
def test_disallow_mask_with_newaxis():
363367
a = ones((3, 3, 3))
364-
with pytest.raises(IndexError):
365-
a[None, asarray(True)]
368+
assert_raises(IndexError, lambda: a[None, asarray(True)])
366369

367370
@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)])
368371
@pytest.mark.parametrize("index", ["string", False, True])
369372
def test_error_on_invalid_index(shape, index):
370373
a = ones(shape)
371-
with pytest.raises(IndexError):
372-
a[index]
374+
assert_raises(IndexError, lambda: a[index])
373375

374376
def test_mask_0d_array_without_errors():
375377
a = ones(())
@@ -380,10 +382,8 @@ def test_mask_0d_array_without_errors():
380382
)
381383
def test_error_on_invalid_index_with_ellipsis(i):
382384
a = ones((3, 3, 3))
383-
with pytest.raises(IndexError):
384-
a[..., i]
385-
with pytest.raises(IndexError):
386-
a[i, ...]
385+
assert_raises(IndexError, lambda: a[..., i])
386+
assert_raises(IndexError, lambda: a[i, ...])
387387

388388
def test_array_keys_use_private_array():
389389
"""
@@ -400,8 +400,7 @@ def test_array_keys_use_private_array():
400400

401401
a = ones((0,), dtype=bool_)
402402
key = ones((0, 0), dtype=bool_)
403-
with pytest.raises(IndexError):
404-
a[key]
403+
assert_raises(IndexError, lambda: a[key])
405404

406405
def test_array_namespace():
407406
a = ones((3, 3))
@@ -422,16 +421,16 @@ def test_array_namespace():
422421
assert a.__array_namespace__(api_version="2021.12") is array_api_strict
423422
assert array_api_strict.__array_api_version__ == "2021.12"
424423

425-
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
426-
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
424+
assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
425+
assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
427426

428427
def test_iter():
429-
pytest.raises(TypeError, lambda: iter(asarray(3)))
428+
assert_raises(TypeError, lambda: iter(asarray(3)))
430429
assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)]
431430
assert all_(isinstance(a, Array) for a in iter(ones(3)))
432431
assert all_(a.shape == () for a in iter(ones(3)))
433432
assert all_(a.dtype == float64 for a in iter(ones(3)))
434-
pytest.raises(TypeError, lambda: iter(ones((3, 3))))
433+
assert_raises(TypeError, lambda: iter(ones((3, 3))))
435434

436435
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
437436
def dlpack_2023_12(api_version):
@@ -447,17 +446,17 @@ def dlpack_2023_12(api_version):
447446

448447

449448
exception = NotImplementedError if api_version >= '2023.12' else ValueError
450-
pytest.raises(exception, lambda:
449+
assert_raises(exception, lambda:
451450
a.__dlpack__(dl_device=CPU_DEVICE))
452-
pytest.raises(exception, lambda:
451+
assert_raises(exception, lambda:
453452
a.__dlpack__(dl_device=None))
454-
pytest.raises(exception, lambda:
453+
assert_raises(exception, lambda:
455454
a.__dlpack__(max_version=(1, 0)))
456-
pytest.raises(exception, lambda:
455+
assert_raises(exception, lambda:
457456
a.__dlpack__(max_version=None))
458-
pytest.raises(exception, lambda:
457+
assert_raises(exception, lambda:
459458
a.__dlpack__(copy=False))
460-
pytest.raises(exception, lambda:
459+
assert_raises(exception, lambda:
461460
a.__dlpack__(copy=True))
462-
pytest.raises(exception, lambda:
461+
assert_raises(exception, lambda:
463462
a.__dlpack__(copy=None))

0 commit comments

Comments
 (0)