Skip to content

Commit 72f86f3

Browse files
authored
Add searchsorted to the specification (#699)
1 parent 5c2423a commit 72f86f3

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

Diff for: spec/draft/API_specification/searching_functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ Objects in API
2323
argmax
2424
argmin
2525
nonzero
26+
searchsorted
2627
where

Diff for: src/array_api_stubs/_draft/searching_functions.py

+52-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
__all__ = ["argmax", "argmin", "nonzero", "where"]
1+
__all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"]
22

33

4-
from ._types import Optional, Tuple, array
4+
from ._types import Optional, Tuple, Literal, array
55

66

77
def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
@@ -87,6 +87,56 @@ def nonzero(x: array, /) -> Tuple[array, ...]:
8787
"""
8888

8989

90+
def searchsorted(
91+
x1: array,
92+
x2: array,
93+
/,
94+
*,
95+
side: Literal["left", "right"] = "left",
96+
sorter: Optional[array] = None,
97+
) -> array:
98+
"""
99+
Finds the indices into ``x1`` such that, if the corresponding elements in ``x2`` were inserted before the indices, the order of ``x1``, when sorted in ascending order, would be preserved.
100+
101+
Parameters
102+
----------
103+
x1: array
104+
input array. Must be a one-dimensional array. Should have a real-valued data type. If ``sorter`` is ``None``, must be sorted in ascending order; otherwise, ``sorter`` must be an array of indices that sort ``x1`` in ascending order.
105+
x2: array
106+
array containing search values. Should have a real-valued data type.
107+
side: Literal['left', 'right']
108+
argument controlling which index is returned if a value lands exactly on an edge.
109+
110+
Let ``x`` be an array of rank ``N`` where ``v`` is an individual element given by ``v = x2[n,m,...,j]``.
111+
112+
If ``side == 'left'``, then
113+
114+
- each returned index ``i`` must satisfy the index condition ``x1[i-1] < v <= x1[i]``.
115+
- if no index satisfies the index condition, then the returned index for that element must be ``0``.
116+
117+
Otherwise, if ``side == 'right'``, then
118+
119+
- each returned index ``i`` must satisfy the index condition ``x1[i-1] <= v < x1[i]``.
120+
- if no index satisfies the index condition, then the returned index for that element must be ``N``, where ``N`` is the number of elements in ``x1``.
121+
122+
Default: ``'left'``.
123+
sorter: Optional[array]
124+
array of indices that sort ``x1`` in ascending order. The array must have the same shape as ``x1`` and have an integer data type. Default: ``None``.
125+
126+
Returns
127+
-------
128+
out: array
129+
an array of indices with the same shape as ``x2``. The returned array must have the default array index data type.
130+
131+
Notes
132+
-----
133+
134+
For real-valued floating-point arrays, the sort order of NaNs and signed zeros is unspecified and thus implementation-dependent. Accordingly, when a real-valued floating-point array contains NaNs and signed zeros, what constitutes ascending order may vary among specification-conforming array libraries.
135+
136+
While behavior for arrays containing NaNs and signed zeros is implementation-dependent, specification-conforming libraries should, however, ensure consistency with ``sort`` and ``argsort`` (i.e., if a value in ``x2`` is inserted into ``x1`` according to the corresponding index in the output array and ``sort`` is invoked on the resultant array, the sorted result should be an array in the same order).
137+
"""
138+
139+
90140
def where(condition: array, x1: array, x2: array, /) -> array:
91141
"""
92142
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.

0 commit comments

Comments
 (0)