|
1 | 1 | """
|
2 | 2 | Test element-wise functions/operators against reference implementations.
|
3 | 3 | """
|
| 4 | +import cmath |
4 | 5 | import math
|
5 | 6 | import operator
|
6 | 7 | from copy import copy
|
@@ -67,6 +68,29 @@ def isclose(
|
67 | 68 | return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
|
68 | 69 |
|
69 | 70 |
|
| 71 | +def isclose_complex( |
| 72 | + a: complex, |
| 73 | + b: complex, |
| 74 | + M: float, |
| 75 | + *, |
| 76 | + rel_tol: float = 0.25, |
| 77 | + abs_tol: float = 1, |
| 78 | +) -> bool: |
| 79 | + """Wraps math.isclose with very generous defaults. |
| 80 | +
|
| 81 | + This is useful for many floating-point operations where the spec does not |
| 82 | + make accuracy requirements. |
| 83 | + """ |
| 84 | + if cmath.isnan(a) or cmath.isnan(b): |
| 85 | + raise ValueError(f"{a=} and {b=}, but input must be non-NaN") |
| 86 | + if cmath.isinf(a): |
| 87 | + return cmath.isinf(b) or abs(b) > cmath.log(M) |
| 88 | + elif cmath.isinf(b): |
| 89 | + return cmath.isinf(a) or abs(a) > cmath.log(M) |
| 90 | + return cmath.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) |
| 91 | + |
| 92 | + |
| 93 | + |
70 | 94 | def default_filter(s: Scalar) -> bool:
|
71 | 95 | """Returns False when s is a non-finite or a signed zero.
|
72 | 96 |
|
@@ -254,8 +278,7 @@ def unary_assert_against_refimpl(
|
254 | 278 | f"{f_i}={scalar_i}"
|
255 | 279 | )
|
256 | 280 | if res.dtype in dh.complex_dtypes:
|
257 |
| - assert isclose(scalar_o.real, expected.real, M), msg |
258 |
| - assert isclose(scalar_o.imag, expected.imag, M), msg |
| 281 | + assert isclose_complex(scalar_o, expected, M), msg |
259 | 282 | else:
|
260 | 283 | assert isclose(scalar_o, expected, M), msg
|
261 | 284 | else:
|
@@ -330,8 +353,7 @@ def binary_assert_against_refimpl(
|
330 | 353 | f"{f_l}={scalar_l}, {f_r}={scalar_r}"
|
331 | 354 | )
|
332 | 355 | if res.dtype in dh.complex_dtypes:
|
333 |
| - assert isclose(scalar_o.real, expected.real, M), msg |
334 |
| - assert isclose(scalar_o.imag, expected.imag, M), msg |
| 356 | + assert isclose_complex(scalar_o, expected, M), msg |
335 | 357 | else:
|
336 | 358 | assert isclose(scalar_o, expected, M), msg
|
337 | 359 | else:
|
@@ -403,8 +425,7 @@ def right_scalar_assert_against_refimpl(
|
403 | 425 | f"{f_l}={scalar_l}"
|
404 | 426 | )
|
405 | 427 | if res.dtype in dh.complex_dtypes:
|
406 |
| - assert isclose(scalar_o.real, expected.real, M), msg |
407 |
| - assert isclose(scalar_o.imag, expected.imag, M), msg |
| 428 | + assert isclose_complex(scalar_o, expected, M), msg |
408 | 429 | else:
|
409 | 430 | assert isclose(scalar_o, expected, M), msg
|
410 | 431 | else:
|
@@ -1394,7 +1415,7 @@ def test_square(x):
|
1394 | 1415 | ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype)
|
1395 | 1416 | ph.assert_shape("square", out_shape=out.shape, expected=x.shape)
|
1396 | 1417 | unary_assert_against_refimpl(
|
1397 |
| - "square", x, out, lambda s: s**2, expr_template="{}²={}" |
| 1418 | + "square", x, out, lambda s: s*s, expr_template="{}²={}" |
1398 | 1419 | )
|
1399 | 1420 |
|
1400 | 1421 |
|
|
0 commit comments