Skip to content

Commit 37ce670

Browse files
committed
nunique
1 parent cc9f403 commit 37ce670

File tree

4 files changed

+62
-0
lines changed

4 files changed

+62
-0
lines changed

Diff for: docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
create_diagonal
1313
expand_dims
1414
kron
15+
nunique
1516
setdiff1d
1617
sinc
1718
```

Diff for: src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
create_diagonal,
88
expand_dims,
99
kron,
10+
nunique,
1011
pad,
1112
setdiff1d,
1213
sinc,
@@ -23,6 +24,7 @@
2324
"create_diagonal",
2425
"expand_dims",
2526
"kron",
27+
"nunique",
2628
"pad",
2729
"setdiff1d",
2830
"sinc",

Diff for: src/array_api_extra/_funcs.py

+40
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6+
import math
67
import operator
78
import warnings
89
from collections.abc import Callable
@@ -13,8 +14,10 @@
1314
from ._lib import _compat, _utils
1415
from ._lib._compat import (
1516
array_namespace,
17+
device,
1618
is_jax_array,
1719
is_writeable_array,
20+
size,
1821
)
1922
from ._lib._typing import Array, Index
2023

@@ -25,6 +28,7 @@
2528
"create_diagonal",
2629
"expand_dims",
2730
"kron",
31+
"nunique",
2832
"pad",
2933
"setdiff1d",
3034
"sinc",
@@ -638,6 +642,42 @@ def pad(
638642
return at(padded, tuple(slices)).set(x)
639643

640644

645+
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
646+
"""
647+
Count the number of unique elements in an array.
648+
649+
Compatible with JAX and Dask, whose laziness would be otherwise
650+
problematic.
651+
652+
Parameters
653+
----------
654+
x : Array
655+
Input array.
656+
xp : array_namespace, optional
657+
The standard-compatible namespace for `x`. Default: infer.
658+
659+
Returns
660+
-------
661+
array: Scalar integer array
662+
The number of unique elements in `x`. It can be lazy.
663+
"""
664+
if xp is None:
665+
xp = array_namespace(x)
666+
667+
if is_jax_array(x):
668+
# size= is JAX-specific
669+
# https://github.com/data-apis/array-api/issues/883
670+
_, counts = xp.unique_counts(x, size=size(x))
671+
return xp.astype(counts, xp.bool).sum()
672+
673+
_, counts = xp.unique_counts(x)
674+
n = size(counts)
675+
# FIXME https://github.com/data-apis/array-api-compat/pull/231
676+
if n is None or math.isnan(n): # e.g. Dask, ndonnx
677+
return xp.astype(counts, xp.bool).sum()
678+
return xp.asarray(n, device=device(x))
679+
680+
641681
class _AtOp(Enum):
642682
"""Operations for use in `xpx.at`."""
643683

Diff for: tests/test_funcs.py

+19
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
create_diagonal,
1212
expand_dims,
1313
kron,
14+
nunique,
1415
pad,
1516
setdiff1d,
1617
sinc,
@@ -448,3 +449,21 @@ def test_list_of_tuples_width(self, xp: ModuleType):
448449

449450
padded = pad(a, [(1, 0), (0, 0)])
450451
assert padded.shape == (4, 4)
452+
453+
454+
class TestNUnique:
455+
def test_simple(self, xp: ModuleType):
456+
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
457+
xp_assert_equal(nunique(a), xp.asarray(3))
458+
459+
def test_empty(self, xp: ModuleType):
460+
a = xp.asarray([])
461+
xp_assert_equal(nunique(a), xp.asarray(0))
462+
463+
def test_device(self, xp: ModuleType, device: Device):
464+
a = xp.asarray(0.0, device=device)
465+
assert get_device(nunique(a)) == device
466+
467+
def test_xp(self, xp: ModuleType):
468+
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
469+
xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))

0 commit comments

Comments
 (0)