Skip to content

Commit 0c6cded

Browse files
authored
Add DataTree.persist (#9682)
* add persist * add to api.rst * add persist to chunkmanager * more generalization * whats-new internal changes
1 parent 88612ce commit 0c6cded

File tree

7 files changed

+150
-6
lines changed

7 files changed

+150
-6
lines changed

doc/api.rst

+5
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ This interface echoes that of ``xarray.Dataset``.
656656
DataTree.has_attrs
657657
DataTree.is_empty
658658
DataTree.is_hollow
659+
DataTree.chunksizes
659660

660661
Dictionary Interface
661662
--------------------
@@ -968,6 +969,10 @@ DataTree methods
968969
DataTree.to_dict
969970
DataTree.to_netcdf
970971
DataTree.to_zarr
972+
DataTree.chunk
973+
DataTree.load
974+
DataTree.compute
975+
DataTree.persist
971976

972977
.. ..
973978

doc/whats-new.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ v.2024.10.1 (unreleased)
2121

2222
New Features
2323
~~~~~~~~~~~~
24-
24+
- Added :py:meth:`DataTree.persist` method (:issue:`9675`, :pull:`9682`).
25+
By `Sam Levang <https://github.com/slevang>`_.
2526

2627
Breaking changes
2728
~~~~~~~~~~~~~~~~
@@ -43,6 +44,8 @@ Documentation
4344

4445
Internal Changes
4546
~~~~~~~~~~~~~~~~
47+
- ``persist`` methods now route through the :py:class:`xr.core.parallelcompat.ChunkManagerEntrypoint` (:pull:`9682`).
48+
By `Sam Levang <https://github.com/slevang>`_.
4649

4750
.. _whats-new.2024.10.0:
4851

xarray/core/dataset.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1046,24 +1046,24 @@ def compute(self, **kwargs) -> Self:
10461046
return new.load(**kwargs)
10471047

10481048
def _persist_inplace(self, **kwargs) -> Self:
1049-
"""Persist all Dask arrays in memory"""
1049+
"""Persist all chunked arrays in memory."""
10501050
# access .data to coerce everything to numpy or dask arrays
10511051
lazy_data = {
1052-
k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data)
1052+
k: v._data for k, v in self.variables.items() if is_chunked_array(v._data)
10531053
}
10541054
if lazy_data:
1055-
import dask
1055+
chunkmanager = get_chunked_array_type(*lazy_data.values())
10561056

10571057
# evaluate all the dask arrays simultaneously
1058-
evaluated_data = dask.persist(*lazy_data.values(), **kwargs)
1058+
evaluated_data = chunkmanager.persist(*lazy_data.values(), **kwargs)
10591059

10601060
for k, data in zip(lazy_data, evaluated_data, strict=False):
10611061
self.variables[k].data = data
10621062

10631063
return self
10641064

10651065
def persist(self, **kwargs) -> Self:
1066-
"""Trigger computation, keeping data as dask arrays
1066+
"""Trigger computation, keeping data as chunked arrays.
10671067
10681068
This operation can be used to trigger computation on underlying dask
10691069
arrays, similar to ``.compute()`` or ``.load()``. However this

xarray/core/datatree.py

+57
Original file line numberDiff line numberDiff line change
@@ -1984,6 +1984,63 @@ def compute(self, **kwargs) -> Self:
19841984
new = self.copy(deep=False)
19851985
return new.load(**kwargs)
19861986

