@@ -229,7 +229,8 @@ def test_hypothesis( # type: ignore[no-any-decorated]
229
229
cond_shape , * shapes = input_shapes
230
230
231
231
# cupy/cupy#8382
232
- elements = {"allow_subnormal" : False } if library is Backend .CUPY else None
232
+ # https://github.com/jax-ml/jax/issues/26658
233
+ elements = {"allow_subnormal" : library not in (Backend .CUPY , Backend .JAX )}
233
234
234
235
fill_value = xp .asarray (
235
236
data .draw (npst .arrays (dtype = dtype , shape = (), elements = elements ))
@@ -258,12 +259,9 @@ def f2(*args: Array) -> Array:
258
259
# TODO remove asarrays once all backends support Array API 2024.12
259
260
ref3 = xp .where (cond , * asarrays (f1 (* arrays ), float_fill_value , xp = xp ))
260
261
261
- # https://github.com/jax-ml/jax/issues/26658
262
- atol = 1e-300 if library is Backend .JAX else 0
263
-
264
- xp_assert_close (res1 , ref1 , atol = atol , rtol = 2e-16 )
265
- xp_assert_close (res2 , ref2 , atol = atol , rtol = 2e-16 )
266
- xp_assert_close (res3 , ref3 , atol = atol , rtol = 2e-16 )
262
+ xp_assert_close (res1 , ref1 , rtol = 2e-16 )
263
+ xp_assert_equal (res2 , ref2 )
264
+ xp_assert_equal (res3 , ref3 )
267
265
268
266
269
267
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
0 commit comments