@@ -2293,6 +2293,57 @@ def test_compute(self):
2293
2293
assert actual .chunksizes == expected_chunksizes , "mismatching chunksizes"
2294
2294
assert tree .chunksizes == original_chunksizes , "original tree was modified"
2295
2295
2296
+ def test_persist (self ):
2297
+ ds1 = xr .Dataset ({"a" : ("x" , np .arange (10 ))})
2298
+ ds2 = xr .Dataset ({"b" : ("y" , np .arange (5 ))})
2299
+ ds3 = xr .Dataset ({"c" : ("z" , np .arange (4 ))})
2300
+ ds4 = xr .Dataset ({"d" : ("x" , np .arange (- 5 , 5 ))})
2301
+
2302
+ def fn (x ):
2303
+ return 2 * x
2304
+
2305
+ expected = xr .DataTree .from_dict (
2306
+ {
2307
+ "/" : fn (ds1 ).chunk ({"x" : 5 }),
2308
+ "/group1" : fn (ds2 ).chunk ({"y" : 3 }),
2309
+ "/group2" : fn (ds3 ).chunk ({"z" : 2 }),
2310
+ "/group1/subgroup1" : fn (ds4 ).chunk ({"x" : 5 }),
2311
+ }
2312
+ )
2313
+ # Add trivial second layer to the task graph, persist should reduce to one
2314
+ tree = xr .DataTree .from_dict (
2315
+ {
2316
+ "/" : fn (ds1 .chunk ({"x" : 5 })),
2317
+ "/group1" : fn (ds2 .chunk ({"y" : 3 })),
2318
+ "/group2" : fn (ds3 .chunk ({"z" : 2 })),
2319
+ "/group1/subgroup1" : fn (ds4 .chunk ({"x" : 5 })),
2320
+ }
2321
+ )
2322
+ original_chunksizes = tree .chunksizes
2323
+ original_hlg_depths = {
2324
+ node .path : len (node .dataset .__dask_graph__ ().layers )
2325
+ for node in tree .subtree
2326
+ }
2327
+
2328
+ actual = tree .persist ()
2329
+ actual_hlg_depths = {
2330
+ node .path : len (node .dataset .__dask_graph__ ().layers )
2331
+ for node in actual .subtree
2332
+ }
2333
+
2334
+ assert_identical (actual , expected )
2335
+
2336
+ assert actual .chunksizes == original_chunksizes , "chunksizes were modified"
2337
+ assert (
2338
+ tree .chunksizes == original_chunksizes
2339
+ ), "original chunksizes were modified"
2340
+ assert all (
2341
+ d == 1 for d in actual_hlg_depths .values ()
2342
+ ), "unexpected dask graph depth"
2343
+ assert all (
2344
+ d == 2 for d in original_hlg_depths .values ()
2345
+ ), "original dask graph was modified"
2346
+
2296
2347
def test_chunk (self ):
2297
2348
ds1 = xr .Dataset ({"a" : ("x" , np .arange (10 ))})
2298
2349
ds2 = xr .Dataset ({"b" : ("y" , np .arange (5 ))})
0 commit comments