1
1
"""
2
2
Test element-wise functions/operators against reference implementations.
3
3
"""
4
+ import cmath
4
5
import math
5
6
import operator
6
7
from copy import copy
@@ -48,7 +49,7 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
48
49
def isclose (
49
50
a : float ,
50
51
b : float ,
51
- M : float ,
52
+ maximum : float ,
52
53
* ,
53
54
rel_tol : float = 0.25 ,
54
55
abs_tol : float = 1 ,
@@ -61,12 +62,30 @@ def isclose(
61
62
if math .isnan (a ) or math .isnan (b ):
62
63
raise ValueError (f"{ a = } and { b = } , but input must be non-NaN" )
63
64
if math .isinf (a ):
64
- return math .isinf (b ) or abs (b ) > math .log (M )
65
+ return math .isinf (b ) or abs (b ) > math .log (maximum )
65
66
elif math .isinf (b ):
66
- return math .isinf (a ) or abs (a ) > math .log (M )
67
+ return math .isinf (a ) or abs (a ) > math .log (maximum )
67
68
return math .isclose (a , b , rel_tol = rel_tol , abs_tol = abs_tol )
68
69
69
70
71
+ def isclose_complex (
72
+ a : complex ,
73
+ b : complex ,
74
+ maximum : float ,
75
+ * ,
76
+ rel_tol : float = 0.25 ,
77
+ abs_tol : float = 1 ,
78
+ ) -> bool :
79
+ """Like isclose() but specifically for complex values."""
80
+ if cmath .isnan (a ) or cmath .isnan (b ):
81
+ raise ValueError (f"{ a = } and { b = } , but input must be non-NaN" )
82
+ if cmath .isinf (a ):
83
+ return cmath .isinf (b ) or abs (b ) > cmath .log (maximum )
84
+ elif cmath .isinf (b ):
85
+ return cmath .isinf (a ) or abs (a ) > cmath .log (maximum )
86
+ return cmath .isclose (a , b , rel_tol = rel_tol , abs_tol = abs_tol )
87
+
88
+
70
89
def default_filter (s : Scalar ) -> bool :
71
90
"""Returns False when s is a non-finite or a signed zero.
72
91
@@ -254,8 +273,7 @@ def unary_assert_against_refimpl(
254
273
f"{ f_i } ={ scalar_i } "
255
274
)
256
275
if res .dtype in dh .complex_dtypes :
257
- assert isclose (scalar_o .real , expected .real , M ), msg
258
- assert isclose (scalar_o .imag , expected .imag , M ), msg
276
+ assert isclose_complex (scalar_o , expected , M ), msg
259
277
else :
260
278
assert isclose (scalar_o , expected , M ), msg
261
279
else :
@@ -330,8 +348,7 @@ def binary_assert_against_refimpl(
330
348
f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
331
349
)
332
350
if res .dtype in dh .complex_dtypes :
333
- assert isclose (scalar_o .real , expected .real , M ), msg
334
- assert isclose (scalar_o .imag , expected .imag , M ), msg
351
+ assert isclose_complex (scalar_o , expected , M ), msg
335
352
else :
336
353
assert isclose (scalar_o , expected , M ), msg
337
354
else :
@@ -403,8 +420,7 @@ def right_scalar_assert_against_refimpl(
403
420
f"{ f_l } ={ scalar_l } "
404
421
)
405
422
if res .dtype in dh .complex_dtypes :
406
- assert isclose (scalar_o .real , expected .real , M ), msg
407
- assert isclose (scalar_o .imag , expected .imag , M ), msg
423
+ assert isclose_complex (scalar_o , expected , M ), msg
408
424
else :
409
425
assert isclose (scalar_o , expected , M ), msg
410
426
else :
@@ -1394,7 +1410,7 @@ def test_square(x):
1394
1410
ph .assert_dtype ("square" , in_dtype = x .dtype , out_dtype = out .dtype )
1395
1411
ph .assert_shape ("square" , out_shape = out .shape , expected = x .shape )
1396
1412
unary_assert_against_refimpl (
1397
- "square" , x , out , lambda s : s ** 2 , expr_template = "{}²={}"
1413
+ "square" , x , out , lambda s : s * s , expr_template = "{}²={}"
1398
1414
)
1399
1415
1400
1416
0 commit comments