Skip to content

Commit cafcaee

Browse files
TomAugspurgerdcherian
authored andcommitted
Fix map_blocks HLG layering (#3598)
* Fix map_blocks HLG layering This fixes an issue with the HighLevelGraph noted in #3584, and exposed by a recent change in Dask to do more HLG fusion. * update * black * update
1 parent 4c51aa2 commit cafcaee

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ Bug fixes
3636
~~~~~~~~~
3737
- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`)
3838
By `Deepak Cherian <https://github.com/dcherian>`_.
39+
- Fix issue with Dask-backed datasets raising a ``KeyError`` on some computations involving ``map_blocks`` (:pull:`3598`)
40+
By `Tom Augspurger <https://github.com/TomAugspurger>`_.
3941

4042
Documentation
4143
~~~~~~~~~~~~~

xarray/core/parallel.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
except ImportError:
88
pass
99

10+
import collections
1011
import itertools
1112
import operator
1213
from typing import (
1314
Any,
1415
Callable,
1516
Dict,
17+
DefaultDict,
1618
Hashable,
1719
Mapping,
1820
Sequence,
@@ -221,7 +223,12 @@ def _wrapper(func, obj, to_array, args, kwargs):
221223
indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes}
222224
indexes.update({k: template.indexes[k] for k in new_indexes})
223225

226+
# We're building a new HighLevelGraph hlg. We'll have one new layer
227+
# for each variable in the dataset, which is the result of the
228+
# func applied to the values.
229+
224230
graph: Dict[Any, Any] = {}
231+
new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict)
225232
gname = "{}-{}".format(
226233
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs)
227234
)
@@ -310,9 +317,20 @@ def _wrapper(func, obj, to_array, args, kwargs):
310317
# unchunked dimensions in the input have one chunk in the result
311318
key += (0,)
312319

313-
graph[key] = (operator.getitem, from_wrapper, name)
320+
# We're adding multiple new layers to the graph:
321+
# The first new layer is the result of the computation on
322+
# the array.
323+
# Then we add one layer per variable, which extracts the
324+
# result for that variable, and depends on just the first new
325+
# layer.
326+
new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)
327+
328+
hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])
314329

315-
graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])
330+
for gname_l, layer in new_layers.items():
331+
# This adds in the getitems for each variable in the dataset.
332+
hlg.dependencies[gname_l] = {gname}
333+
hlg.layers[gname_l] = layer
316334

317335
result = Dataset(coords=indexes, attrs=template.attrs)
318336
for name, gname_l in var_key_map.items():
@@ -325,7 +343,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
325343
var_chunks.append((len(indexes[dim]),))
326344

327345
data = dask.array.Array(
328-
graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
346+
hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
329347
)
330348
result[name] = (dims, data, template[name].attrs)
331349

xarray/tests/test_dask.py

+13
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,19 @@ def func(obj):
11891189
assert_identical(expected.compute(), actual.compute())
11901190

11911191

1192+
def test_map_blocks_hlg_layers():
1193+
# regression test for #3599
1194+
ds = xr.Dataset(
1195+
{
1196+
"x": (("a",), dask.array.ones(10, chunks=(5,))),
1197+
"z": (("b",), dask.array.ones(10, chunks=(5,))),
1198+
}
1199+
)
1200+
mapped = ds.map_blocks(lambda x: x)
1201+
1202+
xr.testing.assert_equal(mapped, ds)
1203+
1204+
11921205
def test_make_meta(map_ds):
11931206
from ..core.parallel import make_meta
11941207

0 commit comments

Comments
 (0)