Skip to content

Commit 93cc035

Browse files
committed
WIP ENH: setdiff1d for Dask and jax.jit
1 parent 59687b3 commit 93cc035

File tree

4 files changed

+96
-112
lines changed

4 files changed

+96
-112
lines changed

Diff for: src/array_api_extra/_lib/_funcs.py

+92-8
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
14+
from ._utils._compat import (
15+
array_namespace,
16+
is_dask_namespace,
17+
is_jax_array,
18+
is_jax_namespace,
19+
)
1520
from ._utils._helpers import asarrays
1621
from ._utils._typing import Array
1722

@@ -547,6 +552,7 @@ def setdiff1d(
547552
/,
548553
*,
549554
assume_unique: bool = False,
555+
fill_value: object | None = None,
550556
xp: ModuleType | None = None,
551557
) -> Array:
552558
"""
@@ -563,6 +569,11 @@ def setdiff1d(
563569
assume_unique : bool
564570
If ``True``, the input arrays are both assumed to be unique, which
565571
can speed up the calculation. Default is ``False``.
572+
fill_value : object, optional
573+
Pad the output array with this value.
574+
575+
This is exclusively used for JAX arrays when running inside ``jax.jit``,
576+
where all array shapes need to be known in advance.
566577
xp : array_namespace, optional
567578
The standard-compatible namespace for `x1` and `x2`. Default: infer.
568579
@@ -587,13 +598,86 @@ def setdiff1d(
587598
xp = array_namespace(x1, x2)
588599
x1, x2 = asarrays(x1, x2, xp=xp)
589600

590-
if assume_unique:
591-
x1 = xp.reshape(x1, (-1,))
592-
x2 = xp.reshape(x2, (-1,))
593-
else:
594-
x1 = xp.unique_values(x1)
595-
x2 = xp.unique_values(x2)
596-
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
601+
x1 = xp.reshape(x1, (-1,))
602+
x2 = xp.reshape(x2, (-1,))
603+
if x1.shape == (0,) or x2.shape == (0,):
604+
return x1
605+
606+
def _x1_not_in_x2(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
607+
"""For each element of x1, return True if it is not also in x2."""
608+
# Even when assume_unique=True, there is no provision for x to be sorted
609+
x2 = xp.sort(x2)
610+
idx = xp.searchsorted(x2, x1)
611+
612+
# FIXME at() is faster but needs JAX jit support for bool mask
613+
# idx = at(idx, idx == x2.shape[0]).set(0)
614+
idx = xp.where(idx == x2.shape[0], xp.zeros_like(idx), idx)
615+
616+
return xp.take(x2, idx, axis=0) != x1
617+
618+
def _generic_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
619+
"""Generic implementation (including eager JAX)."""
620+
# Note: there is no provision in the Array API for xp.unique_values to sort
621+
if not assume_unique:
622+
# Call unique_values early to speed up the algorithm
623+
x1 = xp.unique_values(x1)
624+
x2 = xp.unique_values(x2)
625+
mask = _x1_not_in_x2(x1, x2)
626+
x1 = x1[mask]
627+
return x1 if assume_unique else xp.sort(x1)
628+
629+
def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
630+
"""
631+
Dask implementation.
632+
633+
Works around unique_values returning unknown shapes.
634+
"""
635+
# Do not call unique_values yet, as it would make array shapes unknown
636+
mask = _x1_not_in_x2(x1, x2)
637+
x1 = x1[mask]
638+
# Note: da.unique_values sorts
639+
return x1 if assume_unique else xp.unique_values(x1)
640+
641+
def _jax_jit_impl(
642+
x1: Array, x2: Array, fill_value: object | None
643+
) -> Array: # numpydoc ignore=PR01,RT01
644+
"""
645+
JAX implementation inside jax.jit.
646+
647+
Works around unique_values requiring a size= parameter
648+
and not being able to filter by a boolean mask.
649+
Returns array the same size as x1, padded with fill_value.
650+
"""
651+
# unique_values inside jax.jit is not supported unless it's got a fixed size
652+
mask = _x1_not_in_x2(x1, x2)
653+
654+
if fill_value is None:
655+
fill_value = xp.zeros((), dtype=x1.dtype)
656+
else:
657+
fill_value = xp.asarray(fill_value, dtype=x1.dtype)
658+
if cast(Array, fill_value).ndim != 0:
659+
msg = "`fill_value` must be a scalar."
660+
raise ValueError(msg)
661+
662+
x1 = xp.where(mask, x1, fill_value)
663+
# Note: jnp.unique_values sorts
664+
return xp.unique_values(x1, size=x1.size, fill_value=fill_value)
665+
666+
if is_dask_namespace(xp):
667+
return _dask_impl(x1, x2)
668+
669+
if is_jax_namespace(xp):
670+
import jax
671+
672+
try:
673+
return _generic_impl(x1, x2) # eager mode
674+
except (
675+
jax.errors.ConcretizationTypeError,
676+
jax.errors.NonConcreteBooleanIndexError,
677+
):
678+
return _jax_jit_impl(x1, x2, fill_value) # inside jax.jit
679+
680+
return _generic_impl(x1, x2)
597681

598682

599683
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:

Diff for: src/array_api_extra/_lib/_utils/_helpers.py

+1-60
Original file line numberDiff line numberDiff line change
@@ -10,66 +10,7 @@
1010
from ._compat import is_array_api_obj, is_numpy_array
1111
from ._typing import Array
1212

13-
__all__ = ["in1d", "mean"]
14-
15-
16-
def in1d(
17-
x1: Array,
18-
x2: Array,
19-
/,
20-
*,
21-
assume_unique: bool = False,
22-
invert: bool = False,
23-
xp: ModuleType | None = None,
24-
) -> Array: # numpydoc ignore=PR01,RT01
25-
"""
26-
Check whether each element of an array is also present in a second array.
27-
28-
Returns a boolean array the same length as `x1` that is True
29-
where an element of `x1` is in `x2` and False otherwise.
30-
31-
This function has been adapted using the original implementation
32-
present in numpy:
33-
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
34-
"""
35-
if xp is None:
36-
xp = _compat.array_namespace(x1, x2)
37-
38-
# This code is run to make the code significantly faster
39-
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
40-
if invert:
41-
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
42-
for a in x2:
43-
mask &= x1 != a
44-
else:
45-
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
46-
for a in x2:
47-
mask |= x1 == a
48-
return mask
49-
50-
rev_idx = xp.empty(0) # placeholder
51-
if not assume_unique:
52-
x1, rev_idx = xp.unique_inverse(x1)
53-
x2 = xp.unique_values(x2)
54-
55-
ar = xp.concat((x1, x2))
56-
device_ = _compat.device(ar)
57-
# We need this to be a stable sort.
58-
order = xp.argsort(ar, stable=True)
59-
reverse_order = xp.argsort(order, stable=True)
60-
sar = xp.take(ar, order, axis=0)
61-
ar_size = _compat.size(sar)
62-
assert ar_size is not None, "xp.unique*() on lazy backends raises"
63-
if ar_size >= 1:
64-
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
65-
else:
66-
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
67-
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
68-
ret = xp.take(flag, reverse_order, axis=0)
69-
70-
if assume_unique:
71-
return ret[: x1.shape[0]]
72-
return xp.take(ret, rev_idx, axis=0)
13+
__all__ = ["mean"]
7314

7415

7516
def mean(

Diff for: tests/test_funcs.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535
lazy_xp_function(kron, static_argnames="xp")
3636
lazy_xp_function(nunique, static_argnames="xp")
3737
lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp"))
38-
# FIXME calls in1d which calls xp.unique_values without size
39-
lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp"))
38+
lazy_xp_function(setdiff1d, static_argnames=("assume_unique", "xp"))
4039
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
4140
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")
4241

@@ -576,8 +575,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
576575
assert padded.shape == (4, 4)
577576

578577

579-
@pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort")
580-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
578+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no sort")
581579
class TestSetDiff1D:
582580
@pytest.mark.skip_xp_backend(
583581
Backend.TORCH, reason="index_select not implemented for uint32"

Diff for: tests/test_utils.py

+1-40
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,10 @@
44
import pytest
55

66
from array_api_extra._lib import Backend
7-
from array_api_extra._lib._testing import xp_assert_equal
8-
from array_api_extra._lib._utils._compat import device as get_device
9-
from array_api_extra._lib._utils._helpers import asarrays, in1d
10-
from array_api_extra._lib._utils._typing import Device
11-
from array_api_extra.testing import lazy_xp_function
7+
from array_api_extra._lib._utils._helpers import asarrays
128

139
# mypy: disable-error-code=no-untyped-usage
1410

15-
# FIXME calls xp.unique_values without size
16-
lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))
17-
18-
19-
class TestIn1D:
20-
@pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort")
21-
@pytest.mark.skip_xp_backend(
22-
Backend.SPARSE, reason="no unique_inverse, no device kwarg in asarray"
23-
)
24-
# cover both code paths
25-
@pytest.mark.parametrize("n", [9, 15])
26-
def test_no_invert_assume_unique(self, xp: ModuleType, n: int):
27-
x1 = xp.asarray([3, 8, 20])
28-
x2 = xp.arange(n)
29-
expected = xp.asarray([True, True, False])
30-
actual = in1d(x1, x2)
31-
xp_assert_equal(actual, expected)
32-
33-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
34-
def test_device(self, xp: ModuleType, device: Device):
35-
x1 = xp.asarray([3, 8, 20], device=device)
36-
x2 = xp.asarray([2, 3, 4], device=device)
37-
assert get_device(in1d(x1, x2)) == device
38-
39-
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="explicit xp")
40-
@pytest.mark.skip_xp_backend(
41-
Backend.SPARSE, reason="no arange, no device kwarg in asarray"
42-
)
43-
def test_xp(self, xp: ModuleType):
44-
x1 = xp.asarray([1, 6])
45-
x2 = xp.arange(5)
46-
expected = xp.asarray([True, False])
47-
actual = in1d(x1, x2, xp=xp)
48-
xp_assert_equal(actual, expected)
49-
5011

5112
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
5213
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)