Skip to content

Commit ec890f1

Browse files
ENH: new function nunique (#90)
Co-authored-by: Lucas Colley <[email protected]>
1 parent cc9f403 commit ec890f1

File tree

6 files changed

+62
-1
lines changed

6 files changed

+62
-1
lines changed

Diff for: docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
create_diagonal
1313
expand_dims
1414
kron
15+
nunique
1516
setdiff1d
1617
sinc
1718
```

Diff for: pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ messages_control.disable = [
299299
"line-too-long",
300300
"missing-module-docstring",
301301
"missing-function-docstring",
302+
"too-many-lines",
302303
"wrong-import-position",
303304
]
304305

Diff for: src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
create_diagonal,
88
expand_dims,
99
kron,
10+
nunique,
1011
pad,
1112
setdiff1d,
1213
sinc,
@@ -23,6 +24,7 @@
2324
"create_diagonal",
2425
"expand_dims",
2526
"kron",
27+
"nunique",
2628
"pad",
2729
"setdiff1d",
2830
"sinc",

Diff for: src/array_api_extra/_funcs.py

+38
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6+
import math
67
import operator
78
import warnings
89
from collections.abc import Callable
@@ -25,6 +26,7 @@
2526
"create_diagonal",
2627
"expand_dims",
2728
"kron",
29+
"nunique",
2830
"pad",
2931
"setdiff1d",
3032
"sinc",
@@ -638,6 +640,42 @@ def pad(
638640
return at(padded, tuple(slices)).set(x)
639641

640642

643+
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
644+
"""
645+
Count the number of unique elements in an array.
646+
647+
Compatible with JAX and Dask, whose laziness would be otherwise
648+
problematic.
649+
650+
Parameters
651+
----------
652+
x : Array
653+
Input array.
654+
xp : array_namespace, optional
655+
The standard-compatible namespace for `x`. Default: infer.
656+
657+
Returns
658+
-------
659+
array: 0-dimensional integer array
660+
The number of unique elements in `x`. It can be lazy.
661+
"""
662+
if xp is None:
663+
xp = array_namespace(x)
664+
665+
if is_jax_array(x):
666+
# size= is JAX-specific
667+
# https://github.com/data-apis/array-api/issues/883
668+
_, counts = xp.unique_counts(x, size=_compat.size(x))
669+
return xp.astype(counts, xp.bool).sum()
670+
671+
_, counts = xp.unique_counts(x)
672+
n = _compat.size(counts)
673+
# FIXME https://github.com/data-apis/array-api-compat/pull/231
674+
if n is None or math.isnan(n): # e.g. Dask, ndonnx
675+
return xp.astype(counts, xp.bool).sum()
676+
return xp.asarray(n, device=_compat.device(x))
677+
678+
641679
class _AtOp(Enum):
642680
"""Operations for use in `xpx.at`."""
643681

Diff for: tests/test_funcs.py

+19
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
create_diagonal,
1212
expand_dims,
1313
kron,
14+
nunique,
1415
pad,
1516
setdiff1d,
1617
sinc,
@@ -448,3 +449,21 @@ def test_list_of_tuples_width(self, xp: ModuleType):
448449

449450
padded = pad(a, [(1, 0), (0, 0)])
450451
assert padded.shape == (4, 4)
452+
453+
454+
class TestNUnique:
455+
def test_simple(self, xp: ModuleType):
456+
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
457+
xp_assert_equal(nunique(a), xp.asarray(3))
458+
459+
def test_empty(self, xp: ModuleType):
460+
a = xp.asarray([])
461+
xp_assert_equal(nunique(a), xp.asarray(0))
462+
463+
def test_device(self, xp: ModuleType, device: Device):
464+
a = xp.asarray(0.0, device=device)
465+
assert get_device(nunique(a)) == device
466+
467+
def test_xp(self, xp: ModuleType):
468+
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
469+
xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))

0 commit comments

Comments
 (0)