Skip to content

Commit 0b89c52

Browse files
authored
Merge pull request #347 from ev-br/test_count_nonzero
add a test for count_nonzero
2 parents 28f1dbf + f5a3882 commit 0b89c52

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

Diff for: array_api_tests/test_searching_functions.py

+46
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,52 @@ def test_argmin(x, data):
8888
ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected)
8989

9090

91+
# XXX: dtype= stanza below is to work around unsigned int dtypes in torch
92+
# (count_nonzero_cpu not implemented for uint32 etc)
93+
# XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on
94+
# the problem is tha for ints >iinfo(int32) it runs into essentially this:
95+
# >>> jnp.asarray[2147483648], dtype=jnp.int64)
96+
# .... https://github.com/jax-ml/jax/pull/6047 ...
97+
# Explicitly limiting the range in elements(...) runs into problems with
98+
# hypothesis where floating-point numbers are not exactly representable.
99+
@pytest.mark.min_version("2024.12")
100+
@given(
101+
x=hh.arrays(
102+
dtype=st.sampled_from(dh.int_dtypes + dh.real_float_dtypes + dh.complex_dtypes + (xp.bool,)),
103+
shape=hh.shapes(min_dims=1, min_side=1),
104+
elements={"allow_nan": False},
105+
),
106+
data=st.data(),
107+
)
108+
def test_count_nonzero(x, data):
109+
kw = data.draw(
110+
hh.kwargs(
111+
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
112+
keepdims=st.booleans(),
113+
),
114+
label="kw",
115+
)
116+
keepdims = kw.get("keepdims", False)
117+
118+
out = xp.count_nonzero(x, **kw)
119+
120+
ph.assert_default_index("count_nonzero", out.dtype)
121+
axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
122+
ph.assert_keepdimable_shape(
123+
"count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
124+
)
125+
scalar_type = dh.get_scalar_type(x.dtype)
126+
127+
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
128+
count = int(out[out_idx])
129+
elements = []
130+
for idx in indices:
131+
s = scalar_type(x[idx])
132+
elements.append(s)
133+
expected = sum(el != 0 for el in elements)
134+
ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected)
135+
136+
91137
@given(hh.arrays(dtype=hh.all_dtypes, shape=()))
92138
def test_nonzero_zerodim_error(x):
93139
with pytest.raises(Exception):

0 commit comments

Comments
 (0)