@@ -88,6 +88,52 @@ def test_argmin(x, data):
88
88
ph .assert_scalar_equals ("argmin" , type_ = int , idx = out_idx , out = min_i , expected = expected )
89
89
90
90
91
+ # XXX: dtype= stanza below is to work around unsigned int dtypes in torch
92
+ # (count_nonzero_cpu not implemented for uint32 etc)
93
+ # XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on
94
+ # the problem is tha for ints >iinfo(int32) it runs into essentially this:
95
+ # >>> jnp.asarray[2147483648], dtype=jnp.int64)
96
+ # .... https://github.com/jax-ml/jax/pull/6047 ...
97
+ # Explicitly limiting the range in elements(...) runs into problems with
98
+ # hypothesis where floating-point numbers are not exactly representable.
99
+ @pytest .mark .min_version ("2024.12" )
100
+ @given (
101
+ x = hh .arrays (
102
+ dtype = st .sampled_from (dh .int_dtypes + dh .real_float_dtypes + dh .complex_dtypes + (xp .bool ,)),
103
+ shape = hh .shapes (min_dims = 1 , min_side = 1 ),
104
+ elements = {"allow_nan" : False },
105
+ ),
106
+ data = st .data (),
107
+ )
108
+ def test_count_nonzero (x , data ):
109
+ kw = data .draw (
110
+ hh .kwargs (
111
+ axis = st .none () | st .integers (- x .ndim , max (x .ndim - 1 , 0 )),
112
+ keepdims = st .booleans (),
113
+ ),
114
+ label = "kw" ,
115
+ )
116
+ keepdims = kw .get ("keepdims" , False )
117
+
118
+ out = xp .count_nonzero (x , ** kw )
119
+
120
+ ph .assert_default_index ("count_nonzero" , out .dtype )
121
+ axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
122
+ ph .assert_keepdimable_shape (
123
+ "count_nonzero" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
124
+ )
125
+ scalar_type = dh .get_scalar_type (x .dtype )
126
+
127
+ for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
128
+ count = int (out [out_idx ])
129
+ elements = []
130
+ for idx in indices :
131
+ s = scalar_type (x [idx ])
132
+ elements .append (s )
133
+ expected = sum (el != 0 for el in elements )
134
+ ph .assert_scalar_equals ("count_nonzero" , type_ = int , idx = out_idx , out = count , expected = expected )
135
+
136
+
91
137
@given (hh .arrays (dtype = hh .all_dtypes , shape = ()))
92
138
def test_nonzero_zerodim_error (x ):
93
139
with pytest .raises (Exception ):
0 commit comments