Skip to content

Commit 1c03aaa

Browse files
committed
Restore extra testing of elementwise function disallowed type promotions
1 parent 899ad12 commit 1c03aaa

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

Diff for: array_api_strict/tests/test_elementwise_functions.py

+16
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
_boolean_dtypes,
1010
_floating_dtypes,
1111
_integer_dtypes,
12+
int8,
13+
int16,
14+
int32,
15+
int64,
16+
uint64,
1217
)
1318
from .._flags import set_array_api_strict_flags
1419

@@ -115,6 +120,17 @@ def _array_vals():
115120
func = getattr(_elementwise_functions, func_name)
116121
if nargs(func) == 2:
117122
for y in _array_vals():
123+
# Disallow dtypes that aren't type promotable
124+
if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
125+
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
126+
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
127+
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
128+
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
129+
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
130+
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
131+
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
132+
):
133+
assert_raises(TypeError, lambda: func(x, y))
118134
if x.dtype not in dtypes or y.dtype not in dtypes:
119135
assert_raises(TypeError, lambda: func(x, y))
120136
else:

0 commit comments

Comments
 (0)