Skip to content

Commit 6699efb

Browse files
authored
BUG: fix device compat (#63)
1 parent 7ff0d0a commit 6699efb

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

Diff for: src/array_api_extra/_funcs.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22

3-
from ._lib import _utils
3+
from ._lib import _compat, _utils
44
from ._lib._compat import array_namespace
55
from ._lib._typing import Array, ModuleType
66

@@ -200,7 +200,7 @@ def create_diagonal(
200200
err_msg = "`x` must be 1-dimensional."
201201
raise ValueError(err_msg)
202202
n = x.shape[0] + abs(offset)
203-
diag = xp.zeros(n**2, dtype=x.dtype, device=x.device)
203+
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
204204
i = offset if offset >= 0 else abs(offset) * n
205205
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
206206
return xp.reshape(diag, (n, n))
@@ -540,6 +540,6 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
540540
y = xp.pi * xp.where(
541541
xp.astype(x, xp.bool),
542542
x,
543-
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
543+
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
544544
)
545545
return xp.sin(y) / y

0 commit comments

Comments
 (0)