Skip to content

Commit ac854f0

Browse files
authored
Make Indexer classes not inherit from tuple. (#1705)
* Make Indexer classes not inherit from tuple. I'm not entirely sure this is a good idea. The advantage is that it ensures that all our indexing code is entirely explicit: everything that reaches a backend *must* be an ExplicitIndexer. The downside is that it removes a bit of internal flexibility: we can't just use tuples in place of basic indexers anymore. On the whole, I think this is probably worth it but I would appreciate feedback. * Add validation to ExplicitIndexer classes * Fix pynio test failure * Rename and add comments * flake8 * Fix windows test failure * typo * leftover from debugging
1 parent 9d8ec38 commit ac854f0

19 files changed

+515
-262
lines changed

xarray/backends/common.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from distutils.version import LooseVersion
1111

1212
from ..conventions import cf_encoder
13-
from ..core.utils import FrozenOrderedDict
13+
from ..core import indexing
14+
from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin
1415
from ..core.pycompat import iteritems, dask_array_type
1516

1617
try:
@@ -76,6 +77,13 @@ def robust_getitem(array, key, catch=Exception, max_retries=6,
7677
time.sleep(1e-3 * next_delay)
7778

7879

80+
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
81+
82+
def __array__(self, dtype=None):
83+
key = indexing.BasicIndexer((slice(None),) * self.ndim)
84+
return np.asarray(self[key], dtype=dtype)
85+
86+
7987
class AbstractDataStore(Mapping):
8088
_autoclose = False
8189

xarray/backends/h5netcdf_.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515

1616
class H5NetCDFArrayWrapper(BaseNetCDF4Array):
1717
def __getitem__(self, key):
18-
if isinstance(key, indexing.VectorizedIndexer):
19-
raise NotImplementedError(
20-
'Vectorized indexing for {} is not implemented. Load your '
21-
'data first with .load() or .compute().'.format(type(self)))
22-
key = indexing.to_tuple(key)
18+
key = indexing.unwrap_explicit_indexer(
19+
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))
2320
with self.datastore.ensure_open(autoclose=True):
2421
return self.get_array()[key]
2522

xarray/backends/netCDF4_.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@
1010
from .. import Variable
1111
from ..conventions import pop_to
1212
from ..core import indexing
13-
from ..core.utils import (FrozenOrderedDict, NdimSizeLenMixin,
14-
DunderArrayMixin, close_on_error,
15-
is_remote_uri)
13+
from ..core.utils import (FrozenOrderedDict, close_on_error, is_remote_uri)
1614
from ..core.pycompat import iteritems, basestring, OrderedDict, PY3, suppress
1715

18-
from .common import (WritableCFDataStore, robust_getitem,
16+
from .common import (WritableCFDataStore, robust_getitem, BackendArray,
1917
DataStorePickleMixin, find_root)
2018
from .netcdf3 import (encode_nc3_attr_value, encode_nc3_variable)
2119

@@ -27,8 +25,7 @@
2725
'|': 'native'}
2826

2927

30-
class BaseNetCDF4Array(NdimSizeLenMixin, DunderArrayMixin,
31-
indexing.NDArrayIndexable):
28+
class BaseNetCDF4Array(BackendArray):
3229
def __init__(self, variable_name, datastore):
3330
self.datastore = datastore
3431
self.variable_name = variable_name
@@ -51,12 +48,8 @@ def get_array(self):
5148

5249
class NetCDF4ArrayWrapper(BaseNetCDF4Array):
5350
def __getitem__(self, key):
54-
if isinstance(key, indexing.VectorizedIndexer):
55-
raise NotImplementedError(
56-
'Vectorized indexing for {} is not implemented. Load your '
57-
'data first with .load() or .compute().'.format(type(self)))
58-
59-
key = indexing.to_tuple(key)
51+
key = indexing.unwrap_explicit_indexer(
52+
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))
6053

6154
if self.datastore.is_remote: # pragma: no cover
6255
getitem = functools.partial(robust_getitem, catch=RuntimeError)

xarray/backends/pydap_.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import numpy as np
55

66
from .. import Variable
7-
from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin
7+
from ..core.utils import FrozenOrderedDict, Frozen
88
from ..core import indexing
99
from ..core.pycompat import integer_types
1010

11-
from .common import AbstractDataStore, robust_getitem
11+
from .common import AbstractDataStore, BackendArray, robust_getitem
1212

1313

14-
class PydapArrayWrapper(NDArrayMixin, indexing.NDArrayIndexable):
14+
class PydapArrayWrapper(BackendArray):
1515
def __init__(self, array):
1616
self.array = array
1717

@@ -27,17 +27,9 @@ def dtype(self):
2727
return np.dtype(t.typecode + str(t.size))
2828

