Skip to content

Commit fe8fa8b

Browse files
lucascolleyOmarManzoorogrisel
committed
ENH: add setdiff1d
Co-authored-by: Omar Salman <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
1 parent 9f28511 commit fe8fa8b

File tree

6 files changed

+263
-5
lines changed

6 files changed

+263
-5
lines changed

Diff for: codecov.yml

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
comment: false
22
ignore:
3+
- "src/array_api_extra/_compat"
34
- "src/array_api_extra/_typing"

Diff for: src/array_api_extra/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc
3+
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
44

55
__version__ = "0.2.1.dev0"
66

7+
# pylint: disable=duplicate-code
78
__all__ = [
89
"__version__",
910
"atleast_nd",
1011
"cov",
1112
"create_diagonal",
1213
"expand_dims",
1314
"kron",
15+
"setdiff1d",
1416
"sinc",
1517
]

Diff for: src/array_api_extra/_compat.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
### Helpers borrowed from array-api-compat
2+
3+
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
4+
5+
import inspect
6+
import sys
7+
import typing
8+
9+
if typing.TYPE_CHECKING:
10+
from ._typing import Array, Device
11+
12+
__all__ = ["device"]
13+
14+
15+
# Placeholder object to represent the dask device
16+
# when the array backend is not the CPU.
17+
# (since it is not easy to tell which device a dask array is on)
18+
class _dask_device: # pylint: disable=invalid-name
19+
def __repr__(self) -> str:
20+
return "DASK_DEVICE"
21+
22+
23+
_DASK_DEVICE = _dask_device()
24+
25+
26+
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
27+
# or cupy.ndarray. They are not included in array objects of this library
28+
# because this library just reuses the respective ndarray classes without
29+
# wrapping or subclassing them. These helper functions can be used instead of
30+
# the wrapper functions for libraries that need to support both NumPy/CuPy and
31+
# other libraries that use devices.
32+
def device(x: Array, /) -> Device:
33+
"""
34+
Hardware device the array data resides on.
35+
36+
This is equivalent to `x.device` according to the `standard
37+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
38+
This helper is included because some array libraries either do not have
39+
the `device` attribute or include it with an incompatible API.
40+
41+
Parameters
42+
----------
43+
x: array
44+
array instance from an array API compatible library.
45+
46+
Returns
47+
-------
48+
out: device
49+
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
50+
section of the array API specification).
51+
52+
Notes
53+
-----
54+
55+
For NumPy the device is always `"cpu"`. For Dask, the device is always a
56+
special `DASK_DEVICE` object.
57+
58+
See Also
59+
--------
60+
61+
to_device : Move array data to a different device.
62+
63+
"""
64+
if _is_numpy_array(x):
65+
return "cpu"
66+
if _is_dask_array(x):
67+
# Peek at the metadata of the jax array to determine type
68+
try:
69+
import numpy as np # pylint: disable=import-outside-toplevel
70+
71+
if isinstance(x._meta, np.ndarray): # pylint: disable=protected-access
72+
# Must be on CPU since backed by numpy
73+
return "cpu"
74+
except ImportError:
75+
pass
76+
return _DASK_DEVICE
77+
if _is_jax_array(x):
78+
# JAX has .device() as a method, but it is being deprecated so that it
79+
# can become a property, in accordance with the standard. In order for
80+
# this function to not break when JAX makes the flip, we check for
81+
# both here.
82+
if inspect.ismethod(x.device):
83+
return x.device()
84+
return x.device
85+
if _is_pydata_sparse_array(x):
86+
# `sparse` will gain `.device`, so check for this first.
87+
x_device = getattr(x, "device", None)
88+
if x_device is not None:
89+
return x_device
90+
# Everything but DOK has this attr.
91+
try:
92+
inner = x.data
93+
except AttributeError:
94+
return "cpu"
95+
# Return the device of the constituent array
96+
return device(inner)
97+
return x.device
98+
99+
100+
def _is_numpy_array(x: Array) -> bool:
101+
"""Return True if `x` is a NumPy array."""
102+
# Avoid importing NumPy if it isn't already
103+
if "numpy" not in sys.modules:
104+
return False
105+
106+
import numpy as np # pylint: disable=import-outside-toplevel
107+
108+
# TODO: Should we reject ndarray subclasses?
109+
return isinstance(x, (np.ndarray, np.generic)) and not _is_jax_zero_gradient_array(
110+
x
111+
)
112+
113+
114+
def _is_dask_array(x: Array) -> bool:
115+
"""Return True if `x` is a dask.array Array."""
116+
# Avoid importing dask if it isn't already
117+
if "dask.array" not in sys.modules:
118+
return False
119+
120+
# pylint: disable=import-error, import-outside-toplevel
121+
import dask.array # type: ignore[import-not-found]
122+
123+
return isinstance(x, dask.array.Array)
124+
125+
126+
def _is_jax_zero_gradient_array(x: Array) -> bool:
127+
"""Return True if `x` is a zero-gradient array.
128+
129+
These arrays are a design quirk of Jax that may one day be removed.
130+
See https://github.com/google/jax/issues/20620.
131+
"""
132+
if "numpy" not in sys.modules or "jax" not in sys.modules:
133+
return False
134+
135+
# pylint: disable=import-error, import-outside-toplevel
136+
import jax # type: ignore[import-not-found]
137+
import numpy as np # pylint: disable=import-outside-toplevel
138+
139+
return isinstance(x, np.ndarray) and x.dtype == jax.float0
140+
141+
142+
def _is_jax_array(x: Array) -> bool:
143+
"""Return True if `x` is a JAX array."""
144+
# Avoid importing jax if it isn't already
145+
if "jax" not in sys.modules:
146+
return False
147+
148+
# pylint: disable=import-error, import-outside-toplevel
149+
import jax
150+
151+
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
152+
153+
154+
def _is_pydata_sparse_array(x: Array) -> bool:
155+
"""Return True if `x` is an array from the `sparse` package."""
156+
157+
# Avoid importing jax if it isn't already
158+
if "sparse" not in sys.modules:
159+
return False
160+
161+
# pylint: disable=import-error, import-outside-toplevel
162+
import sparse # type: ignore[import-not-found]
163+
164+
# TODO: Account for other backends.
165+
return isinstance(x, sparse.SparseArray)

