Skip to content

Revert "Allow any combination of real dtypes in comparisons" #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,7 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
# spec in places where it either deviates from or is more strict than
# NumPy behavior

def _check_allowed_dtypes(
self,
other: bool | int | float | Array,
dtype_category: str,
op: str,
*,
check_promotion: bool = True,
) -> Array:
def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
"""
Helper function for operators to only allow specific input dtypes

Expand All @@ -183,8 +176,7 @@ def _check_allowed_dtypes(
# This will raise TypeError for type combinations that are not allowed
# to promote in the spec (even if the NumPy array operator would
# promote them).
if check_promotion:
res_dtype = _result_type(self.dtype, other.dtype)
res_dtype = _result_type(self.dtype, other.dtype)
if op.startswith("__i"):
# Note: NumPy will allow in-place operators in some cases where
# the type promoted operator does not match the left-hand side
Expand Down Expand Up @@ -578,7 +570,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
# Even though "all" dtypes are allowed, we still require them to be
# promotable with each other.
other = self._check_allowed_dtypes(other, "all", "__eq__", check_promotion=False)
other = self._check_allowed_dtypes(other, "all", "__eq__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -612,7 +604,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ge__.
"""
other = self._check_allowed_dtypes(other, "real numeric", "__ge__", check_promotion=False)
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -646,7 +638,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __gt__.
"""
other = self._check_allowed_dtypes(other, "real numeric", "__gt__", check_promotion=False)
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -700,7 +692,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __le__.
"""
other = self._check_allowed_dtypes(other, "real numeric", "__le__", check_promotion=False)
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -722,7 +714,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __lt__.
"""
other = self._check_allowed_dtypes(other, "real numeric", "__lt__", check_promotion=False)
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -767,7 +759,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
Performs the operation __ne__.
"""
other = self._check_allowed_dtypes(other, "all", "__ne__", check_promotion=False)
other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down
12 changes: 12 additions & 0 deletions array_api_strict/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def equal(x1: Array, x2: Array, /) -> Array:

See its docstring for more information.
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.equal(x1._array, x2._array))

Expand Down Expand Up @@ -437,6 +439,8 @@ def greater(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in greater")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.greater(x1._array, x2._array))

Expand All @@ -449,6 +453,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in greater_equal")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.greater_equal(x1._array, x2._array))

Expand Down Expand Up @@ -518,6 +524,8 @@ def less(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in less")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.less(x1._array, x2._array))

Expand All @@ -530,6 +538,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in less_equal")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.less_equal(x1._array, x2._array))

Expand Down Expand Up @@ -705,6 +715,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array:

See its docstring for more information.
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.not_equal(x1._array, x2._array))

Expand Down
91 changes: 46 additions & 45 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import operator
from builtins import all as all_

import numpy.testing
from numpy.testing import assert_raises, suppress_warnings
import numpy as np
import pytest

Expand Down Expand Up @@ -29,10 +29,6 @@

import array_api_strict

def assert_raises(exception, func, msg=None):
with numpy.testing.assert_raises(exception, msg=msg):
func()

def test_validate_index():
# The indexing tests in the official array API test suite test that the
# array object correctly handles the subset of indices that are required
Expand Down Expand Up @@ -94,7 +90,7 @@ def test_validate_index():

def test_operators():
# For every operator, we test that it works for the required type
# combinations and assert_raises TypeError otherwise
# combinations and raises TypeError otherwise
binary_op_dtypes = {
"__add__": "numeric",
"__and__": "integer_or_boolean",
Expand All @@ -115,7 +111,6 @@ def test_operators():
"__truediv__": "floating",
"__xor__": "integer_or_boolean",
}
comparison_ops = ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]
# Recompute each time because of in-place ops
def _array_vals():
for d in _integer_dtypes:
Expand All @@ -129,7 +124,7 @@ def _array_vals():
BIG_INT = int(1e30)
for op, dtypes in binary_op_dtypes.items():
ops = [op]
if op not in comparison_ops:
if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
rop = "__r" + op[2:]
iop = "__i" + op[2:]
ops += [rop, iop]
Expand Down Expand Up @@ -160,16 +155,16 @@ def _array_vals():
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
)):
if a.dtype in _integer_dtypes and s == BIG_INT:
assert_raises(OverflowError, lambda: getattr(a, _op)(s), _op)
assert_raises(OverflowError, lambda: getattr(a, _op)(s))
else:
# Only test for no error
with numpy.testing.suppress_warnings() as sup:
with suppress_warnings() as sup:
# ignore warnings from pow(BIG_INT)
sup.filter(RuntimeWarning,
"invalid value encountered in power")
getattr(a, _op)(s)
else:
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
assert_raises(TypeError, lambda: getattr(a, _op)(s))