2929
def __getitem__(self, key):
30-
if isinstance(key, indexing.VectorizedIndexer):
31-
raise NotImplementedError(
32-
'Vectorized indexing for {} is not implemented. Load your '
33-
'data first with .load() or .compute().'.format(type(self)))
34-
key = indexing.to_tuple(key)
35-
if not isinstance(key, tuple):
36-
key = (key,)
37-
for k in key:
38-
if not (isinstance(k, integer_types + (slice,)) or k is Ellipsis):
39-
raise IndexError('pydap only supports indexing with int, '
40-
'slice and Ellipsis objects')
30+
key = indexing.unwrap_explicit_indexer(
31+
key, target=self, allow=indexing.BasicIndexer)
32+
4133
# pull the data from the array attribute if possible, to avoid
4234
# downloading coordinate data twice
4335
array = getattr(self.array, 'array', self.array)

xarray/backends/pynio_.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@
77
import numpy as np
88

99
from .. import Variable
10-
from ..core.utils import (FrozenOrderedDict, Frozen,
11-
NdimSizeLenMixin, DunderArrayMixin)
10+
from ..core.utils import (FrozenOrderedDict, Frozen)
1211
from ..core import indexing
13-
from ..core.pycompat import integer_types
1412

15-
from .common import AbstractDataStore, DataStorePickleMixin
13+
from .common import AbstractDataStore, DataStorePickleMixin, BackendArray
1614

1715

18-
class NioArrayWrapper(NdimSizeLenMixin, DunderArrayMixin,
19-
indexing.NDArrayIndexable):
16+
class NioArrayWrapper(BackendArray):
2017

2118
def __init__(self, variable_name, datastore):
2219
self.datastore = datastore
@@ -30,13 +27,9 @@ def get_array(self):
3027
return self.datastore.ds.variables[self.variable_name]
3128

3229
def __getitem__(self, key):
33-
if isinstance(key, (indexing.VectorizedIndexer,
34-
indexing.OuterIndexer)):
35-
raise NotImplementedError(
36-
'Nio backend does not support vectorized / outer indexing. '
37-
'Load your data first with .load() or .compute(). '
38-
'Given {}'.format(key))
39-
key = indexing.to_tuple(key)
30+
key = indexing.unwrap_explicit_indexer(
31+
key, target=self, allow=indexing.BasicIndexer)
32+
4033
with self.datastore.ensure_open(autoclose=True):
4134
array = self.get_array()
4235
if key == () and self.ndim == 0:

xarray/backends/rasterio_.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import numpy as np
44

55
from .. import DataArray
6-
from ..core.utils import DunderArrayMixin, NdimSizeLenMixin, is_scalar
6+
from ..core.utils import is_scalar
77
from ..core import indexing
8+
from .common import BackendArray
89
try:
910
from dask.utils import SerializableLock as Lock
1011
except ImportError:
@@ -17,8 +18,7 @@
1718
'first.')
1819

1920

20-
class RasterioArrayWrapper(NdimSizeLenMixin, DunderArrayMixin,
21-
indexing.NDArrayIndexable):
21+
class RasterioArrayWrapper(BackendArray):
2222
"""A wrapper around rasterio dataset objects"""
2323
def __init__(self, rasterio_ds):
2424
self.rasterio_ds = rasterio_ds
@@ -38,11 +38,9 @@ def shape(self):
3838
return self._shape
3939

4040
def __getitem__(self, key):
41-
if isinstance(key, indexing.VectorizedIndexer):
42-
raise NotImplementedError(
43-
'Vectorized indexing for {} is not implemented. Load your '
44-
'data first with .load() or .compute().'.format(type(self)))
45-
key = indexing.to_tuple(key)
41+
key = indexing.unwrap_explicit_indexer(
42+
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))
43+
4644
# bands cannot be windowed but they can be listed
4745
band_key = key[0]
4846
n_bands = self.shape[0]

xarray/backends/scipy_.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99

1010
from .. import Variable
1111
from ..core.pycompat import iteritems, OrderedDict, basestring
12-
from ..core.utils import (Frozen, FrozenOrderedDict, NdimSizeLenMixin,
13-
DunderArrayMixin)
14-
from ..core.indexing import NumpyIndexingAdapter, NDArrayIndexable
12+
from ..core.utils import (Frozen, FrozenOrderedDict)
13+
from ..core.indexing import NumpyIndexingAdapter
1514

16-
from .common import WritableCFDataStore, DataStorePickleMixin
15+
from .common import WritableCFDataStore, DataStorePickleMixin, BackendArray
1716
from .netcdf3 import (is_valid_nc3_name, encode_nc3_attr_value,
1817
encode_nc3_variable)
1918

@@ -31,7 +30,7 @@ def _decode_attrs(d):
3130
for (k, v) in iteritems(d))
3231

3332

34-
class ScipyArrayWrapper(NdimSizeLenMixin, DunderArrayMixin, NDArrayIndexable):
33+
class ScipyArrayWrapper(BackendArray):
3534

