Skip to content

Add searchsorted to the specification #699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 11, 2024
1 change: 1 addition & 0 deletions spec/draft/API_specification/searching_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ Objects in API
argmax
argmin
nonzero
searchsorted
where
54 changes: 52 additions & 2 deletions src/array_api_stubs/_draft/searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ["argmax", "argmin", "nonzero", "where"]
__all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"]


from ._types import Optional, Tuple, array
from ._types import Optional, Tuple, Literal, array


def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
Expand Down Expand Up @@ -87,6 +87,56 @@ def nonzero(x: array, /) -> Tuple[array, ...]:
"""


def searchsorted(
x1: array,
x2: array,
/,
*,
side: Literal["left", "right"] = "left",
sorter: Optional[array] = None,
) -> array:
"""
Finds the indices into ``x1`` such that, if the corresponding elements in ``x2`` were inserted before the indices, the order of ``x1``, when sorted in ascending order, would be preserved.

Parameters
----------
x1: array
input array. Must be a one-dimensional array. Should have a real-valued data type. If ``sorter`` is ``None``, must be sorted in ascending order; otherwise, ``sorter`` must be an array of indices that sort ``x1`` in ascending order.
x2: array
array containing search values. Should have a real-valued data type.
side: Literal['left', 'right']
argument controlling which index is returned if a value lands exactly on an edge.

Let ``x`` be an array of rank ``N`` where ``v`` is an individual element given by ``v = x2[n,m,...,j]``.

If ``side == 'left'``, then

- each returned index ``i`` must satisfy the index condition ``x1[i-1] < v <= x1[i]``.
- if no index satisfies the index condition, then the returned index for that element must be ``0``.

Otherwise, if ``side == 'right'``, then

- each returned index ``i`` must satisfy the index condition ``x1[i-1] <= v < x1[i]``.
- if no index satisfies the index condition, then the returned index for that element must be ``N``, where ``N`` is the number of elements in ``x1``.

Default: ``'left'``.
sorter: Optional[array]
array of indices that sort ``x1`` in ascending order. The array must have the same shape as ``x1`` and have an integer data type. Default: ``None``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we only require this to be the default array index type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would that be necessary?


Returns
-------
out: array
an array of indices with the same shape as ``x2``. The returned array must have the default array index data type.

Notes
-----

For real-valued floating-point arrays, the sort order of NaNs and signed zeros is unspecified and thus implementation-dependent. Accordingly, when a real-valued floating-point array contains NaNs and signed zeros, what constitutes ascending order may vary among specification-conforming array libraries.

While behavior for arrays containing NaNs and signed zeros is implementation-dependent, specification-conforming libraries should, however, ensure consistency with ``sort`` and ``argsort`` (i.e., if a value in ``x2`` is inserted into ``x1`` according to the corresponding index in the output array and ``sort`` is invoked on the resultant array, the sorted result should be an array in the same order).
"""


def where(condition: array, x1: array, x2: array, /) -> array:
"""
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.
Expand Down