Skip to content

Commit ff37de7

Browse files
committed
Add multi-device support for searching
1 parent 9323324 commit ff37de7

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

Diff for: array_api_strict/_searching_functions.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
1919
"""
2020
if x.dtype not in _real_numeric_dtypes:
2121
raise TypeError("Only real numeric dtypes are allowed in argmax")
22-
return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)))
22+
return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)), device=x.device)
2323

2424

2525
def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
@@ -30,7 +30,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
3030
"""
3131
if x.dtype not in _real_numeric_dtypes:
3232
raise TypeError("Only real numeric dtypes are allowed in argmin")
33-
return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)))
33+
return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)), device=x.device)
3434

3535

3636
@requires_data_dependent_shapes
@@ -61,12 +61,16 @@ def searchsorted(
6161
"""
6262
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
6363
raise TypeError("Only real numeric dtypes are allowed in searchsorted")
64+
65+
if x1.device != x2.device:
66+
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
67+
6468
sorter = sorter._array if sorter is not None else None
6569
# TODO: The sort order of nans and signed zeros is implementation
6670
# dependent. Should we error/warn if they are present?
6771

6872
# x1 must be 1-D, but NumPy already requires this.
69-
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter))
73+
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)
7074

7175
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
7276
"""
@@ -76,5 +80,9 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
7680
"""
7781
# Call result type here just to raise on disallowed type combinations
7882
_result_type(x1.dtype, x2.dtype)
83+
84+
if len({a.device for a in (condition, x1, x2)}) > 1:
85+
raise ValueError("where inputs must all be on the same device")
86+
7987
x1, x2 = Array._normalize_two_args(x1, x2)
80-
return Array._new(np.where(condition._array, x1._array, x2._array))
88+
return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device)

0 commit comments

Comments
 (0)