# Test array op array.
for _op in ops:
Expand All @@ -178,25 +173,25 @@ def _array_vals():
# See the promotion table in NEP 47 or the array
# API spec page on type promotion. Mixed kind
# promotion is not defined.
if (op not in comparison_ops and
(x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
)):
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
):
assert_raises(TypeError, lambda: getattr(x, _op)(y))
# Ensure in-place operators only promote to the same dtype as the left operand.
elif (
_op.startswith("__i")
and result_type(x.dtype, y.dtype) != x.dtype
):
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
assert_raises(TypeError, lambda: getattr(x, _op)(y))
# Ensure only those dtypes that are required for every operator are allowed.
elif (dtypes == "all"
elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes)
or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
Expand All @@ -207,7 +202,7 @@ def _array_vals():
):
getattr(x, _op)(y)
else:
assert_raises(TypeError, lambda: getattr(x, _op)(y), (x, _op, y))
assert_raises(TypeError, lambda: getattr(x, _op)(y))

unary_op_dtypes = {
"__abs__": "numeric",
Expand All @@ -226,7 +221,7 @@ def _array_vals():
# Only test for no error
getattr(a, op)()
else:
assert_raises(TypeError, lambda: getattr(a, op)(), _op)
assert_raises(TypeError, lambda: getattr(a, op)())

# Finally, matmul() must be tested separately, because it works a bit
# different from the other operations.
Expand All @@ -245,9 +240,9 @@ def _matmul_array_vals():
or type(s) == int and a.dtype in _integer_dtypes):
# Type promotion is valid, but @ is not allowed on 0-D
# inputs, so the error is a ValueError
assert_raises(ValueError, lambda: getattr(a, _op)(s), _op)
assert_raises(ValueError, lambda: getattr(a, _op)(s))
else:
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
assert_raises(TypeError, lambda: getattr(a, _op)(s))

for x in _matmul_array_vals():
for y in _matmul_array_vals():
Expand Down Expand Up @@ -361,17 +356,20 @@ def test_allow_newaxis():

def test_disallow_flat_indexing_with_newaxis():
a = ones((3, 3, 3))
assert_raises(IndexError, lambda: a[None, 0, 0])
with pytest.raises(IndexError):
a[None, 0, 0]

def test_disallow_mask_with_newaxis():
a = ones((3, 3, 3))
assert_raises(IndexError, lambda: a[None, asarray(True)])
with pytest.raises(IndexError):
a[None, asarray(True)]

@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)])
@pytest.mark.parametrize("index", ["string", False, True])
def test_error_on_invalid_index(shape, index):
a = ones(shape)
assert_raises(IndexError, lambda: a[index])
with pytest.raises(IndexError):
a[index]

def test_mask_0d_array_without_errors():
a = ones(())
Expand All @@ -382,8 +380,10 @@ def test_mask_0d_array_without_errors():
)
def test_error_on_invalid_index_with_ellipsis(i):
a = ones((3, 3, 3))
assert_raises(IndexError, lambda: a[..., i])
assert_raises(IndexError, lambda: a[i, ...])
with pytest.raises(IndexError):
a[..., i]
with pytest.raises(IndexError):
a[i, ...]

def test_array_keys_use_private_array():
"""
Expand All @@ -400,7 +400,8 @@ def test_array_keys_use_private_array():

a = ones((0,), dtype=bool_)
key = ones((0, 0), dtype=bool_)
assert_raises(IndexError, lambda: a[key])
with pytest.raises(IndexError):
a[key]

def test_array_namespace():
a = ones((3, 3))
Expand All @@ -421,16 +422,16 @@ def test_array_namespace():
assert a.__array_namespace__(api_version="2021.12") is array_api_strict
assert array_api_strict.__array_api_version__ == "2021.12"

assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))

def test_iter():
assert_raises(TypeError, lambda: iter(asarray(3)))
pytest.raises(TypeError, lambda: iter(asarray(3)))
assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)]
assert all_(isinstance(a, Array) for a in iter(ones(3)))
assert all_(a.shape == () for a in iter(ones(3)))
assert all_(a.dtype == float64 for a in iter(ones(3)))
assert_raises(TypeError, lambda: iter(ones((3, 3))))
pytest.raises(TypeError, lambda: iter(ones((3, 3))))

@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
def dlpack_2023_12(api_version):
Expand All @@ -446,17 +447,17 @@ def dlpack_2023_12(api_version):


exception = NotImplementedError if api_version >= '2023.12' else ValueError
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(dl_device=CPU_DEVICE))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(dl_device=None))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(max_version=(1, 0)))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(max_version=None))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(copy=False))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(copy=True))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(copy=None))
Loading
Loading