Skip to content

Commit f5a3882

Browse files
committed
BUG: work around/comment on the strategy for x in count_nonzero
On torch, work around count_nonzero not implemented for uints On jax, there are problems with integers > iinfo(jnp.int32)
1 parent 041717e commit f5a3882

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

Diff for: array_api_tests/test_searching_functions.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,18 @@ def test_argmin(x, data):
8888
ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected)
8989

9090

91+
# XXX: dtype= stanza below is to work around unsigned int dtypes in torch
92+
# (count_nonzero_cpu not implemented for uint32 etc)
93+
# XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on
94+
# the problem is tha for ints >iinfo(int32) it runs into essentially this:
95+
# >>> jnp.asarray[2147483648], dtype=jnp.int64)
96+
# .... https://github.com/jax-ml/jax/pull/6047 ...
97+
# Explicitly limiting the range in elements(...) runs into problems with
98+
# hypothesis where floating-point numbers are not exactly representable.
9199
@pytest.mark.min_version("2024.12")
92100
@given(
93101
x=hh.arrays(
94-
dtype=hh.real_dtypes,
102+
dtype=st.sampled_from(dh.int_dtypes + dh.real_float_dtypes + dh.complex_dtypes + (xp.bool,)),
95103
shape=hh.shapes(min_dims=1, min_side=1),
96104
elements={"allow_nan": False},
97105
),

0 commit comments

Comments
 (0)