Skip to content

Commit 54468e1

Browse files
fujiisoupshoyer
authored andcommitted
Vectorized lazy indexing (#1899)
* Start working * First support of lazy vectorized indexing. * Some optimization. * Use unique to decompose vectorized indexing. * Consolidate vectorizedIndexing * Support vectorized_indexing in h5py * Refactoring backend array. Added indexing.decompose_indexers. Drop unwrap_explicit_indexers. * typo * bugfix and typo * Fix based on @WeatherGod comments. * Use enum-like object to indicate indexing-support types. * Update test_decompose_indexers. * Bugfix and benchmarks. * fix: support outer/basic indexer in LazilyVectorizedIndexedArray * More comments. * Fixing style errors. * Remove unintended dupicate * combine indexers for on-memory np.ndarray. * fix whats new * fix pydap * Update comments. * Support VectorizedIndexing for rasterio. Some bugfix. * flake8 * More tests * Use LazilyIndexedArray for scalar array instead of loading. * Support negative step slice in rasterio. * Make slice-step always positive * Bugfix in slice-slice * Add pydap support. * Rename LazilyIndexedArray -> LazilyOuterIndexedArray. Remove duplicate in zarr.py * flake8 * Added transpose to LazilyOuterIndexedArray
1 parent 55128aa commit 54468e1

15 files changed

+859
-269
lines changed

asv_bench/benchmarks/dataset_io.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import xarray as xr
77

8-
from . import randn, requires_dask
8+
from . import randn, randint, requires_dask
99

1010
try:
1111
import dask
@@ -71,6 +71,15 @@ def make_ds(self):
7171

7272
self.ds.attrs = {'history': 'created for xarray benchmarking'}
7373

74+
self.oinds = {'time': randint(0, self.nt, 120),
75+
'lon': randint(0, self.nx, 20),
76+
'lat': randint(0, self.ny, 10)}
77+
self.vinds = {'time': xr.DataArray(randint(0, self.nt, 120),
78+
dims='x'),
79+
'lon': xr.DataArray(randint(0, self.nx, 120),
80+
dims='x'),
81+
'lat': slice(3, 20)}
82+
7483

7584
class IOWriteSingleNetCDF3(IOSingleNetCDF):
7685
def setup(self):
@@ -98,6 +107,14 @@ def setup(self):
98107
def time_load_dataset_netcdf4(self):
99108
xr.open_dataset(self.filepath, engine='netcdf4').load()
100109

110+
def time_orthogonal_indexing(self):
111+
ds = xr.open_dataset(self.filepath, engine='netcdf4')
112+
ds = ds.isel(**self.oinds).load()
113+
114+
def time_vectorized_indexing(self):
115+
ds = xr.open_dataset(self.filepath, engine='netcdf4')
116+
ds = ds.isel(**self.vinds).load()
117+
101118

102119
class IOReadSingleNetCDF3(IOReadSingleNetCDF4):
103120
def setup(self):
@@ -111,6 +128,14 @@ def setup(self):
111128
def time_load_dataset_scipy(self):
112129
xr.open_dataset(self.filepath, engine='scipy').load()
113130

131+
def time_orthogonal_indexing(self):
132+
ds = xr.open_dataset(self.filepath, engine='scipy')
133+
ds = ds.isel(**self.oinds).load()
134+
135+
def time_vectorized_indexing(self):
136+
ds = xr.open_dataset(self.filepath, engine='scipy')
137+
ds = ds.isel(**self.vinds).load()
138+
114139

115140
class IOReadSingleNetCDF4Dask(IOSingleNetCDF):
116141
def setup(self):
@@ -127,6 +152,16 @@ def time_load_dataset_netcdf4_with_block_chunks(self):
127152
xr.open_dataset(self.filepath, engine='netcdf4',
128153
chunks=self.block_chunks).load()
129154

155+
def time_load_dataset_netcdf4_with_block_chunks_oindexing(self):
156+
ds = xr.open_dataset(self.filepath, engine='netcdf4',
157+
chunks=self.block_chunks)
158+
ds = ds.isel(**self.oinds).load()
159+
160+
def time_load_dataset_netcdf4_with_block_chunks_vindexing(self):
161+
ds = xr.open_dataset(self.filepath, engine='netcdf4',
162+
chunks=self.block_chunks)
163+
ds = ds.isel(**self.vinds).load()
164+
130165
def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self):
131166
with dask.set_options(get=dask.multiprocessing.get):
132167
xr.open_dataset(self.filepath, engine='netcdf4',
@@ -158,6 +193,16 @@ def time_load_dataset_scipy_with_block_chunks(self):
158193
xr.open_dataset(self.filepath, engine='scipy',
159194
chunks=self.block_chunks).load()
160195

196+
def time_load_dataset_scipy_with_block_chunks_oindexing(self):
197+
ds = xr.open_dataset(self.filepath, engine='scipy',
198+
chunks=self.block_chunks)
199+
ds = ds.isel(**self.oinds).load()
200+
201+
def time_load_dataset_scipy_with_block_chunks_vindexing(self):
202+
ds = xr.open_dataset(self.filepath, engine='scipy',
203+
chunks=self.block_chunks)
204+
ds = ds.isel(**self.vinds).load()
205+
161206
def time_load_dataset_scipy_with_time_chunks(self):
162207
with dask.set_options(get=dask.multiprocessing.get):
163208
xr.open_dataset(self.filepath, engine='scipy',

doc/whats-new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ Documentation
3838
Enhancements
3939
~~~~~~~~~~~~
4040

41+
- Support lazy vectorized-indexing. After this change, flexible indexing such
42+
as orthogonal/vectorized indexing, becomes possible for all the backend
43+
arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`)
44+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
45+
4146
- Improve :py:func:`~xarray.DataArray.rolling` logic.
4247
:py:func:`~xarray.DataArrayRolling` object now supports
4348
:py:func:`~xarray.DataArrayRolling.construct` method that returns a view

xarray/backends/h5netcdf_.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,20 @@
1616

1717
class H5NetCDFArrayWrapper(BaseNetCDF4Array):
1818
def __getitem__(self, key):
19-
key = indexing.unwrap_explicit_indexer(
20-
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))
19+
key, np_inds = indexing.decompose_indexer(
20+
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR)
21+
2122
# h5py requires using lists for fancy indexing:
2223
# https://github.com/h5py/h5py/issues/992
23-
# OuterIndexer only holds 1D integer ndarrays, so it's safe to convert
24-
# them to lists.
25-
key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key)
24+
key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in
25+
key.tuple)
2626
with self.datastore.ensure_open(autoclose=True):
27-
return self.get_array()[key]
27+
array = self.get_array()[key]
28+
29+
if len(np_inds.tuple) > 0:
30+
array = indexing.NumpyIndexingAdapter(array)[np_inds]
31+
32+
return array
2833

2934

3035
def maybe_decode_bytes(txt):
@@ -85,7 +90,7 @@ def __init__(self, filename, mode='r', format=None, group=None,
8590
def open_store_variable(self, name, var):
8691
with self.ensure_open(autoclose=False):
8792
dimensions = var.dimensions
88-
data = indexing.LazilyIndexedArray(
93+
data = indexing.LazilyOuterIndexedArray(
8994
H5NetCDFArrayWrapper(name, self))
9095
attrs = _read_attributes(var)
9196

xarray/backends/netCDF4_.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,16 @@ def get_array(self):
4848

4949
class NetCDF4ArrayWrapper(BaseNetCDF4Array):
5050
def __getitem__(self, key):
51-
key = indexing.unwrap_explicit_indexer(
52-
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))
53-
51+
key, np_inds = indexing.decompose_indexer(
52+
key, self.shape, indexing.IndexingSupport.OUTER)
5453
if self.datastore.is_remote: # pragma: no cover
5554
getitem = functools.partial(robust_getitem, catch=RuntimeError)
5655
else:
5756
getitem = operator.getitem
5857

5958
with self.datastore.ensure_open(autoclose=True):
6059
try:
61-
data = getitem(self.get_array(), key)
60+
array = getitem(self.get_array(), key.tuple)
6261
except IndexError:
6362
# Catch IndexError in netCDF4 and return a more informative
6463
# error message. This is most often called when an unsorted
@@ -71,7 +70,10 @@ def __getitem__(self, key):
7170
msg += '\n\nOriginal traceback:\n' + traceback.format_exc()
7271
raise IndexError(msg)
7372

74-
return data
73+
if len(np_inds.tuple) > 0:
74+
array = indexing.NumpyIndexingAdapter(array)[np_inds]
75+
76+
return array
7577

7678

7779
def _encode_nc4_variable(var):
@@ -277,7 +279,8 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None,
277279
def open_store_variable(self, name, var):
278280
with self.ensure_open(autoclose=False):
279281
dimensions = var.dimensions
280-
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
282+
data = indexing.LazilyOuterIndexedArray(
283+
NetCDF4ArrayWrapper(name, self))
281284
attributes = OrderedDict((k, var.getncattr(k))
282285
for k in var.ncattrs())
283286
_ensure_fill_value_valid(data, attributes)

xarray/backends/pydap_.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,22 @@ def dtype(self):
2222
return self.array.dtype
2323

2424
def __getitem__(self, key):
25-
key = indexing.unwrap_explicit_indexer(
26-
key, target=self, allow=indexing.BasicIndexer)
25+
key, np_inds = indexing.decompose_indexer(
26+
key, self.shape, indexing.IndexingSupport.BASIC)
2727

2828
# pull the data from the array attribute if possible, to avoid
2929
# downloading coordinate data twice
3030
array = getattr(self.array, 'array', self.array)
31-
result = robust_getitem(array, key, catch=ValueError)
31+
result = robust_getitem(array, key.tuple, catch=ValueError)
3232
# pydap doesn't squeeze axes automatically like numpy
33-
axis = tuple(n for n, k in enumerate(key)
33+
axis = tuple(n for n, k in enumerate(key.tuple)
3434
if isinstance(k, integer_types))
3535
if len(axis) > 0:
3636
result = np.squeeze(result, axis)
37+
38+
if len(np_inds.tuple) > 0:
39+
result = indexing.NumpyIndexingAdapter(np.asarray(result))[np_inds]
40+
3741
return result
3842

3943

@@ -74,7 +78,7 @@ def open(cls, url, session=None):
7478
return cls(ds)
7579

7680
def open_store_variable(self, var):
77-
data = indexing.LazilyIndexedArray(PydapArrayWrapper(var))
81+
data = indexing.LazilyOuterIndexedArray(PydapArrayWrapper(var))
7882
return Variable(var.dimensions, data,
7983
_fix_attributes(var.attributes))
8084

xarray/backends/pynio_.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,19 @@ def get_array(self):
2424
return self.datastore.ds.variables[self.variable_name]
2525

2626
def __getitem__(self, key):
27-
key = indexing.unwrap_explicit_indexer(
28-
key, target=self, allow=indexing.BasicIndexer)
27+
key, np_inds = indexing.decompose_indexer(
28+
key, self.shape, indexing.IndexingSupport.BASIC)
2929

3030
with self.datastore.ensure_open(autoclose=True):
3131
array = self.get_array()
32-
if key == () and self.ndim == 0:
32+
if key.tuple == () and self.ndim == 0:
3333
return array.get_value()
34-
return array[key]
34+
35+
array = array[key.tuple]
36+
if len(np_inds.tuple) > 0:
37+
array = indexing.NumpyIndexingAdapter(array)[np_inds]
38+
39+
return array
3540

3641

3742
class NioDataStore(AbstractDataStore, DataStorePickleMixin):
@@ -51,7 +56,7 @@ def __init__(self, filename, mode='r', autoclose=False):
5156
self._mode = mode
5257

5358
def open_store_variable(self, name, var):
54-
data = indexing.LazilyIndexedArray(NioArrayWrapper(name, self))
59+
data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self))
5560
return Variable(var.dimensions, data, var.attributes)
5661

5762
def get_variables(self):

xarray/backends/rasterio_.py

+46-21
Original file line numberDiff line numberDiff line change
@@ -42,48 +42,73 @@ def dtype(self):
4242
def shape(self):
4343
return self._shape
4444

45-
def __getitem__(self, key):
46-
key = indexing.unwrap_explicit_indexer(
47-
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))
45+
def _get_indexer(self, key):
46+
""" Get indexer for rasterio array.
47+
48+
Parameter
49+
---------
50+
key: ExplicitIndexer
51+
52+
Returns
53+
-------
54+
band_key: an indexer for the 1st dimension
55+
window: two tuples. Each consists of (start, stop).
56+
squeeze_axis: axes to be squeezed
57+
np_ind: indexer for loaded numpy array
58+
59+
See also
60+
--------
61+
indexing.decompose_indexer
62+
"""
63+
key, np_inds = indexing.decompose_indexer(
64+
key, self.shape, indexing.IndexingSupport.OUTER)
4865

4966
# bands cannot be windowed but they can be listed
50-
band_key = key[0]
51-
n_bands = self.shape[0]
67+
band_key = key.tuple[0]
68+
new_shape = []
69+
np_inds2 = []
70+
# bands (axis=0) cannot be windowed but they can be listed
5271
if isinstance(band_key, slice):
53-
start, stop, step = band_key.indices(n_bands)
54-
if step is not None and step != 1:
55-
raise IndexError(_ERROR_MSG)
56-
band_key = np.arange(start, stop)
72+
start, stop, step = band_key.indices(self.shape[0])
73+
band_key = np.arange(start, stop, step)
5774
# be sure we give out a list
5875
band_key = (np.asarray(band_key) + 1).tolist()
76+
if isinstance(band_key, list): # if band_key is not a scalar
77+
new_shape.append(len(band_key))
78+
np_inds2.append(slice(None))
5979

6080
# but other dims can only be windowed
6181
window = []
6282
squeeze_axis = []
63-
for i, (k, n) in enumerate(zip(key[1:], self.shape[1:])):
83+
for i, (k, n) in enumerate(zip(key.tuple[1:], self.shape[1:])):
6484
if isinstance(k, slice):
85+
# step is always positive. see indexing.decompose_indexer
6586
start, stop, step = k.indices(n)
66-
if step is not None and step != 1:
67-
raise IndexError(_ERROR_MSG)
87+
np_inds2.append(slice(None, None, step))
88+
new_shape.append(stop - start)
6889
elif is_scalar(k):
6990
# windowed operations will always return an array
7091
# we will have to squeeze it later
71-
squeeze_axis.append(i + 1)
92+
squeeze_axis.append(- (2 - i))
7293
start = k
7394
stop = k + 1
7495
else:
75-
k = np.asarray(k)
76-
start = k[0]
77-
stop = k[-1] + 1
78-
ids = np.arange(start, stop)
79-
if not ((k.shape == ids.shape) and np.all(k == ids)):
80-
raise IndexError(_ERROR_MSG)
96+
start, stop = np.min(k), np.max(k) + 1
97+
np_inds2.append(k - start)
98+
new_shape.append(stop - start)
8199
window.append((start, stop))
82100

101+
np_inds = indexing._combine_indexers(
102+
indexing.OuterIndexer(tuple(np_inds2)), new_shape, np_inds)
103+
return band_key, window, tuple(squeeze_axis), np_inds
104+
105+
def __getitem__(self, key):
106+
band_key, window, squeeze_axis, np_inds = self._get_indexer(key)
107+
83108
out = self.rasterio_ds.read(band_key, window=tuple(window))
84109
if squeeze_axis:
85110
out = np.squeeze(out, axis=squeeze_axis)
86-
return out
111+
return indexing.NumpyIndexingAdapter(out)[np_inds]
87112

88113

89114
def _parse_envi(meta):
@@ -249,7 +274,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
249274
else:
250275
attrs[k] = v
251276

252-
data = indexing.LazilyIndexedArray(RasterioArrayWrapper(riods))
277+
data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(riods))
253278

254279
# this lets you write arrays loaded with rasterio
255280
data = indexing.CopyOnWriteArray(data)

0 commit comments

Comments
 (0)