diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py
index 50331fa0..30d9fe48 100644
--- a/array_api_compat/cupy/_aliases.py
+++ b/array_api_compat/cupy/_aliases.py
@@ -125,6 +125,20 @@ def astype(
     return out.copy() if copy and out is x else out
 
 
+# cupy.count_nonzero does not have keepdims
+def count_nonzero(
+    x: ndarray,
+    axis=None,
+    keepdims=False
+) -> ndarray:
+   result = cp.count_nonzero(x, axis)
+   if keepdims:
+       if axis is None:
+            return cp.reshape(result, [1]*x.ndim)
+       return cp.expand_dims(result, axis)
+   return result
+
+
 # These functions are completely new here. If the library already has them
 # (i.e., numpy 2.0), use the library version instead of our wrapper.
 if hasattr(cp, 'vecdot'):
@@ -146,6 +160,6 @@ def astype(
                               'acos', 'acosh', 'asin', 'asinh', 'atan',
                               'atan2', 'atanh', 'bitwise_left_shift',
                               'bitwise_invert', 'bitwise_right_shift',
-                              'bool', 'concat', 'pow', 'sign']
+                              'bool', 'concat', 'count_nonzero', 'pow', 'sign']
 
 _all_ignore = ['cp', 'get_xp']
diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py
index 4e2d26f9..80d66281 100644
--- a/array_api_compat/dask/array/_aliases.py
+++ b/array_api_compat/dask/array/_aliases.py
@@ -335,6 +335,21 @@ def argsort(
     return restore(x)
 
 
+# dask.array.count_nonzero does not have keepdims
+def count_nonzero(
+    x: Array,
+    axis=None,
+    keepdims=False
+) -> Array:
+   result = da.count_nonzero(x, axis)
+   if keepdims:
+       if axis is None:
+            return da.reshape(result, [1]*x.ndim)
+       return da.expand_dims(result, axis)
+   return result
+
+
+
 __all__ = _aliases.__all__ + [
                     '__array_namespace_info__', 'asarray', 'astype', 'acos',
                     'acosh', 'asin', 'asinh', 'atan', 'atan2',
@@ -343,6 +358,6 @@ def argsort(
                     'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
                     'uint8', 'uint16', 'uint32', 'uint64',
                     'complex64', 'complex128', 'iinfo', 'finfo',
-                    'can_cast', 'result_type']
+                    'can_cast', 'count_nonzero', 'result_type']
 
 _all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index 98eec121..a47f7121 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -127,6 +127,19 @@ def astype(
     return x.astype(dtype=dtype, copy=copy)
 
 
+# count_nonzero returns a python int for axis=None and keepdims=False
+# https://github.com/numpy/numpy/issues/17562
+def count_nonzero(
+    x : ndarray,
+    axis=None,
+    keepdims=False
+) -> ndarray:
+    result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
+    if axis is None and not keepdims:
+        return np.asarray(result)
+    return result
+
+
 # These functions are completely new here. If the library already has them
 # (i.e., numpy 2.0), use the library version instead of our wrapper.
 if hasattr(np, 'vecdot'):
@@ -148,6 +161,6 @@ def astype(
                               'acos', 'acosh', 'asin', 'asinh', 'atan',
                               'atan2', 'atanh', 'bitwise_left_shift',
                               'bitwise_invert', 'bitwise_right_shift',
-                              'bool', 'concat', 'pow']
+                              'bool', 'concat', 'count_nonzero', 'pow']
 
 _all_ignore = ['np', 'get_xp']
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index a6e833f9..b4786320 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -521,7 +521,7 @@ def diff(
     return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
 
 
-# torch uses `dim` instead of `axis`
+# torch uses `dim` instead of `axis`, does not have keepdims
 def count_nonzero(
     x: array,
     /,
@@ -529,7 +529,14 @@ def count_nonzero(
     axis: Optional[Union[int, Tuple[int, ...]]] = None,
     keepdims: bool = False,
 ) -> array:
-    return torch.count_nonzero(x, dim=axis, keepdims=keepdims)
+    result = torch.count_nonzero(x, dim=axis)
+    if keepdims:
+        if axis is not None:
+            return result.unsqueeze(axis)
+        return _axis_none_keepdims(result, x.ndim, keepdims)
+    else:
+        return result
+
 
 
 def where(condition: array, x1: array, x2: array, /) -> array: