Skip to content

Commit 2b8f5f5

Browse files
rgommershonno
authored andcommitted
Fix complex scalar checking and test_square failure
Closes gh-190
1 parent 6fa42b3 commit 2b8f5f5

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

Diff for: array_api_tests/test_operators_and_elementwise_functions.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Test element-wise functions/operators against reference implementations.
33
"""
4+
import cmath
45
import math
56
import operator
67
from copy import copy
@@ -67,6 +68,29 @@ def isclose(
6768
return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
6869

6970

71+
def isclose_complex(
72+
a: complex,
73+
b: complex,
74+
M: float,
75+
*,
76+
rel_tol: float = 0.25,
77+
abs_tol: float = 1,
78+
) -> bool:
79+
"""Wraps math.isclose with very generous defaults.
80+
81+
This is useful for many floating-point operations where the spec does not
82+
make accuracy requirements.
83+
"""
84+
if cmath.isnan(a) or cmath.isnan(b):
85+
raise ValueError(f"{a=} and {b=}, but input must be non-NaN")
86+
if cmath.isinf(a):
87+
return cmath.isinf(b) or abs(b) > cmath.log(M)
88+
elif cmath.isinf(b):
89+
return cmath.isinf(a) or abs(a) > cmath.log(M)
90+
return cmath.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
91+
92+
93+
7094
def default_filter(s: Scalar) -> bool:
7195
"""Returns False when s is a non-finite or a signed zero.
7296
@@ -254,8 +278,7 @@ def unary_assert_against_refimpl(
254278
f"{f_i}={scalar_i}"
255279
)
256280
if res.dtype in dh.complex_dtypes:
257-
assert isclose(scalar_o.real, expected.real, M), msg
258-
assert isclose(scalar_o.imag, expected.imag, M), msg
281+
assert isclose_complex(scalar_o, expected, M), msg
259282
else:
260283
assert isclose(scalar_o, expected, M), msg
261284
else:
@@ -330,8 +353,7 @@ def binary_assert_against_refimpl(
330353
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
331354
)
332355
if res.dtype in dh.complex_dtypes:
333-
assert isclose(scalar_o.real, expected.real, M), msg
334-
assert isclose(scalar_o.imag, expected.imag, M), msg
356+
assert isclose_complex(scalar_o, expected, M), msg
335357
else:
336358
assert isclose(scalar_o, expected, M), msg
337359
else:
@@ -403,8 +425,7 @@ def right_scalar_assert_against_refimpl(
403425
f"{f_l}={scalar_l}"
404426
)
405427
if res.dtype in dh.complex_dtypes:
406-
assert isclose(scalar_o.real, expected.real, M), msg
407-
assert isclose(scalar_o.imag, expected.imag, M), msg
428+
assert isclose_complex(scalar_o, expected, M), msg
408429
else:
409430
assert isclose(scalar_o, expected, M), msg
410431
else:
@@ -1394,7 +1415,7 @@ def test_square(x):
13941415
ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype)
13951416
ph.assert_shape("square", out_shape=out.shape, expected=x.shape)
13961417
unary_assert_against_refimpl(
1397-
"square", x, out, lambda s: s**2, expr_template="{}²={}"
1418+
"square", x, out, lambda s: s*s, expr_template="{}²={}"
13981419
)
13991420

14001421

0 commit comments

Comments
 (0)