Diff for: src/array_api_extra/_funcs.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1-
from __future__ import annotations
1+
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

33
import typing
44
import warnings
55

66
if typing.TYPE_CHECKING:
77
from ._typing import Array, ModuleType
88

9-
__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"]
9+
from . import _utils
10+
11+
__all__ = [
12+
"atleast_nd",
13+
"cov",
14+
"create_diagonal",
15+
"expand_dims",
16+
"kron",
17+
"setdiff1d",
18+
"sinc",
19+
]
1020

1121

1222
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
@@ -399,6 +409,22 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
399409
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))
400410

401411

412+
def setdiff1d(
413+
x1: Array, x2: Array, /, *, assume_unique: bool = False, xp: ModuleType
414+
) -> Array:
415+
"""Find the set difference of two arrays.
416+
417+
Return the unique values in `x1` that are not in `x2`.
418+
"""
419+
420+
if assume_unique:
421+
x1 = xp.reshape(x1, (-1,))
422+
else:
423+
x1 = xp.unique_values(x1)
424+
x2 = xp.unique_values(x2)
425+
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
426+
427+
402428
def sinc(x: Array, /, *, xp: ModuleType) -> Array:
403429
r"""
404430
Return the normalized sinc function.

Diff for: src/array_api_extra/_typing.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from __future__ import annotations
1+
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

33
from types import ModuleType
44
from typing import Any
55

66
# To be changed to a Protocol later (see data-apis/array-api#589)
77
Array = Any # type: ignore[no-any-explicit]
8+
Device = Any
89

9-
__all__ = ["Array", "ModuleType"]
10+
__all__ = ["Array", "Device", "ModuleType"]

Diff for: src/array_api_extra/_utils.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
2+
3+
import typing
4+
5+
if typing.TYPE_CHECKING:
6+
from ._typing import Array, ModuleType
7+
8+
from . import _compat
9+
10+
__all__ = ["in1d"]
11+
12+
13+
def in1d(
14+
x1: Array,
15+
x2: Array,
16+
/,
17+
*,
18+
assume_unique: bool = False,
19+
invert: bool = False,
20+
xp: ModuleType,
21+
) -> Array:
22+
"""Checks whether each element of an array is also present in a
23+
second array.
24+
25+
Returns a boolean array the same length as `x1` that is True
26+
where an element of `x1` is in `x2` and False otherwise.
27+
28+
This function has been adapted using the original implementation
29+
present in numpy:
30+
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
31+
"""
32+
33+
# This code is run to make the code significantly faster
34+
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
35+
if invert:
36+
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=x1.device)
37+
for a in x2:
38+
mask &= x1 != a
39+
else:
40+
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=x1.device)
41+
for a in x2:
42+
mask |= x1 == a
43+
return mask
44+
45+
if not assume_unique:
46+
x1, rev_idx = xp.unique_inverse(x1)
47+
x2 = xp.unique_values(x2)
48+
49+
ar = xp.concat((x1, x2))
50+
device_ = _compat.device(ar)
51+
# We need this to be a stable sort.
52+
order = xp.argsort(ar, stable=True)
53+
reverse_order = xp.argsort(order, stable=True)
54+
sar = xp.take(ar, order, axis=0)
55+
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
56+
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
57+
ret = xp.take(flag, reverse_order, axis=0)
58+
59+
if assume_unique:
60+
return ret[: x1.shape[0]]
61+
# https://github.com/pylint-dev/pylint/issues/10095
62+
# pylint: disable=possibly-used-before-assignment
63+
return xp.take(ret, rev_idx, axis=0)

0 commit comments

Comments
 (0)