Skip to content

Commit c9be347

Browse files
committed
Tweak unit test
1 parent 6458d8d commit c9be347

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

Diff for: tests/test_funcs.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ def test_hypothesis( # type: ignore[no-any-decorated]
229229
cond_shape, *shapes = input_shapes
230230

231231
# cupy/cupy#8382
232-
elements = {"allow_subnormal": False} if library is Backend.CUPY else None
232+
# https://github.com/jax-ml/jax/issues/26658
233+
elements = {"allow_subnormal": library not in (Backend.CUPY, Backend.JAX)}
233234

234235
fill_value = xp.asarray(
235236
data.draw(npst.arrays(dtype=dtype, shape=(), elements=elements))
@@ -258,12 +259,9 @@ def f2(*args: Array) -> Array:
258259
# TODO remove asarrays once all backends support Array API 2024.12
259260
ref3 = xp.where(cond, *asarrays(f1(*arrays), float_fill_value, xp=xp))
260261

261-
# https://github.com/jax-ml/jax/issues/26658
262-
atol = 1e-300 if library is Backend.JAX else 0
263-
264-
xp_assert_close(res1, ref1, atol=atol, rtol=2e-16)
265-
xp_assert_close(res2, ref2, atol=atol, rtol=2e-16)
266-
xp_assert_close(res3, ref3, atol=atol, rtol=2e-16)
262+
xp_assert_close(res1, ref1, rtol=2e-16)
263+
xp_assert_equal(res2, ref2)
264+
xp_assert_equal(res3, ref3)
267265

268266

269267
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")

0 commit comments

Comments
 (0)