|
13 | 13 | from array_api_extra._lib._at import _AtOp
|
14 | 14 | from array_api_extra._lib._testing import xp_assert_equal
|
15 | 15 | from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
|
16 |
| -from array_api_extra._lib._utils._typing import Array, SetIndex |
| 16 | +from array_api_extra._lib._utils._compat import device as get_device |
| 17 | +from array_api_extra._lib._utils._typing import Array, Device, SetIndex |
17 | 18 | from array_api_extra.testing import lazy_xp_function
|
18 | 19 |
|
19 | 20 | pytestmark = [
|
@@ -327,3 +328,15 @@ def test_gh134(xp: ModuleType, bool_mask: bool, copy: bool | None):
|
327 | 328 | idx = xp.asarray(True) if bool_mask else ()
|
328 | 329 | z = at_op(y, idx, _AtOp.SET, 1, copy=copy)
|
329 | 330 | xp_assert_equal(z, xp.asarray(1, dtype=x.dtype))
|
| 331 | + |
| 332 | + |
| 333 | +def test_device(xp: ModuleType, device: Device): |
| 334 | + x = xp.asarray([1, 2, 3], device=device) |
| 335 | + |
| 336 | + y = xp.asarray([4, 5], device=device) |
| 337 | + z = at(x)[:2].set(y) |
| 338 | + assert get_device(z) == get_device(x) |
| 339 | + |
| 340 | + idx = xp.asarray([True, False, True], device=device) |
| 341 | + z = at(x)[idx].set(4) |
| 342 | + assert get_device(z) == get_device(x) |
0 commit comments