Skip to content

Commit ee024f5

Browse files
lucascolleyOmarManzoorogrisel
authored
ENH: add setdiff1d (#35)
* ENH: add `setdiff1d` Co-authored-by: Omar Salman <[email protected]> Co-authored-by: Olivier Grisel <[email protected]> * (temp) switch to pyright * appease linter * cleanup * upgrade deps * TST: setdiff1d: add tests * TST: `_utils.in1d`: add tests * adjust test to hit alternative path * cover all cases * DOC: setdiff1d: add docs --------- Co-authored-by: Omar Salman <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
1 parent 9f28511 commit ee024f5

13 files changed

+473
-94
lines changed

Diff for: codecov.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
comment: false
22
ignore:
3-
- "src/array_api_extra/_typing"
3+
- "src/array_api_extra/_lib/_compat"
4+
- "src/array_api_extra/_lib/_typing"

Diff for: docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
create_diagonal
1212
expand_dims
1313
kron
14+
setdiff1d
1415
sinc
1516
```

Diff for: pixi.lock

+98-75
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = []
29+
dependencies = ["typing-extensions"]
3030

3131
[project.optional-dependencies]
3232
tests = [
@@ -64,6 +64,7 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
6464

6565
[tool.pixi.dependencies]
6666
python = ">=3.10.15,<3.14"
67+
typing_extensions = ">=4.12.2,<4.13"
6768

6869
[tool.pixi.pypi-dependencies]
6970
array-api-extra = { path = ".", editable = true }

Diff for: src/array_api_extra/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc
3+
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
44

55
__version__ = "0.2.1.dev0"
66

7+
# pylint: disable=duplicate-code
78
__all__ = [
89
"__version__",
910
"atleast_nd",
1011
"cov",
1112
"create_diagonal",
1213
"expand_dims",
1314
"kron",
15+
"setdiff1d",
1416
"sinc",
1517
]

Diff for: src/array_api_extra/_funcs.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1-
from __future__ import annotations
1+
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

33
import typing
44
import warnings
55

66
if typing.TYPE_CHECKING:
7-
from ._typing import Array, ModuleType
7+
from ._lib._typing import Array, ModuleType
88

9-
__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"]
9+
from ._lib import _utils
10+
11+
__all__ = [
12+
"atleast_nd",
13+
"cov",
14+
"create_diagonal",
15+
"expand_dims",
16+
"kron",
17+
"setdiff1d",
18+
"sinc",
19+
]
1020

1121

1222
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
@@ -399,6 +409,53 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
399409
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))
400410

401411

412+
def setdiff1d(
413+
x1: Array, x2: Array, /, *, assume_unique: bool = False, xp: ModuleType
414+
) -> Array:
415+
"""
416+
Find the set difference of two arrays.
417+
418+
Return the unique values in `x1` that are not in `x2`.
419+
420+
Parameters
421+
----------
422+
x1 : array
423+
Input array.
424+
x2 : array
425+
Input comparison array.
426+
assume_unique : bool
427+
If ``True``, the input arrays are both assumed to be unique, which
428+
can speed up the calculation. Default is ``False``.
429+
xp : array_namespace
430+
The standard-compatible namespace for `x1` and `x2`.
431+
432+
Returns
433+
-------
434+
res : array
435+
1D array of values in `x1` that are not in `x2`. The result
436+
is sorted when `assume_unique` is ``False``, but otherwise only sorted
437+
if the input is sorted.
438+
439+
Examples
440+
--------
441+
>>> import array_api_strict as xp
442+
>>> import array_api_extra as xpx
443+
444+
>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
445+
>>> x2 = xp.asarray([3, 4, 5, 6])
446+
>>> xpx.setdiff1d(x1, x2, xp=xp)
447+
Array([1, 2], dtype=array_api_strict.int64)
448+
449+
"""
450+
451+
if assume_unique:
452+
x1 = xp.reshape(x1, (-1,))
453+
else:
454+
x1 = xp.unique_values(x1)
455+
x2 = xp.unique_values(x2)
456+
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
457+
458+
402459
def sinc(x: Array, /, *, xp: ModuleType) -> Array:
403460
r"""
404461
Return the normalized sinc function.

Diff for: src/array_api_extra/_lib/_compat.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
### Helpers borrowed from array-api-compat
2+
3+
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
4+
5+
import inspect
6+
import sys
7+
import typing
8+
9+
from typing_extensions import override
10+
11+
if typing.TYPE_CHECKING:
12+
from ._typing import Array, Device
13+
14+
__all__ = ["device"]
15+
16+
17+
# Placeholder object to represent the dask device
18+
# when the array backend is not the CPU.
19+
# (since it is not easy to tell which device a dask array is on)
20+
class _dask_device: # pylint: disable=invalid-name
21+
@override
22+
def __repr__(self) -> str:
23+
return "DASK_DEVICE"
24+
25+
26+
_DASK_DEVICE = _dask_device()
27+
28+
29+
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
30+
# or cupy.ndarray. They are not included in array objects of this library
31+
# because this library just reuses the respective ndarray classes without
32+
# wrapping or subclassing them. These helper functions can be used instead of
33+
# the wrapper functions for libraries that need to support both NumPy/CuPy and
34+
# other libraries that use devices.
35+
def device(x: Array, /) -> Device:
36+
"""
37+
Hardware device the array data resides on.
38+
39+
This is equivalent to `x.device` according to the `standard
40+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
41+
This helper is included because some array libraries either do not have
42+
the `device` attribute or include it with an incompatible API.
43+
44+
Parameters
45+
----------
46+
x: array
47+
array instance from an array API compatible library.
48+
49+
Returns
50+
-------
51+
out: device
52+
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
53+
section of the array API specification).
54+
55+
Notes
56+
-----
57+
58+
For NumPy the device is always `"cpu"`. For Dask, the device is always a
59+
special `DASK_DEVICE` object.
60+
61+
See Also
62+
--------
63+
64+
to_device : Move array data to a different device.
65+
66+
"""
67+
if _is_numpy_array(x):
68+
return "cpu"
69+
if _is_dask_array(x):
70+
# Peek at the metadata of the jax array to determine type
71+
try:
72+
import numpy as np # pylint: disable=import-outside-toplevel
73+
74+
if isinstance(x._meta, np.ndarray): # pylint: disable=protected-access
75+
# Must be on CPU since backed by numpy
76+
return "cpu"
77+
except ImportError:
78+
pass
79+
return _DASK_DEVICE
80+
if _is_jax_array(x):
81+
# JAX has .device() as a method, but it is being deprecated so that it
82+
# can become a property, in accordance with the standard. In order for
83+
# this function to not break when JAX makes the flip, we check for
84+
# both here.
85+
if inspect.ismethod(x.device):
86+
return x.device()
87+
return x.device
88+
if _is_pydata_sparse_array(x):
89+
# `sparse` will gain `.device`, so check for this first.
90+
x_device = getattr(x, "device", None)
91+
if x_device is not None:
92+
return x_device
93+
# Everything but DOK has this attr.
94+
try:
95+
inner = x.data
96+
except AttributeError:
97+
return "cpu"
98+
# Return the device of the constituent array
99+
return device(inner)
100+
return x.device
101+
102+
103+
def _is_numpy_array(x: Array) -> bool:
104+
"""Return True if `x` is a NumPy array."""
105+
# Avoid importing NumPy if it isn't already
106+
if "numpy" not in sys.modules:
107+
return False
108+
109+
import numpy as np # pylint: disable=import-outside-toplevel
110+
111+
# TODO: Should we reject ndarray subclasses?
112+
return isinstance(x, (np.ndarray, np.generic)) and not _is_jax_zero_gradient_array(
113+
x
114+
)
115+
116+
117+
def _is_dask_array(x: Array) -> bool:
118+
"""Return True if `x` is a dask.array Array."""
119+
# Avoid importing dask if it isn't already
120+
if "dask.array" not in sys.modules:
121+
return False
122+
123+
# pylint: disable=import-error, import-outside-toplevel
124+
import dask.array # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
125+
126+
return isinstance(x, dask.array.Array)
127+
128+
129+
def _is_jax_zero_gradient_array(x: Array) -> bool:
130+
"""Return True if `x` is a zero-gradient array.
131+
132+
These arrays are a design quirk of Jax that may one day be removed.
133+
See https://github.com/google/jax/issues/20620.
134+
"""
135+
if "numpy" not in sys.modules or "jax" not in sys.modules:
136+
return False
137+
138+
# pylint: disable=import-error, import-outside-toplevel
139+
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
140+
import numpy as np # pylint: disable=import-outside-toplevel
141+
142+
return isinstance(x, np.ndarray) and x.dtype == jax.float0 # pyright: ignore[reportUnknownVariableType]
143+
144+
145+
def _is_jax_array(x: Array) -> bool:
146+
"""Return True if `x` is a JAX array."""
147+
# Avoid importing jax if it isn't already
148+
if "jax" not in sys.modules:
149+
return False
150+
151+
# pylint: disable=import-error, import-outside-toplevel
152+
import jax # pyright: ignore[reportMissingImports]
153+
154+
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
155+
156+
157+
def _is_pydata_sparse_array(x: Array) -> bool:
158+
"""Return True if `x` is an array from the `sparse` package."""
159+
160+
# Avoid importing jax if it isn't already
161+
if "sparse" not in sys.modules:
162+
return False
163+
164+
# pylint: disable=import-error, import-outside-toplevel
165+
import sparse # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
166+
167+
# TODO: Account for other backends.
168+
return isinstance(x, sparse.SparseArray)

Diff for: src/array_api_extra/_lib/_typing.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
2+
3+
from types import ModuleType
4+
from typing import Any
5+
6+
# To be changed to a Protocol later (see data-apis/array-api#589)
7+
Array = Any # type: ignore[no-any-explicit]
8+
Device = Any # type: ignore[no-any-explicit]
9+
10+
__all__ = ["Array", "Device", "ModuleType"]

Diff for: src/array_api_extra/_lib/_utils.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
2+
3+
import typing
4+
5+
if typing.TYPE_CHECKING:
6+
from ._typing import Array, ModuleType
7+
8+
from . import _compat
9+
10+
__all__ = ["in1d"]
11+
12+
13+
def in1d(
14+
x1: Array,
15+
x2: Array,
16+
/,
17+
*,
18+
assume_unique: bool = False,
19+
invert: bool = False,
20+
xp: ModuleType,
21+
) -> Array:
22+
"""Checks whether each element of an array is also present in a
23+
second array.
24+
25+
Returns a boolean array the same length as `x1` that is True
26+
where an element of `x1` is in `x2` and False otherwise.
27+
28+
This function has been adapted using the original implementation
29+
present in numpy:
30+
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
31+
"""
32+
33+
# This code is run to make the code significantly faster
34+
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
35+
if invert:
36+
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=x1.device)
37+
for a in x2:
38+
mask &= x1 != a
39+
else:
40+
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=x1.device)
41+
for a in x2:
42+
mask |= x1 == a
43+
return mask
44+
45+
rev_idx = xp.empty(0) # placeholder
46+
if not assume_unique:
47+
x1, rev_idx = xp.unique_inverse(x1)
48+
x2 = xp.unique_values(x2)
49+
50+
ar = xp.concat((x1, x2))
51+
device_ = _compat.device(ar)
52+
# We need this to be a stable sort.
53+
order = xp.argsort(ar, stable=True)
54+
reverse_order = xp.argsort(order, stable=True)
55+
sar = xp.take(ar, order, axis=0)
56+
if sar.size >= 1:
57+
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
58+
else:
59+
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
60+
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
61+
ret = xp.take(flag, reverse_order, axis=0)
62+
63+
if assume_unique:
64+
return ret[: x1.shape[0]]
65+
return xp.take(ret, rev_idx, axis=0)

Diff for: src/array_api_extra/_typing.py

-9
This file was deleted.

0 commit comments

Comments
 (0)