diff --git a/tests/test_at.py b/tests/test_at.py index d9a5854..218b05b 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -13,7 +13,8 @@ from array_api_extra._lib._at import _AtOp from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array -from array_api_extra._lib._utils._typing import Array, SetIndex +from array_api_extra._lib._utils._compat import device as get_device +from array_api_extra._lib._utils._typing import Array, Device, SetIndex from array_api_extra.testing import lazy_xp_function pytestmark = [ @@ -327,3 +328,15 @@ def test_gh134(xp: ModuleType, bool_mask: bool, copy: bool | None): idx = xp.asarray(True) if bool_mask else () z = at_op(y, idx, _AtOp.SET, 1, copy=copy) xp_assert_equal(z, xp.asarray(1, dtype=x.dtype)) + + +def test_device(xp: ModuleType, device: Device): + x = xp.asarray([1, 2, 3], device=device) + + y = xp.asarray([4, 5], device=device) + z = at(x)[:2].set(y) + assert get_device(z) == get_device(x) + + idx = xp.asarray([True, False, True], device=device) + z = at(x)[idx].set(4) + assert get_device(z) == get_device(x)