Skip to content

Commit d9760f3

Browse files
andersy005dcherianpre-commit-ci[bot]
authored
refactor indexing.py: introduce .oindex for Explicitly Indexed Arrays (#8750)
Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 01f7b4f commit d9760f3

File tree

3 files changed

+81
-28
lines changed

3 files changed

+81
-28
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ v2024.03.0 (unreleased)
2323
New Features
2424
~~~~~~~~~~~~
2525

26+
- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`)
27+
By `Anderson Banihirwe <https://github.com/andersy005>`_.
28+
2629

2730
Breaking changes
2831
~~~~~~~~~~~~~~~~
@@ -44,6 +47,7 @@ Internal Changes
4447
~~~~~~~~~~~~~~~~
4548

4649

50+
4751
.. _whats-new.2024.02.0:
4852

4953
v2024.02.0 (Feb 19, 2024)

xarray/core/indexing.py

+62-21
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,21 @@ def as_integer_slice(value):
325325
return slice(start, stop, step)
326326

327327

328+
class IndexCallable:
329+
"""Provide getitem syntax for a callable object."""
330+
331+
__slots__ = ("func",)
332+
333+
def __init__(self, func):
334+
self.func = func
335+
336+
def __getitem__(self, key):
337+
return self.func(key)
338+
339+
def __setitem__(self, key, value):
340+
raise NotImplementedError
341+
342+
328343
class BasicIndexer(ExplicitIndexer):
329344
"""Tuple for basic indexing.
330345
@@ -470,6 +485,13 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
470485
# Note this is the base class for all lazy indexing classes
471486
return np.asarray(self.get_duck_array(), dtype=dtype)
472487

488+
def _oindex_get(self, key):
489+
raise NotImplementedError("This method should be overridden")
490+
491+
@property
492+
def oindex(self):
493+
return IndexCallable(self._oindex_get)
494+
473495

474496
class ImplicitToExplicitIndexingAdapter(NDArrayMixin):
475497
"""Wrap an array, converting tuples into the indicated explicit indexer."""
@@ -560,6 +582,9 @@ def get_duck_array(self):
560582
def transpose(self, order):
561583
return LazilyVectorizedIndexedArray(self.array, self.key).transpose(order)
562584

585+
def _oindex_get(self, indexer):
586+
return type(self)(self.array, self._updated_key(indexer))
587+
563588
def __getitem__(self, indexer):
564589
if isinstance(indexer, VectorizedIndexer):
565590
array = LazilyVectorizedIndexedArray(self.array, self.key)
@@ -663,6 +688,9 @@ def _ensure_copied(self):
663688
def get_duck_array(self):
664689
return self.array.get_duck_array()
665690

691+
def _oindex_get(self, key):
692+
return type(self)(_wrap_numpy_scalars(self.array[key]))
693+
666694
def __getitem__(self, key):
667695
return type(self)(_wrap_numpy_scalars(self.array[key]))
668696

@@ -696,6 +724,9 @@ def get_duck_array(self):
696724
self._ensure_cached()
697725
return self.array.get_duck_array()
698726

727+
def _oindex_get(self, key):
728+
return type(self)(_wrap_numpy_scalars(self.array[key]))
729+
699730
def __getitem__(self, key):
700731
return type(self)(_wrap_numpy_scalars(self.array[key]))
701732

@@ -1332,6 +1363,10 @@ def _indexing_array_and_key(self, key):
13321363
def transpose(self, order):
13331364
return self.array.transpose(order)
13341365

1366+
def _oindex_get(self, key):
1367+
array, key = self._indexing_array_and_key(key)
1368+
return array[key]
1369+
13351370
def __getitem__(self, key):
13361371
array, key = self._indexing_array_and_key(key)
13371372
return array[key]
@@ -1376,16 +1411,19 @@ def __init__(self, array):
13761411
)
13771412
self.array = array
13781413

1414+
def _oindex_get(self, key):
1415+
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
1416+
key = key.tuple
1417+
value = self.array
1418+
for axis, subkey in reversed(list(enumerate(key))):
1419+
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
1420+
return value
1421+
13791422
def __getitem__(self, key):
13801423
if isinstance(key, BasicIndexer):
13811424
return self.array[key.tuple]
13821425
elif isinstance(key, OuterIndexer):
1383-
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
1384-
key = key.tuple
1385-
value = self.array
1386-
for axis, subkey in reversed(list(enumerate(key))):
1387-
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
1388-
return value
1426+
return self.oindex[key]
13891427
else:
13901428
if isinstance(key, VectorizedIndexer):
13911429
raise TypeError("Vectorized indexing is not supported")
@@ -1395,11 +1433,10 @@ def __getitem__(self, key):
13951433
def __setitem__(self, key, value):
13961434
if isinstance(key, (BasicIndexer, OuterIndexer)):
13971435
self.array[key.tuple] = value
1436+
elif isinstance(key, VectorizedIndexer):
1437+
raise TypeError("Vectorized indexing is not supported")
13981438
else:
1399-
if isinstance(key, VectorizedIndexer):
1400-
raise TypeError("Vectorized indexing is not supported")
1401-
else:
1402-
raise TypeError(f"Unrecognized indexer: {key}")
1439+
raise TypeError(f"Unrecognized indexer: {key}")
14031440

14041441
def transpose(self, order):
14051442
xp = self.array.__array_namespace__()
@@ -1417,24 +1454,25 @@ def __init__(self, array):
14171454
"""
14181455
self.array = array
14191456

1420-
def __getitem__(self, key):
1457+
def _oindex_get(self, key):
1458+
key = key.tuple
1459+
try:
1460+
return self.array[key]
1461+
except NotImplementedError:
1462+
# manual orthogonal indexing
1463+
value = self.array
1464+
for axis, subkey in reversed(list(enumerate(key))):
1465+
value = value[(slice(None),) * axis + (subkey,)]
1466+
return value
14211467

1468+
def __getitem__(self, key):
14221469
if isinstance(key, BasicIndexer):
14231470
return self.array[key.tuple]
14241471
elif isinstance(key, VectorizedIndexer):
14251472
return self.array.vindex[key.tuple]
14261473
else:
14271474
assert isinstance(key, OuterIndexer)
1428-
key = key.tuple
1429-
try:
1430-
return self.array[key]
1431-
except NotImplementedError:
1432-
# manual orthogonal indexing.
1433-
# TODO: port this upstream into dask in a saner way.
1434-
value = self.array
1435-
for axis, subkey in reversed(list(enumerate(key))):
1436-
value = value[(slice(None),) * axis + (subkey,)]
1437-
return value
1475+
return self.oindex[key]
14381476

14391477
def __setitem__(self, key, value):
14401478
if isinstance(key, BasicIndexer):
@@ -1510,6 +1548,9 @@ def _convert_scalar(self, item):
15101548
# a NumPy array.
15111549
return to_0d_array(item)
15121550

1551+
def _oindex_get(self, key):
1552+
return self.__getitem__(key)
1553+
15131554
def __getitem__(
15141555
self, indexer
15151556
) -> (

xarray/core/variable.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,7 @@
4141
maybe_coerce_to_str,
4242
)
4343
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
44-
from xarray.namedarray.pycompat import (
45-
integer_types,
46-
is_0d_dask_array,
47-
to_duck_array,
48-
)
44+
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
4945

5046
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
5147
indexing.ExplicitlyIndexed,
@@ -761,7 +757,14 @@ def __getitem__(self, key) -> Self:
761757
array `x.values` directly.
762758
"""
763759
dims, indexer, new_order = self._broadcast_indexes(key)
764-
data = as_indexable(self._data)[indexer]
760+
indexable = as_indexable(self._data)
761+
762+
if isinstance(indexer, BasicIndexer):
763+
data = indexable[indexer]
764+
elif isinstance(indexer, OuterIndexer):
765+
data = indexable.oindex[indexer]
766+
else:
767+
data = indexable[indexer]
765768
if new_order:
766769
data = np.moveaxis(data, range(len(new_order)), new_order)
767770
return self._finalize_indexing_result(dims, data)
@@ -794,7 +797,12 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA):
794797
else:
795798
actual_indexer = indexer
796799

797-
data = as_indexable(self._data)[actual_indexer]
800+
indexable = as_indexable(self._data)
801+
802+
if isinstance(indexer, OuterIndexer):
803+
data = indexable.oindex[indexer]
804+
else:
805+
data = indexable[actual_indexer]
798806
mask = indexing.create_mask(indexer, self.shape, data)
799807
# we need to invert the mask in order to pass data first. This helps
800808
# pint to choose the correct unit

0 commit comments

Comments
 (0)