7
7
except ImportError :
8
8
pass
9
9
10
+ import collections
10
11
import itertools
11
12
import operator
12
13
from typing import (
13
14
Any ,
14
15
Callable ,
15
16
Dict ,
17
+ DefaultDict ,
16
18
Hashable ,
17
19
Mapping ,
18
20
Sequence ,
@@ -221,7 +223,12 @@ def _wrapper(func, obj, to_array, args, kwargs):
221
223
indexes = {dim : dataset .indexes [dim ] for dim in preserved_indexes }
222
224
indexes .update ({k : template .indexes [k ] for k in new_indexes })
223
225
226
+ # We're building a new HighLevelGraph hlg. We'll have one new layer
227
+ # for each variable in the dataset, which is the result of the
228
+ # func applied to the values.
229
+
224
230
graph : Dict [Any , Any ] = {}
231
+ new_layers : DefaultDict [str , Dict [Any , Any ]] = collections .defaultdict (dict )
225
232
gname = "{}-{}" .format (
226
233
dask .utils .funcname (func ), dask .base .tokenize (dataset , args , kwargs )
227
234
)
@@ -310,9 +317,20 @@ def _wrapper(func, obj, to_array, args, kwargs):
310
317
# unchunked dimensions in the input have one chunk in the result
311
318
key += (0 ,)
312
319
313
- graph [key ] = (operator .getitem , from_wrapper , name )
320
+ # We're adding multiple new layers to the graph:
321
+ # The first new layer is the result of the computation on
322
+ # the array.
323
+ # Then we add one layer per variable, which extracts the
324
+ # result for that variable, and depends on just the first new
325
+ # layer.
326
+ new_layers [gname_l ][key ] = (operator .getitem , from_wrapper , name )
327
+
328
+ hlg = HighLevelGraph .from_collections (gname , graph , dependencies = [dataset ])
314
329
315
- graph = HighLevelGraph .from_collections (gname , graph , dependencies = [dataset ])
330
+ for gname_l , layer in new_layers .items ():
331
+ # This adds in the getitems for each variable in the dataset.
332
+ hlg .dependencies [gname_l ] = {gname }
333
+ hlg .layers [gname_l ] = layer
316
334
317
335
result = Dataset (coords = indexes , attrs = template .attrs )
318
336
for name , gname_l in var_key_map .items ():
@@ -325,7 +343,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
325
343
var_chunks .append ((len (indexes [dim ]),))
326
344
327
345
data = dask .array .Array (
328
- graph , name = gname_l , chunks = var_chunks , dtype = template [name ].dtype
346
+ hlg , name = gname_l , chunks = var_chunks , dtype = template [name ].dtype
329
347
)
330
348
result [name ] = (dims , data , template [name ].attrs )
331
349
0 commit comments