Skip to content

Commit 83f0bcd

Browse files
authored
Merge pull request #204 from rgommers/fix-complex-scalar-check
Fix complex scalar checking and `test_square` failure
2 parents 6fa42b3 + 88b64af commit 83f0bcd

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

Diff for: array_api_tests/test_operators_and_elementwise_functions.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Test element-wise functions/operators against reference implementations.
33
"""
4+
import cmath
45
import math
56
import operator
67
from copy import copy
@@ -48,7 +49,7 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
4849
def isclose(
4950
a: float,
5051
b: float,
51-
M: float,
52+
maximum: float,
5253
*,
5354
rel_tol: float = 0.25,
5455
abs_tol: float = 1,
@@ -61,12 +62,30 @@ def isclose(
6162
if math.isnan(a) or math.isnan(b):
6263
raise ValueError(f"{a=} and {b=}, but input must be non-NaN")
6364
if math.isinf(a):
64-
return math.isinf(b) or abs(b) > math.log(M)
65+
return math.isinf(b) or abs(b) > math.log(maximum)
6566
elif math.isinf(b):
66-
return math.isinf(a) or abs(a) > math.log(M)
67+
return math.isinf(a) or abs(a) > math.log(maximum)
6768
return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
6869

6970

71+
def isclose_complex(
72+
a: complex,
73+
b: complex,
74+
maximum: float,
75+
*,
76+
rel_tol: float = 0.25,
77+
abs_tol: float = 1,
78+
) -> bool:
79+
"""Like isclose() but specifically for complex values."""
80+
if cmath.isnan(a) or cmath.isnan(b):
81+
raise ValueError(f"{a=} and {b=}, but input must be non-NaN")
82+
if cmath.isinf(a):
83+
return cmath.isinf(b) or abs(b) > cmath.log(maximum)
84+
elif cmath.isinf(b):
85+
return cmath.isinf(a) or abs(a) > cmath.log(maximum)
86+
return cmath.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
87+
88+
7089
def default_filter(s: Scalar) -> bool:
7190
"""Returns False when s is a non-finite or a signed zero.
7291
@@ -254,8 +273,7 @@ def unary_assert_against_refimpl(
254273
f"{f_i}={scalar_i}"
255274
)
256275
if res.dtype in dh.complex_dtypes:
257-
assert isclose(scalar_o.real, expected.real, M), msg
258-
assert isclose(scalar_o.imag, expected.imag, M), msg
276+
assert isclose_complex(scalar_o, expected, M), msg
259277
else:
260278
assert isclose(scalar_o, expected, M), msg
261279
else:
@@ -330,8 +348,7 @@ def binary_assert_against_refimpl(
330348
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
331349
)
332350
if res.dtype in dh.complex_dtypes:
333-
assert isclose(scalar_o.real, expected.real, M), msg
334-
assert isclose(scalar_o.imag, expected.imag, M), msg
351+
assert isclose_complex(scalar_o, expected, M), msg
335352
else:
336353
assert isclose(scalar_o, expected, M), msg
337354
else:
@@ -403,8 +420,7 @@ def right_scalar_assert_against_refimpl(
403420
f"{f_l}={scalar_l}"
404421
)
405422
if res.dtype in dh.complex_dtypes:
406-
assert isclose(scalar_o.real, expected.real, M), msg
407-
assert isclose(scalar_o.imag, expected.imag, M), msg
423+
assert isclose_complex(scalar_o, expected, M), msg
408424
else:
409425
assert isclose(scalar_o, expected, M), msg
410426
else:
@@ -1394,7 +1410,7 @@ def test_square(x):
13941410
ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype)
13951411
ph.assert_shape("square", out_shape=out.shape, expected=x.shape)
13961412
unary_assert_against_refimpl(
1397-
"square", x, out, lambda s: s**2, expr_template="{}²={}"
1413+
"square", x, out, lambda s: s*s, expr_template="{}²={}"
13981414
)
13991415

14001416

0 commit comments

Comments
 (0)