1987+
def _persist_inplace(self, **kwargs) -> Self:
1988+
"""Persist all chunked arrays in memory"""
1989+
# access .data to coerce everything to numpy or dask arrays
1990+
lazy_data = {
1991+
path: {
1992+
k: v._data
1993+
for k, v in node.variables.items()
1994+
if is_chunked_array(v._data)
1995+
}
1996+
for path, node in self.subtree_with_keys
1997+
}
1998+
flat_lazy_data = {
1999+
(path, var_name): array
2000+
for path, node in lazy_data.items()
2001+
for var_name, array in node.items()
2002+
}
2003+
if flat_lazy_data:
2004+
chunkmanager = get_chunked_array_type(*flat_lazy_data.values())
2005+
2006+
# evaluate all the dask arrays simultaneously
2007+
evaluated_data = chunkmanager.persist(*flat_lazy_data.values(), **kwargs)
2008+
2009+
for (path, var_name), data in zip(
2010+
flat_lazy_data, evaluated_data, strict=False
2011+
):
2012+
self[path].variables[var_name].data = data
2013+
2014+
return self
2015+
2016+
def persist(self, **kwargs) -> Self:
2017+
"""Trigger computation, keeping data as chunked arrays.
2018+
2019+
This operation can be used to trigger computation on underlying dask
2020+
arrays, similar to ``.compute()`` or ``.load()``. However this
2021+
operation keeps the data as dask arrays. This is particularly useful
2022+
when using the dask.distributed scheduler and you want to load a large
2023+
amount of data into distributed memory.
2024+
Like compute (but unlike load), the original dataset is left unaltered.
2025+
2026+
2027+
Parameters
2028+
----------
2029+
**kwargs : dict
2030+
Additional keyword arguments passed on to ``dask.persist``.
2031+
2032+
Returns
2033+
-------
2034+
object : DataTree
2035+
New object with all dask-backed coordinates and data variables as persisted dask arrays.
2036+
2037+
See Also
2038+
--------
2039+
dask.persist
2040+
"""
2041+
new = self.copy(deep=False)
2042+
return new._persist_inplace(**kwargs)
2043+
19872044
@property
19882045
def chunksizes(self) -> Mapping[str, Mapping[Hashable, tuple[int, ...]]]:
19892046
"""

xarray/namedarray/daskmanager.py

+5
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ def compute(
8585

8686
return compute(*data, **kwargs) # type: ignore[no-untyped-call, no-any-return]
8787

88+
def persist(self, *data: Any, **kwargs: Any) -> tuple[DaskArray | Any, ...]:
89+
from dask import persist
90+
91+
return persist(*data, **kwargs) # type: ignore[no-untyped-call, no-any-return]
92+
8893
@property
8994
def array_api(self) -> Any:
9095
from dask import array as da

xarray/namedarray/parallelcompat.py

+23
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,29 @@ def compute(
357357
"""
358358
raise NotImplementedError()
359359

360+
def persist(
361+
self, *data: T_ChunkedArray | Any, **kwargs: Any
362+
) -> tuple[T_ChunkedArray | Any, ...]:
363+
"""
364+
Persist one or more chunked arrays in memory.
365+
366+
Parameters
367+
----------
368+
*data : object
369+
Any number of objects. If an object is an instance of the chunked array type, it is persisted
370+
as a chunked array in memory. All other types should be passed through unchanged.
371+
372+
Returns
373+
-------
374+
objs
375+
The input, but with all chunked arrays now persisted in memory.
376+
377+
See Also
378+
--------
379+
dask.persist
380+
"""
381+
raise NotImplementedError()
382+
360383
@property
361384
def array_api(self) -> Any:
362385
"""

xarray/tests/test_datatree.py

+51
Original file line numberDiff line numberDiff line change
@@ -2293,6 +2293,57 @@ def test_compute(self):
22932293
assert actual.chunksizes == expected_chunksizes, "mismatching chunksizes"
22942294
assert tree.chunksizes == original_chunksizes, "original tree was modified"
22952295

2296+
def test_persist(self):
2297+
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
2298+
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
2299+
ds3 = xr.Dataset({"c": ("z", np.arange(4))})
2300+
ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))})
2301+
2302+
def fn(x):
2303+
return 2 * x
2304+
2305+
expected = xr.DataTree.from_dict(
2306+
{
2307+
"/": fn(ds1).chunk({"x": 5}),
2308+
"/group1": fn(ds2).chunk({"y": 3}),
2309+
"/group2": fn(ds3).chunk({"z": 2}),
2310+
"/group1/subgroup1": fn(ds4).chunk({"x": 5}),
2311+
}
2312+
)
2313+
# Add trivial second layer to the task graph, persist should reduce to one
2314+
tree = xr.DataTree.from_dict(
2315+
{
2316+
"/": fn(ds1.chunk({"x": 5})),
2317+
"/group1": fn(ds2.chunk({"y": 3})),
2318+
"/group2": fn(ds3.chunk({"z": 2})),
2319+
"/group1/subgroup1": fn(ds4.chunk({"x": 5})),
2320+
}
2321+
)
2322+
original_chunksizes = tree.chunksizes
2323+
original_hlg_depths = {
2324+
node.path: len(node.dataset.__dask_graph__().layers)
2325+
for node in tree.subtree
2326+
}
2327+
2328+
actual = tree.persist()
2329+
actual_hlg_depths = {
2330+
node.path: len(node.dataset.__dask_graph__().layers)
2331+
for node in actual.subtree
2332+
}
2333+
2334+
assert_identical(actual, expected)
2335+
2336+
assert actual.chunksizes == original_chunksizes, "chunksizes were modified"
2337+
assert (
2338+
tree.chunksizes == original_chunksizes
2339+
), "original chunksizes were modified"
2340+
assert all(
2341+
d == 1 for d in actual_hlg_depths.values()
2342+
), "unexpected dask graph depth"
2343+
assert all(
2344+
d == 2 for d in original_hlg_depths.values()
2345+
), "original dask graph was modified"
2346+
22962347
def test_chunk(self):
22972348
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
22982349
ds2 = xr.Dataset({"b": ("y", np.arange(5))})

0 commit comments

Comments
 (0)