Skip to content

Commit f468a06

Browse files
dcheriancrusaderky
andauthored
Optimize isel for lazy array equality checking (pydata#3588)
* Add some xfailed tests. * Only xfail failing tests. * Add DataArray.rename_dims, DataArray.rename_vars * Update tests. * Fix isel. Tests pass. * todos * All tests pass. * Add comments. * wip * cleanup * Revert "Add DataArray.rename_dims, DataArray.rename_vars" This reverts commit 61b7334. * more tests * Add comment * Add optimization to DaskIndexingAdapter * Update xarray/core/variable.py Co-Authored-By: crusaderky <[email protected]> * minor. Co-authored-by: crusaderky <[email protected]>
1 parent 5e41b60 commit f468a06

File tree

3 files changed

+77
-2
lines changed

3 files changed

+77
-2
lines changed

xarray/core/indexing.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import defaultdict
55
from contextlib import suppress
66
from datetime import timedelta
7-
from typing import Any, Callable, Sequence, Tuple, Union
7+
from typing import Any, Callable, Iterable, Sequence, Tuple, Union
88

99
import numpy as np
1010
import pandas as pd
@@ -1314,6 +1314,24 @@ def __init__(self, array):
13141314
self.array = array
13151315

13161316
def __getitem__(self, key):
1317+
1318+
if not isinstance(key, VectorizedIndexer):
1319+
# if possible, short-circuit when keys are effectively slice(None)
1320+
# This preserves dask name and passes lazy array equivalence checks
1321+
# (see duck_array_ops.lazy_array_equiv)
1322+
rewritten_indexer = False
1323+
new_indexer = []
1324+
for idim, k in enumerate(key.tuple):
1325+
if isinstance(k, Iterable) and duck_array_ops.array_equiv(
1326+
k, np.arange(self.array.shape[idim])
1327+
):
1328+
new_indexer.append(slice(None))
1329+
rewritten_indexer = True
1330+
else:
1331+
new_indexer.append(k)
1332+
if rewritten_indexer:
1333+
key = type(key)(tuple(new_indexer))
1334+
13171335
if isinstance(key, BasicIndexer):
13181336
return self.array[key.tuple]
13191337
elif isinstance(key, VectorizedIndexer):

xarray/core/variable.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,9 @@ def isel(
10571057

10581058
invalid = indexers.keys() - set(self.dims)
10591059
if invalid:
1060-
raise ValueError("dimensions %r do not exist" % invalid)
1060+
raise ValueError(
1061+
f"dimensions {invalid} do not exist. Expected one or more of {self.dims}"
1062+
)
10611063

10621064
key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
10631065
return self[key]

xarray/tests/test_dask.py

+55
Original file line numberDiff line numberDiff line change
@@ -1390,3 +1390,58 @@ def test_lazy_array_equiv_merge(compat):
13901390
xr.merge([da1, da3], compat=compat)
13911391
with raise_if_dask_computes(max_computes=2):
13921392
xr.merge([da1, da2 / 2], compat=compat)
1393+
1394+
1395+
@pytest.mark.filterwarnings("ignore::FutureWarning") # transpose_coords
1396+
@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1397+
@pytest.mark.parametrize(
1398+
"transform",
1399+
[
1400+
lambda a: a.assign_attrs(new_attr="anew"),
1401+
lambda a: a.assign_coords(cxy=a.cxy),
1402+
lambda a: a.copy(),
1403+
lambda a: a.isel(x=np.arange(a.sizes["x"])),
1404+
lambda a: a.isel(x=slice(None)),
1405+
lambda a: a.loc[dict(x=slice(None))],
1406+
lambda a: a.loc[dict(x=np.arange(a.sizes["x"]))],
1407+
lambda a: a.loc[dict(x=a.x)],
1408+
lambda a: a.sel(x=a.x),
1409+
lambda a: a.sel(x=a.x.values),
1410+
lambda a: a.transpose(...),
1411+
lambda a: a.squeeze(), # no dimensions to squeeze
1412+
lambda a: a.sortby("x"), # "x" is already sorted
1413+
lambda a: a.reindex(x=a.x),
1414+
lambda a: a.reindex_like(a),
1415+
lambda a: a.rename({"cxy": "cnew"}).rename({"cnew": "cxy"}),
1416+
lambda a: a.pipe(lambda x: x),
1417+
lambda a: xr.align(a, xr.zeros_like(a))[0],
1418+
# assign
1419+
# swap_dims
1420+
# set_index / reset_index
1421+
],
1422+
)
1423+
def test_transforms_pass_lazy_array_equiv(obj, transform):
1424+
with raise_if_dask_computes():
1425+
assert_equal(obj, transform(obj))
1426+
1427+
1428+
def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds):
1429+
with raise_if_dask_computes():
1430+
assert_equal(map_ds.cxy.broadcast_like(map_ds.cxy), map_ds.cxy)
1431+
assert_equal(xr.broadcast(map_ds.cxy, map_ds.cxy)[0], map_ds.cxy)
1432+
assert_equal(map_ds.map(lambda x: x), map_ds)
1433+
assert_equal(map_ds.set_coords("a").reset_coords("a"), map_ds)
1434+
assert_equal(map_ds.update({"a": map_ds.a}), map_ds)
1435+
1436+
# fails because of index error
1437+
# assert_equal(
1438+
# map_ds.rename_dims({"x": "xnew"}).rename_dims({"xnew": "x"}), map_ds
1439+
# )
1440+
1441+
assert_equal(
1442+
map_ds.rename_vars({"cxy": "cnew"}).rename_vars({"cnew": "cxy"}), map_ds
1443+
)
1444+
1445+
assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da)
1446+
assert_equal(map_da.astype(map_da.dtype), map_da)
1447+
assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy)

0 commit comments

Comments
 (0)