11
11
12
12
from ._at import at
13
13
from ._utils import _compat , _helpers
14
- from ._utils ._compat import array_namespace , is_jax_array
14
+ from ._utils ._compat import (
15
+ array_namespace ,
16
+ is_dask_namespace ,
17
+ is_jax_array ,
18
+ is_jax_namespace ,
19
+ )
15
20
from ._utils ._helpers import asarrays
16
21
from ._utils ._typing import Array
17
22
@@ -547,6 +552,7 @@ def setdiff1d(
547
552
/ ,
548
553
* ,
549
554
assume_unique : bool = False ,
555
+ fill_value : object | None = None ,
550
556
xp : ModuleType | None = None ,
551
557
) -> Array :
552
558
"""
@@ -563,6 +569,11 @@ def setdiff1d(
563
569
assume_unique : bool
564
570
If ``True``, the input arrays are both assumed to be unique, which
565
571
can speed up the calculation. Default is ``False``.
572
+ fill_value : object, optional
573
+ Pad the output array with this value.
574
+
575
+ This is exclusively used for JAX arrays when running inside ``jax.jit``,
576
+ where all array shapes need to be known in advance.
566
577
xp : array_namespace, optional
567
578
The standard-compatible namespace for `x1` and `x2`. Default: infer.
568
579
@@ -587,13 +598,86 @@ def setdiff1d(
587
598
xp = array_namespace (x1 , x2 )
588
599
x1 , x2 = asarrays (x1 , x2 , xp = xp )
589
600
590
- if assume_unique :
591
- x1 = xp .reshape (x1 , (- 1 ,))
592
- x2 = xp .reshape (x2 , (- 1 ,))
593
- else :
594
- x1 = xp .unique_values (x1 )
595
- x2 = xp .unique_values (x2 )
596
- return x1 [_helpers .in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
601
+ x1 = xp .reshape (x1 , (- 1 ,))
602
+ x2 = xp .reshape (x2 , (- 1 ,))
603
+ if x1 .shape == (0 ,) or x2 .shape == (0 ,):
604
+ return x1
605
+
606
+ def _x1_not_in_x2 (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
607
+ """For each element of x1, return True if it is not also in x2."""
608
+ # Even when assume_unique=True, there is no provision for x to be sorted
609
+ x2 = xp .sort (x2 )
610
+ idx = xp .searchsorted (x2 , x1 )
611
+
612
+ # FIXME at() is faster but needs JAX jit support for bool mask
613
+ # idx = at(idx, idx == x2.shape[0]).set(0)
614
+ idx = xp .where (idx == x2 .shape [0 ], xp .zeros_like (idx ), idx )
615
+
616
+ return xp .take (x2 , idx , axis = 0 ) != x1
617
+
618
+ def _generic_impl (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
619
+ """Generic implementation (including eager JAX)."""
620
+ # Note: there is no provision in the Array API for xp.unique_values to sort
621
+ if not assume_unique :
622
+ # Call unique_values early to speed up the algorithm
623
+ x1 = xp .unique_values (x1 )
624
+ x2 = xp .unique_values (x2 )
625
+ mask = _x1_not_in_x2 (x1 , x2 )
626
+ x1 = x1 [mask ]
627
+ return x1 if assume_unique else xp .sort (x1 )
628
+
629
+ def _dask_impl (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
630
+ """
631
+ Dask implementation.
632
+
633
+ Works around unique_values returning unknown shapes.
634
+ """
635
+ # Do not call unique_values yet, as it would make array shapes unknown
636
+ mask = _x1_not_in_x2 (x1 , x2 )
637
+ x1 = x1 [mask ]
638
+ # Note: da.unique_values sorts
639
+ return x1 if assume_unique else xp .unique_values (x1 )
640
+
641
+ def _jax_jit_impl (
642
+ x1 : Array , x2 : Array , fill_value : object | None
643
+ ) -> Array : # numpydoc ignore=PR01,RT01
644
+ """
645
+ JAX implementation inside jax.jit.
646
+
647
+ Works around unique_values requiring a size= parameter
648
+ and not being able to filter by a boolean mask.
649
+ Returns array the same size as x1, padded with fill_value.
650
+ """
651
+ # unique_values inside jax.jit is not supported unless it's got a fixed size
652
+ mask = _x1_not_in_x2 (x1 , x2 )
653
+
654
+ if fill_value is None :
655
+ fill_value = xp .zeros ((), dtype = x1 .dtype )
656
+ else :
657
+ fill_value = xp .asarray (fill_value , dtype = x1 .dtype )
658
+ if cast (Array , fill_value ).ndim != 0 :
659
+ msg = "`fill_value` must be a scalar."
660
+ raise ValueError (msg )
661
+
662
+ x1 = xp .where (mask , x1 , fill_value )
663
+ # Note: jnp.unique_values sorts
664
+ return xp .unique_values (x1 , size = x1 .size , fill_value = fill_value )
665
+
666
+ if is_dask_namespace (xp ):
667
+ return _dask_impl (x1 , x2 )
668
+
669
+ if is_jax_namespace (xp ):
670
+ import jax
671
+
672
+ try :
673
+ return _generic_impl (x1 , x2 ) # eager mode
674
+ except (
675
+ jax .errors .ConcretizationTypeError ,
676
+ jax .errors .NonConcreteBooleanIndexError ,
677
+ ):
678
+ return _jax_jit_impl (x1 , x2 , fill_value ) # inside jax.jit
679
+
680
+ return _generic_impl (x1 , x2 )
597
681
598
682
599
683
def sinc (x : Array , / , * , xp : ModuleType | None = None ) -> Array :
0 commit comments