Skip to content

Commit 47722cc

Browse files
authored
Merge pull request #180 from crusaderky/at_device
TST: `at`: add test for device
2 parents f55076f + 3a822e2 commit 47722cc

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

tests/test_at.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from array_api_extra._lib._at import _AtOp
1414
from array_api_extra._lib._testing import xp_assert_equal
1515
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
16-
from array_api_extra._lib._utils._typing import Array, SetIndex
16+
from array_api_extra._lib._utils._compat import device as get_device
17+
from array_api_extra._lib._utils._typing import Array, Device, SetIndex
1718
from array_api_extra.testing import lazy_xp_function
1819

1920
pytestmark = [
@@ -327,3 +328,15 @@ def test_gh134(xp: ModuleType, bool_mask: bool, copy: bool | None):
327328
idx = xp.asarray(True) if bool_mask else ()
328329
z = at_op(y, idx, _AtOp.SET, 1, copy=copy)
329330
xp_assert_equal(z, xp.asarray(1, dtype=x.dtype))
331+
332+
333+
def test_device(xp: ModuleType, device: Device):
334+
x = xp.asarray([1, 2, 3], device=device)
335+
336+
y = xp.asarray([4, 5], device=device)
337+
z = at(x)[:2].set(y)
338+
assert get_device(z) == get_device(x)
339+
340+
idx = xp.asarray([True, False, True], device=device)
341+
z = at(x)[idx].set(4)
342+
assert get_device(z) == get_device(x)

0 commit comments

Comments
 (0)