3635
def __init__(self, variable_name, datastore):
3736
self.datastore = datastore

xarray/conventions.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def encode_cf_timedelta(timedeltas, units=None):
337337
return (num, units)
338338

339339

340-
class MaskedAndScaledArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
340+
class MaskedAndScaledArray(indexing.ExplicitlyIndexedNDArrayMixin):
341341
"""Wrapper around array-like objects to create a new indexable object where
342342
values, when accessed, are automatically scaled and masked according to
343343
CF conventions for packed and missing data values.
@@ -395,7 +395,7 @@ def __repr__(self):
395395
self.scale_factor, self.add_offset, self._dtype))
396396

397397

398-
class DecodedCFDatetimeArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
398+
class DecodedCFDatetimeArray(indexing.ExplicitlyIndexedNDArrayMixin):
399399
"""Wrapper around array-like objects to create a new indexable object where
400400
values, when accessed, are automatically converted into datetime objects
401401
using decode_cf_datetime.
@@ -408,8 +408,9 @@ def __init__(self, array, units, calendar=None):
408408
# Verify that at least the first and last date can be decoded
409409
# successfully. Otherwise, tracebacks end up swallowed by
410410
# Dataset.__repr__ when users try to view their lazily decoded array.
411-
example_value = np.concatenate([first_n_items(array, 1) or [0],
412-
last_item(array) or [0]])
411+
values = indexing.ImplicitToExplicitIndexingAdapter(self.array)
412+
example_value = np.concatenate([first_n_items(values, 1) or [0],
413+
last_item(values) or [0]])
413414

414415
try:
415416
result = decode_cf_datetime(example_value, units, calendar)
@@ -434,7 +435,7 @@ def __getitem__(self, key):
434435
calendar=self.calendar)
435436

436437

437-
class DecodedCFTimedeltaArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
438+
class DecodedCFTimedeltaArray(indexing.ExplicitlyIndexedNDArrayMixin):
438439
"""Wrapper around array-like objects to create a new indexable object where
439440
values, when accessed, are automatically converted into timedelta objects
440441
using decode_cf_timedelta.
@@ -451,7 +452,7 @@ def __getitem__(self, key):
451452
return decode_cf_timedelta(self.array[key], units=self.units)
452453

453454

454-
class StackedBytesArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
455+
class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin):
455456
"""Wrapper around array-like objects to create a new indexable object where
456457
values, when accessed, are automatically stacked along the last dimension.
457458
@@ -482,7 +483,7 @@ def shape(self):
482483
def __str__(self):
483484
# TODO(shoyer): figure out why we need this special case?
484485
if self.ndim == 0:
485-
return str(self[...].item())
486+
return str(np.array(self).item())
486487
else:
487488
return repr(self)
488489

@@ -491,13 +492,13 @@ def __repr__(self):
491492

492493
def __getitem__(self, key):
493494
# require slicing the last dimension completely
494-
key = indexing.expanded_indexer(key, self.array.ndim)
495-
if key[-1] != slice(None):
495+
key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim))
496+
if key.tuple[-1] != slice(None):
496497
raise IndexError('too many indices')
497498
return char_to_bytes(self.array[key])
498499

499500

500-
class BytesToStringArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
501+
class BytesToStringArray(indexing.ExplicitlyIndexedNDArrayMixin):
501502
"""Wrapper that decodes bytes to unicode when values are read.
502503
503504
>>> BytesToStringArray(np.array([b'abc']))[:]
@@ -524,7 +525,7 @@ def dtype(self):
524525
def __str__(self):
525526
# TODO(shoyer): figure out why we need this special case?
526527
if self.ndim == 0:
527-
return str(self[...].item())
528+
return str(np.array(self).item())
528529
else:
529530
return repr(self)
530531

@@ -536,7 +537,7 @@ def __getitem__(self, key):
536537
return decode_bytes_array(self.array[key], self.encoding)
537538

538539

539-
class NativeEndiannessArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
540+
class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
540541
"""Decode arrays on the fly from non-native to native endianness
541542
542543
This is useful for decoding arrays from netCDF3 files (which are all
@@ -565,7 +566,7 @@ def __getitem__(self, key):
565566
return np.asarray(self.array[key], dtype=self.dtype)
566567

567568

568-
class BoolTypeArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
569+
class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
569570
"""Decode arrays on the fly from integer to boolean datatype
570571
571572
This is useful for decoding boolean arrays from integer typed netCDF
@@ -593,7 +594,7 @@ def __getitem__(self, key):
593594
return np.asarray(self.array[key], dtype=self.dtype)
594595

595596

596-
class UnsignedIntTypeArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
597+
class UnsignedIntTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
597598
"""Decode arrays on the fly from signed integer to unsigned
598599
integer. Typically used when _Unsigned is set at as a netCDF
599600
attribute on a signed integer variable.

0 commit comments

Comments
 (0)