|
14 | 14 | "doesnt_raise",
|
15 | 15 | "nargs",
|
16 | 16 | "fmt_kw",
|
| 17 | + "is_pos_zero", |
| 18 | + "is_neg_zero", |
17 | 19 | "assert_dtype",
|
18 | 20 | "assert_kw_dtype",
|
19 | 21 | "assert_default_float",
|
|
22 | 24 | "assert_shape",
|
23 | 25 | "assert_result_shape",
|
24 | 26 | "assert_keepdimable_shape",
|
| 27 | + "assert_0d_equals", |
25 | 28 | "assert_fill",
|
26 | 29 | "assert_array",
|
27 | 30 | ]
|
@@ -69,6 +72,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str:
|
69 | 72 | return ", ".join(f"{k}={v}" for k, v in kw.items())
|
70 | 73 |
|
71 | 74 |
|
| 75 | +def is_pos_zero(n: float) -> bool: |
| 76 | + return n == 0 and math.copysign(1, n) == 1 |
| 77 | + |
| 78 | + |
| 79 | +def is_neg_zero(n: float) -> bool: |
| 80 | + return n == 0 and math.copysign(1, n) == -1 |
| 81 | + |
| 82 | + |
72 | 83 | def assert_dtype(
|
73 | 84 | func_name: str,
|
74 | 85 | in_dtype: Union[DataType, Sequence[DataType]],
|
@@ -232,15 +243,28 @@ def assert_fill(
|
232 | 243 | def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
|
233 | 244 | assert_dtype(func_name, out.dtype, expected.dtype)
|
234 | 245 | assert_shape(func_name, out.shape, expected.shape, **kw)
|
235 |
| - msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}" |
| 246 | + f_func = f"[{func_name}({fmt_kw(kw)})]" |
236 | 247 | if dh.is_float_dtype(out.dtype):
|
237 |
| - neg_zeros = expected == -0.0 |
238 |
| - assert xp.all((out == -0.0) == neg_zeros), msg |
239 |
| - pos_zeros = expected == +0.0 |
240 |
| - assert xp.all((out == +0.0) == pos_zeros), msg |
241 |
| - nans = xp.isnan(expected) |
242 |
| - assert xp.all(xp.isnan(out) == nans), msg |
243 |
| - mask = ~(neg_zeros | pos_zeros | nans) |
244 |
| - assert xp.all(out[mask] == expected[mask]), msg |
| 248 | + for idx in sh.ndindex(out.shape): |
| 249 | + at_out = out[idx] |
| 250 | + at_expected = expected[idx] |
| 251 | + msg = ( |
| 252 | + f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} " |
| 253 | + f"{f_func}" |
| 254 | + ) |
| 255 | + if xp.isnan(at_expected): |
| 256 | + assert xp.isnan(at_out), msg |
| 257 | + elif at_expected == 0.0 or at_expected == -0.0: |
| 258 | + scalar_at_expected = float(at_expected) |
| 259 | + scalar_at_out = float(at_out) |
| 260 | + if is_pos_zero(scalar_at_expected): |
| 261 | + assert is_pos_zero(scalar_at_out), msg |
| 262 | + else: |
| 263 | + assert is_neg_zero(scalar_at_expected) # sanity check |
| 264 | + assert is_neg_zero(scalar_at_out), msg |
| 265 | + else: |
| 266 | + assert at_out == at_expected, msg |
245 | 267 | else:
|
246 |
| - assert xp.all(out == expected), msg |
| 268 | + assert xp.all(out == expected), ( |
| 269 | + f"out not as expected {f_func}\n" f"{out=}\n{expected=}" |
| 270 | + ) |
0 commit comments