@@ -19,7 +19,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
19
19
"""
20
20
if x .dtype not in _real_numeric_dtypes :
21
21
raise TypeError ("Only real numeric dtypes are allowed in argmax" )
22
- return Array ._new (np .asarray (np .argmax (x ._array , axis = axis , keepdims = keepdims )))
22
+ return Array ._new (np .asarray (np .argmax (x ._array , axis = axis , keepdims = keepdims )), device = x . device )
23
23
24
24
25
25
def argmin (x : Array , / , * , axis : Optional [int ] = None , keepdims : bool = False ) -> Array :
@@ -30,7 +30,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
30
30
"""
31
31
if x .dtype not in _real_numeric_dtypes :
32
32
raise TypeError ("Only real numeric dtypes are allowed in argmin" )
33
- return Array ._new (np .asarray (np .argmin (x ._array , axis = axis , keepdims = keepdims )))
33
+ return Array ._new (np .asarray (np .argmin (x ._array , axis = axis , keepdims = keepdims )), device = x . device )
34
34
35
35
36
36
@requires_data_dependent_shapes
@@ -61,12 +61,16 @@ def searchsorted(
61
61
"""
62
62
if x1 .dtype not in _real_numeric_dtypes or x2 .dtype not in _real_numeric_dtypes :
63
63
raise TypeError ("Only real numeric dtypes are allowed in searchsorted" )
64
+
65
+ if x1 .device != x2 .device :
66
+ raise RuntimeError (f"Arrays from two different devices ({ x1 .device } and { x2 .device } ) can not be combined." )
67
+
64
68
sorter = sorter ._array if sorter is not None else None
65
69
# TODO: The sort order of nans and signed zeros is implementation
66
70
# dependent. Should we error/warn if they are present?
67
71
68
72
# x1 must be 1-D, but NumPy already requires this.
69
- return Array ._new (np .searchsorted (x1 ._array , x2 ._array , side = side , sorter = sorter ))
73
+ return Array ._new (np .searchsorted (x1 ._array , x2 ._array , side = side , sorter = sorter ), device = x1 . device )
70
74
71
75
def where (condition : Array , x1 : Array , x2 : Array , / ) -> Array :
72
76
"""
@@ -76,5 +80,9 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
76
80
"""
77
81
# Call result type here just to raise on disallowed type combinations
78
82
_result_type (x1 .dtype , x2 .dtype )
83
+
84
+ if len ({a .device for a in (condition , x1 , x2 )}) > 1 :
85
+ raise ValueError ("where inputs must all be on the same device" )
86
+
79
87
x1 , x2 = Array ._normalize_two_args (x1 , x2 )
80
88
return Array ._new (np .where (condition ._array , x1 ._array , x2 ._array ))
0 commit comments