7
7
except ImportError :
8
8
pass
9
9
10
+ import collections
10
11
import itertools
11
12
import operator
12
13
from typing import (
13
14
Any ,
14
15
Callable ,
15
16
Dict ,
17
+ DefaultDict ,
16
18
Hashable ,
17
19
Mapping ,
18
20
Sequence ,
@@ -222,6 +224,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
222
224
indexes .update ({k : template .indexes [k ] for k in new_indexes })
223
225
224
226
graph : Dict [Any , Any ] = {}
227
+ new_layers : DefaultDict [str , Dict [Any , Any ]] = collections .defaultdict (dict )
225
228
gname = "{}-{}" .format (
226
229
dask .utils .funcname (func ), dask .base .tokenize (dataset , args , kwargs )
227
230
)
@@ -310,9 +313,13 @@ def _wrapper(func, obj, to_array, args, kwargs):
310
313
# unchunked dimensions in the input have one chunk in the result
311
314
key += (0 ,)
312
315
313
- graph [key ] = (operator .getitem , from_wrapper , name )
316
+ new_layers [ gname_l ] [key ] = (operator .getitem , from_wrapper , name )
314
317
315
- graph = HighLevelGraph .from_collections (gname , graph , dependencies = [dataset ])
318
+ hlg = HighLevelGraph .from_collections (gname , graph , dependencies = [dataset ])
319
+
320
+ for gname_l , layer in new_layers .items ():
321
+ hlg .dependencies [gname_l ] = {gname }
322
+ hlg .layers [gname_l ] = layer
316
323
317
324
result = Dataset (coords = indexes , attrs = template .attrs )
318
325
for name , gname_l in var_key_map .items ():
@@ -325,7 +332,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
325
332
var_chunks .append ((len (indexes [dim ]),))
326
333
327
334
data = dask .array .Array (
328
- graph , name = gname_l , chunks = var_chunks , dtype = template [name ].dtype
335
+ hlg , name = gname_l , chunks = var_chunks , dtype = template [name ].dtype
329
336
)
330
337
result [name ] = (dims , data , template [name ].attrs )
331
338
0 commit comments