Skip to content

Commit a9a5e93

Browse files
committed
Fix map_blocks HLG layering
This fixes an issue with the HighLevelGraph noted in pydata#3584, and exposed by a recent change in Dask to do more HLG fusion.
1 parent 87a25b6 commit a9a5e93

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

xarray/core/parallel.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
except ImportError:
88
pass
99

10+
import collections
1011
import itertools
1112
import operator
1213
from typing import (
1314
Any,
1415
Callable,
1516
Dict,
17+
DefaultDict,
1618
Hashable,
1719
Mapping,
1820
Sequence,
@@ -222,6 +224,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
222224
indexes.update({k: template.indexes[k] for k in new_indexes})
223225

224226
graph: Dict[Any, Any] = {}
227+
new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict)
225228
gname = "{}-{}".format(
226229
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs)
227230
)
@@ -310,9 +313,13 @@ def _wrapper(func, obj, to_array, args, kwargs):
310313
# unchunked dimensions in the input have one chunk in the result
311314
key += (0,)
312315

313-
graph[key] = (operator.getitem, from_wrapper, name)
316+
new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)
314317

315-
graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])
318+
hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])
319+
320+
for gname_l, layer in new_layers.items():
321+
hlg.dependencies[gname_l] = {gname}
322+
hlg.layers[gname_l] = layer
316323

317324
result = Dataset(coords=indexes, attrs=template.attrs)
318325
for name, gname_l in var_key_map.items():
@@ -325,7 +332,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
325332
var_chunks.append((len(indexes[dim]),))
326333

327334
data = dask.array.Array(
328-
graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
335+
hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
329336
)
330337
result[name] = (dims, data, template[name].attrs)
331338

xarray/tests/test_dask.py

+7
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,13 @@ def func(obj):
11891189
assert_identical(expected.compute(), actual.compute())
11901190

11911191

1192+
def test_map_blocks_hlg_layers():
1193+
ds = xr.Dataset({"x": (("y",), dask.array.ones(10, chunks=(5,)))})
1194+
mapped = ds.map_blocks(lambda x: x)
1195+
1196+
xr.testing.assert_equal(mapped, ds) # does not work
1197+
1198+
11921199
def test_make_meta(map_ds):
11931200
from ..core.parallel import make_meta
11941201

0 commit comments

Comments
 (0)