Skip to content

Commit 577221d

Browse files
authored
Fix writing of DataTree subgroups to zarr or netCDF (#9677)
* Fix writing of DataTree subgroups to zarr or netCDF Consider a DataTree with a group, e.g., `tree = DataTree.from_dict({'/': ... '/child': ...})` If we write `tree['/child']` to disk, the result should have groups relative to `'/child'`, so writing and reading from the same path restores the same object. In addition, coordinates defined at the root should be written to disk instead of being omitted. * Add write_inherited_coords for additional control in DataTree.to_zarr As discussed in the last xarray meeting, this defaults to write_inherited_coords=True, which has a little more overhead but means you always get coordinates when opening a sub-group. * Switch write_inherited_coords default to false * add whats new * remove unused import
1 parent fc05da9 commit 577221d

File tree

4 files changed

+124
-79
lines changed

4 files changed

+124
-79
lines changed

doc/whats-new.rst

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ New Features
2323
~~~~~~~~~~~~
2424
- Added :py:meth:`DataTree.persist` method (:issue:`9675`, :pull:`9682`).
2525
By `Sam Levang <https://github.com/slevang>`_.
26+
- Added ``write_inherited_coords`` option to :py:meth:`DataTree.to_netcdf`
27+
and :py:meth:`DataTree.to_zarr` (:pull:`9677`).
28+
By `Stephan Hoyer <https://github.com/shoyer>`_.
2629
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
2730
(:issue:`2852`, :issue:`757`).
2831
By `Deepak Cherian <https://github.com/dcherian>`_.
@@ -42,7 +45,11 @@ Deprecations
4245
Bug fixes
4346
~~~~~~~~~
4447

45-
- Fix inadvertent deep-copying of child data in DataTree.
48+
- Fix inadvertent deep-copying of child data in DataTree (:issue:`9683`,
49+
:pull:`9684`).
50+
By `Stephan Hoyer <https://github.com/shoyer>`_.
51+
- Avoid including parent groups when writing DataTree subgroups to Zarr or
52+
netCDF (:pull:`9682`).
4653
By `Stephan Hoyer <https://github.com/shoyer>`_.
4754
- Fix regression in the interoperability of :py:meth:`DataArray.polyfit` and :py:meth:`xr.polyval` for date-time coordinates. (:pull:`9691`).
4855
By `Pascal Bourgault <https://github.com/aulemahal>`_.

xarray/core/datatree.py

+14
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,7 @@ def to_netcdf(
15731573
format: T_DataTreeNetcdfTypes | None = None,
15741574
engine: T_DataTreeNetcdfEngine | None = None,
15751575
group: str | None = None,
1576+
write_inherited_coords: bool = False,
15761577
compute: bool = True,
15771578
**kwargs,
15781579
):
@@ -1609,6 +1610,11 @@ def to_netcdf(
16091610
group : str, optional
16101611
Path to the netCDF4 group in the given file to open as the root group
16111612
of the ``DataTree``. Currently, specifying a group is not supported.
1613+
write_inherited_coords : bool, default: False
1614+
If true, replicate inherited coordinates on all descendant nodes.
1615+
Otherwise, only write coordinates at the level at which they are
1616+
originally defined. This saves disk space, but requires opening the
1617+
full tree to load inherited coordinates.
16121618
compute : bool, default: True
16131619
If true compute immediately, otherwise return a
16141620
``dask.delayed.Delayed`` object that can be computed later.
@@ -1632,6 +1638,7 @@ def to_netcdf(
16321638
format=format,
16331639
engine=engine,
16341640
group=group,
1641+
write_inherited_coords=write_inherited_coords,
16351642
compute=compute,
16361643
**kwargs,
16371644
)
@@ -1643,6 +1650,7 @@ def to_zarr(
16431650
encoding=None,
16441651
consolidated: bool = True,
16451652
group: str | None = None,
1653+
write_inherited_coords: bool = False,
16461654
compute: Literal[True] = True,
16471655
**kwargs,
16481656
):
@@ -1668,6 +1676,11 @@ def to_zarr(
16681676
after writing metadata for all groups.
16691677
group : str, optional
16701678
Group path. (a.k.a. `path` in zarr terminology.)
1679+
write_inherited_coords : bool, default: False
1680+
If true, replicate inherited coordinates on all descendant nodes.
1681+
Otherwise, only write coordinates at the level at which they are
1682+
originally defined. This saves disk space, but requires opening the
1683+
full tree to load inherited coordinates.
16711684
compute : bool, default: True
16721685
If true compute immediately, otherwise return a
16731686
``dask.delayed.Delayed`` object that can be computed later. Metadata
@@ -1690,6 +1703,7 @@ def to_zarr(
16901703
encoding=encoding,
16911704
consolidated=consolidated,
16921705
group=group,
1706+
write_inherited_coords=write_inherited_coords,
16931707
compute=compute,
16941708
**kwargs,
16951709
)

xarray/core/datatree_io.py

+28-78
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,15 @@
22

33
from collections.abc import Mapping, MutableMapping
44
from os import PathLike
5-
from typing import TYPE_CHECKING, Any, Literal, get_args
5+
from typing import Any, Literal, get_args
66

77
from xarray.core.datatree import DataTree
88
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes
99

10-
if TYPE_CHECKING:
11-
from h5netcdf.legacyapi import Dataset as h5Dataset
12-
from netCDF4 import Dataset as ncDataset
13-
1410
T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"]
1511
T_DataTreeNetcdfTypes = Literal["NETCDF4"]
1612

1713

18-
def _get_nc_dataset_class(
19-
engine: T_DataTreeNetcdfEngine | None,
20-
) -> type[ncDataset] | type[h5Dataset]:
21-
if engine == "netcdf4":
22-
from netCDF4 import Dataset as ncDataset
23-
24-
return ncDataset
25-
if engine == "h5netcdf":
26-
from h5netcdf.legacyapi import Dataset as h5Dataset
27-
28-
return h5Dataset
29-
if engine is None:
30-
try:
31-
from netCDF4 import Dataset as ncDataset
32-
33-
return ncDataset
34-
except ImportError:
35-
from h5netcdf.legacyapi import Dataset as h5Dataset
36-
37-
return h5Dataset
38-
raise ValueError(f"unsupported engine: {engine}")
39-
40-
41-
def _create_empty_netcdf_group(
42-
filename: str | PathLike,
43-
group: str,
44-
mode: NetcdfWriteModes,
45-
engine: T_DataTreeNetcdfEngine | None,
46-
) -> None:
47-
ncDataset = _get_nc_dataset_class(engine)
48-
49-
with ncDataset(filename, mode=mode) as rootgrp:
50-
rootgrp.createGroup(group)
51-
52-
5314
def _datatree_to_netcdf(
5415
dt: DataTree,
5516
filepath: str | PathLike,
@@ -59,6 +20,7 @@ def _datatree_to_netcdf(
5920
format: T_DataTreeNetcdfTypes | None = None,
6021
engine: T_DataTreeNetcdfEngine | None = None,
6122
group: str | None = None,
23+
write_inherited_coords: bool = False,
6224
compute: bool = True,
6325
**kwargs,
6426
) -> None:
@@ -97,41 +59,31 @@ def _datatree_to_netcdf(
9759
unlimited_dims = {}
9860

9961
for node in dt.subtree:
100-
ds = node.to_dataset(inherit=False)
101-
group_path = node.path
102-
if ds is None:
103-
_create_empty_netcdf_group(filepath, group_path, mode, engine)
104-
else:
105-
ds.to_netcdf(
106-
filepath,
107-
group=group_path,
108-
mode=mode,
109-
encoding=encoding.get(node.path),
110-
unlimited_dims=unlimited_dims.get(node.path),
111-
engine=engine,
112-
format=format,
113-
compute=compute,
114-
**kwargs,
115-
)
62+
at_root = node is dt
63+
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
64+
group_path = None if at_root else "/" + node.relative_to(dt)
65+
ds.to_netcdf(
66+
filepath,
67+
group=group_path,
68+
mode=mode,
69+
encoding=encoding.get(node.path),
70+
unlimited_dims=unlimited_dims.get(node.path),
71+
engine=engine,
72+
format=format,
73+
compute=compute,
74+
**kwargs,
75+
)
11676
mode = "a"
11777

11878

119-
def _create_empty_zarr_group(
120-
store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes
121-
):
122-
import zarr
123-
124-
root = zarr.open_group(store, mode=mode)
125-
root.create_group(group, overwrite=True)
126-
127-
12879
def _datatree_to_zarr(
12980
dt: DataTree,
13081
store: MutableMapping | str | PathLike[str],
13182
mode: ZarrWriteModes = "w-",
13283
encoding: Mapping[str, Any] | None = None,
13384
consolidated: bool = True,
13485
group: str | None = None,
86+
write_inherited_coords: bool = False,
13587
compute: Literal[True] = True,
13688
**kwargs,
13789
):
@@ -163,19 +115,17 @@ def _datatree_to_zarr(
163115
)
164116

165117
for node in dt.subtree:
166-
ds = node.to_dataset(inherit=False)
167-
group_path = node.path
168-
if ds is None:
169-
_create_empty_zarr_group(store, group_path, mode)
170-
else:
171-
ds.to_zarr(
172-
store,
173-
group=group_path,
174-
mode=mode,
175-
encoding=encoding.get(node.path),
176-
consolidated=False,
177-
**kwargs,
178-
)
118+
at_root = node is dt
119+
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
120+
group_path = None if at_root else "/" + node.relative_to(dt)
121+
ds.to_zarr(
122+
store,
123+
group=group_path,
124+
mode=mode,
125+
encoding=encoding.get(node.path),
126+
consolidated=False,
127+
**kwargs,
128+
)
179129
if "w" in mode:
180130
mode = "a"
181131

xarray/tests/test_backends_datatree.py

+74
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,24 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
196196
with pytest.raises(ValueError, match="unexpected encoding group.*"):
197197
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
198198

199+
def test_write_subgroup(self, tmpdir):
200+
original_dt = DataTree.from_dict(
201+
{
202+
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
203+
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
204+
}
205+
).children["child"]
206+
207+
expected_dt = original_dt.copy()
208+
expected_dt.name = None
209+
210+
filepath = tmpdir / "test.zarr"
211+
original_dt.to_netcdf(filepath, engine=self.engine)
212+
213+
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
214+
assert_equal(original_dt, roundtrip_dt)
215+
assert_identical(expected_dt, roundtrip_dt)
216+
199217

200218
@requires_netCDF4
201219
class TestNetCDF4DatatreeIO(DatatreeIOBase):
@@ -556,3 +574,59 @@ def test_open_groups_chunks(self, tmpdir) -> None:
556574

557575
for ds in dict_of_datasets.values():
558576
ds.close()
577+
578+
def test_write_subgroup(self, tmpdir):
579+
original_dt = DataTree.from_dict(
580+
{
581+
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
582+
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
583+
}
584+
).children["child"]
585+
586+
expected_dt = original_dt.copy()
587+
expected_dt.name = None
588+
589+
filepath = tmpdir / "test.zarr"
590+
original_dt.to_zarr(filepath)
591+
592+
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
593+
assert_equal(original_dt, roundtrip_dt)
594+
assert_identical(expected_dt, roundtrip_dt)
595+
596+
def test_write_inherited_coords_false(self, tmpdir):
597+
original_dt = DataTree.from_dict(
598+
{
599+
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
600+
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
601+
}
602+
)
603+
604+
filepath = tmpdir / "test.zarr"
605+
original_dt.to_zarr(filepath, write_inherited_coords=False)
606+
607+
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
608+
assert_identical(original_dt, roundtrip_dt)
609+
610+
expected_child = original_dt.children["child"].copy(inherit=False)
611+
expected_child.name = None
612+
with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child:
613+
assert_identical(expected_child, roundtrip_child)
614+
615+
def test_write_inherited_coords_true(self, tmpdir):
616+
original_dt = DataTree.from_dict(
617+
{
618+
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
619+
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
620+
}
621+
)
622+
623+
filepath = tmpdir / "test.zarr"
624+
original_dt.to_zarr(filepath, write_inherited_coords=True)
625+
626+
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
627+
assert_identical(original_dt, roundtrip_dt)
628+
629+
expected_child = original_dt.children["child"].copy(inherit=True)
630+
expected_child.name = None
631+
with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child:
632+
assert_identical(expected_child, roundtrip_child)

0 commit comments

Comments
 (0)