diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 1304d5a..823274e 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -724,8 +724,24 @@ def __getitem__( # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index self._validate_index(key, op="getitem") - # Indexing self._array with array_api_strict arrays can be erroneous - np_key = key._array if isinstance(key, Array) else key + if isinstance(key, Array): + key = (key,) + np_key = key + devices = {self.device} + if isinstance(key, tuple): + devices.update( + [subkey.device for subkey in key if hasattr(subkey, "device")] + ) + if len(devices) > 1: + raise ValueError( + "Array indexing is only allowed when array to be indexed and all " + "indexing arrays are on the same device." + ) + # Indexing self._array with array_api_strict arrays can be erroneous + # e.g., when using non-default device + np_key = tuple( + subkey._array if isinstance(subkey, Array) else subkey for subkey in key + ) res = self._array.__getitem__(np_key) return self._new(res, device=self.device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index e24a40f..51f4f31 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from .. import ones, arange, reshape, asarray, result_type, all, equal +from .. import ones, arange, reshape, asarray, result_type, all, equal, stack from .._array_object import Array, CPU_DEVICE, Device from .._dtypes import ( _all_dtypes, @@ -101,33 +101,40 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[idx]) -def test_indexing_arrays(): +@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) +def test_indexing_arrays(device): # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed + device = None if device is None else Device(device) # 1D array - a = arange(5) - idx = asarray([1, 0, 1, 2, -1]) + a = arange(5, device=device) + idx = asarray([1, 0, 1, 2, -1], device=device) a_idx = a[idx] - a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])]) + a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + assert a_idx.shape == idx.shape + assert a.device == idx.device == a_idx.device # setitem with arrays is not allowed with assert_raises(IndexError): a[idx] = 42 # mixed array and integer indexing - a = reshape(arange(3*4), (3, 4)) - idx = asarray([1, 0, 1, 2, -1]) + a = reshape(arange(3*4, device=device), (3, 4)) + idx = asarray([1, 0, 1, 2, -1], device=device) a_idx = a[idx, 1] - - a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])]) + a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + assert a_idx.shape == idx.shape + assert a.device == idx.device == a_idx.device # index with two arrays a_idx = a[idx, idx] - a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])]) + a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + assert a_idx.shape == a_idx.shape + assert a.device == idx.device == a_idx.device # setitem with arrays is not allowed with assert_raises(IndexError): @@ -135,7 +142,24 @@ def test_indexing_arrays(): # smoke test indexing with ndim > 1 arrays idx = idx[..., None] - a[idx, idx] + a_idx = a[idx, idx] + assert a.device == idx.device == a_idx.device + + +def test_indexing_arrays_different_devices(): + # Ensure indexing via array on different device errors + device1 = Device("CPU_DEVICE") + device2 = Device("device1") + + a = arange(5, device=device1) + idx1 = asarray([1, 0, 1, 2, -1], device=device2) + idx2 = asarray([1, 0, 1, 2, -1], device=device1) + + with pytest.raises(ValueError, match="Array indexing is only allowed when"): + a[idx1] + + with pytest.raises(ValueError, match="Array indexing is only allowed when"): + a[idx1, idx2] def test_promoted_scalar_inherits_device():