Description
Not sure why indexing a 2D array with 'device1' or 'device2' fails. Noticed while working on scikit-learn/scikit-learn#29431
Min reproducible example:
import array_api_strict as xp
# Create device
device = xp.Device('device1')
# Create a 2D array on device1
data = xp.asarray([
[10, 20, 30, 40],
[50, 60, 70, 80],
[90, 100, 110, 120]
], device=device)
# Create row and column index arrays on the same device
row_indices = xp.asarray([0, 2, 1], device=device)
col_indices = xp.asarray([1, 3, 0], device=device)
# Index the 2D array using both index arrays
result = data[row_indices, col_indices]
gives the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[7], line 1
----> 1 result = data[row_indices, col_indices]
File ~/miniconda3/envs/skl-array-api/lib/python3.13/site-packages/array_api_strict/_array_object.py:703, in Array.__getitem__(self, key)
700 if isinstance(key, Array):
701 # Indexing self._array with array_api_strict arrays can be erroneous
702 key = key._array
--> 703 res = self._array.__getitem__(key)
704 return self._new(res, device=self.device)
File ~/miniconda3/envs/skl-array-api/lib/python3.13/site-packages/array_api_strict/_array_object.py:167, in Array.__array__(self, dtype, copy)
165 if _allow_array:
166 if self._device != CPU_DEVICE:
--> 167 raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
168 # copy keyword is new in 2.0.0; for older versions don't use it
169 # retry without that keyword.
170 if np.__version__[0] < '2':
RuntimeError: Can not convert array on the 'array_api_strict.Device('device1')' device to a Numpy array.
Doing the above with a 1D array works. Setting device = xp.Device('CPU_DEVICE')
also works.
Using array-api-strict version 2.3.1
Metadata
Metadata
Assignees
Labels
No labels