Skip to content

Commit 6c9a1cc

Browse files
Joe HammanTomNicholas
Joe Hamman
andauthored
Add zarr read/write xarray-contrib/datatree#30
* add test for roundtrip and support empty nodes * update roundtrip test, improves empty node handling in IO * add zarr read/write support * support netcdf4 or h5netcdf * netcdf is optional, zarr too! * Apply suggestions from code review Co-authored-by: Tom Nicholas <[email protected]> Co-authored-by: Tom Nicholas <[email protected]>
1 parent 6807504 commit 6c9a1cc

File tree

5 files changed

+159
-35
lines changed

5 files changed

+159
-35
lines changed

ci/environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ dependencies:
1111
- black
1212
- codecov
1313
- pytest-cov
14+
- zarr

datatree/datatree.py

+31
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,37 @@ def to_netcdf(
854854
**kwargs,
855855
)
856856

857+
def to_zarr(self, store, mode: str = "w", encoding=None, **kwargs):
858+
"""
859+
Write datatree contents to a netCDF file.
860+
861+
Parameters
862+
---------
863+
store : MutableMapping, str or Path, optional
864+
Store or path to directory in file system
865+
mode : {{"w", "w-", "a", "r+", None}, default: "w"
866+
Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists);
867+
“a” means override existing variables (create if does not exist); “r+” means modify existing
868+
array values only (raise an error if any metadata or shapes would change). The default mode
869+
is “a” if append_dim is set. Otherwise, it is “r+” if region is set and w- otherwise.
870+
encoding : dict, optional
871+
Nested dictionary with variable names as keys and dictionaries of
872+
variable specific encodings as values, e.g.,
873+
``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1}, ...}, ...}``.
874+
See ``xarray.Dataset.to_zarr`` for available options.
875+
kwargs :
876+
Addional keyword arguments to be passed to ``xarray.Dataset.to_zarr``
877+
"""
878+
from .io import _datatree_to_zarr
879+
880+
_datatree_to_zarr(
881+
self,
882+
store,
883+
mode=mode,
884+
encoding=encoding,
885+
**kwargs,
886+
)
887+
857888
def plot(self):
858889
raise NotImplementedError
859890

datatree/io.py

+115-25
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import os
2-
from typing import Dict, Sequence
1+
import pathlib
2+
from typing import Sequence
33

4-
import netCDF4
54
from xarray import open_dataset
65

7-
from .datatree import DataNode, DataTree, PathType
6+
from .datatree import DataTree, PathType
87

98

109
def _ds_or_none(ds):
@@ -14,37 +13,87 @@ def _ds_or_none(ds):
1413
return None
1514

1615

17-
def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs):
18-
for g in ncgroup.groups.values():
16+
def _iter_zarr_groups(root, parrent=""):
17+
parrent = pathlib.Path(parrent)
18+
for path, group in root.groups():
19+
gpath = parrent / path
20+
yield str(gpath)
21+
yield from _iter_zarr_groups(group, parrent=gpath)
1922

20-
# Open and add this node's dataset to the tree
21-
name = os.path.basename(g.path)
22-
ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs)
23-
ds = _ds_or_none(ds)
24-
child_node = DataNode(name, ds)
25-
node.add_child(child_node)
2623

27-
_open_group_children_recursively(filename, node[name], g, chunks, **kwargs)
24+
def _iter_nc_groups(root, parrent=""):
25+
parrent = pathlib.Path(parrent)
26+
for path, group in root.groups.items():
27+
gpath = parrent / path
28+
yield str(gpath)
29+
yield from _iter_nc_groups(group, parrent=gpath)
2830

2931

30-
def open_datatree(filename: str, chunks: Dict = None, **kwargs) -> DataTree:
32+
def _get_nc_dataset_class(engine):
33+
if engine == "netcdf4":
34+
from netCDF4 import Dataset
35+
elif engine == "h5netcdf":
36+
from h5netcdf import Dataset
37+
elif engine is None:
38+
try:
39+
from netCDF4 import Dataset
40+
except ImportError:
41+
from h5netcdf import Dataset
42+
else:
43+
raise ValueError(f"unsupported engine: {engine}")
44+
return Dataset
45+
46+
47+
def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree:
3148
"""
3249
Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file.
3350
3451
Parameters
3552
----------
36-
filename
37-
chunks
53+
filename_or_obj : str, Path, file-like, or DataStore
54+
Strings and Path objects are interpreted as a path to a netCDF file or Zarr store.
55+
engine : str, optional
56+
Xarray backend engine to us. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`.
57+
kwargs :
58+
Additional keyword arguments passed to ``xarray.open_dataset`` for each group.
3859
3960
Returns
4061
-------
4162
DataTree
4263
"""
4364

44-
with netCDF4.Dataset(filename, mode="r") as ncfile:
45-
ds = open_dataset(filename, chunks=chunks, **kwargs)
65+
if engine == "zarr":
66+
return _open_datatree_zarr(filename_or_obj, **kwargs)
67+
elif engine in [None, "netcdf4", "h5netcdf"]:
68+
return _open_datatree_netcdf(filename_or_obj, engine=engine, **kwargs)
69+
70+
71+
def _open_datatree_netcdf(filename: str, **kwargs) -> DataTree:
72+
ncDataset = _get_nc_dataset_class(kwargs.get("engine", None))
73+
74+
with ncDataset(filename, mode="r") as ncds:
75+
ds = open_dataset(filename, **kwargs).pipe(_ds_or_none)
76+
tree_root = DataTree(data_objects={"root": ds})
77+
for key in _iter_nc_groups(ncds):
78+
tree_root[key] = open_dataset(filename, group=key, **kwargs).pipe(
79+
_ds_or_none
80+
)
81+
return tree_root
82+
83+
84+
def _open_datatree_zarr(store, **kwargs) -> DataTree:
85+
import zarr
86+
87+
with zarr.open_group(store, mode="r") as zds:
88+
ds = open_dataset(store, engine="zarr", **kwargs).pipe(_ds_or_none)
4689
tree_root = DataTree(data_objects={"root": ds})
47-
_open_group_children_recursively(filename, tree_root, ncfile, chunks, **kwargs)
90+
for key in _iter_zarr_groups(zds):
91+
try:
92+
tree_root[key] = open_dataset(
93+
store, engine="zarr", group=key, **kwargs
94+
).pipe(_ds_or_none)
95+
except zarr.errors.PathNotFoundError:
96+
tree_root[key] = None
4897
return tree_root
4998

5099

@@ -80,8 +129,10 @@ def _maybe_extract_group_kwargs(enc, group):
80129
return None
81130

82131

83-
def _create_empty_group(filename, group, mode):
84-
with netCDF4.Dataset(filename, mode=mode) as rootgrp:
132+
def _create_empty_netcdf_group(filename, group, mode, engine):
133+
ncDataset = _get_nc_dataset_class(engine)
134+
135+
with ncDataset(filename, mode=mode) as rootgrp:
85136
rootgrp.createGroup(group)
86137

87138

@@ -91,13 +142,14 @@ def _datatree_to_netcdf(
91142
mode: str = "w",
92143
encoding=None,
93144
unlimited_dims=None,
94-
**kwargs
145+
**kwargs,
95146
):
96147

97148
if kwargs.get("format", None) not in [None, "NETCDF4"]:
98149
raise ValueError("to_netcdf only supports the NETCDF4 format")
99150

100-
if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]:
151+
engine = kwargs.get("engine", None)
152+
if engine not in [None, "netcdf4", "h5netcdf"]:
101153
raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines")
102154

103155
if kwargs.get("group", None) is not None:
@@ -118,14 +170,52 @@ def _datatree_to_netcdf(
118170
ds = node.ds
119171
group_path = node.pathstr.replace(dt.root.pathstr, "")
120172
if ds is None:
121-
_create_empty_group(filepath, group_path, mode)
173+
_create_empty_netcdf_group(filepath, group_path, mode, engine)
122174
else:
175+
123176
ds.to_netcdf(
124177
filepath,
125178
group=group_path,
126179
mode=mode,
127180
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
128181
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
129-
**kwargs
182+
**kwargs,
130183
)
131184
mode = "a"
185+
186+
187+
def _create_empty_zarr_group(store, group, mode):
188+
import zarr
189+
190+
root = zarr.open_group(store, mode=mode)
191+
root.create_group(group, overwrite=True)
192+
193+
194+
def _datatree_to_zarr(dt: DataTree, store, mode: str = "w", encoding=None, **kwargs):
195+
196+
if kwargs.get("group", None) is not None:
197+
raise NotImplementedError(
198+
"specifying a root group for the tree has not been implemented"
199+
)
200+
201+
if not kwargs.get("compute", True):
202+
raise NotImplementedError("compute=False has not been implemented yet")
203+
204+
if encoding is None:
205+
encoding = {}
206+
207+
for node in dt.subtree:
208+
ds = node.ds
209+
group_path = node.pathstr.replace(dt.root.pathstr, "")
210+
if ds is None:
211+
_create_empty_zarr_group(store, group_path, mode)
212+
else:
213+
ds.to_zarr(
214+
store,
215+
group=group_path,
216+
mode=mode,
217+
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
218+
**kwargs,
219+
)
220+
if "w" in mode:
221+
mode = "a"

datatree/tests/test_datatree.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,15 @@ def test_to_netcdf(self, tmpdir):
322322

323323
roundtrip_dt = open_datatree(filepath)
324324

325-
original_dt.name == roundtrip_dt.name
326-
assert original_dt.ds.identical(roundtrip_dt.ds)
327-
for a, b in zip(original_dt.descendants, roundtrip_dt.descendants):
328-
assert a.name == b.name
329-
assert a.pathstr == b.pathstr
330-
if a.has_data:
331-
assert a.ds.identical(b.ds)
332-
else:
333-
assert a.ds is b.ds
325+
assert_tree_equal(original_dt, roundtrip_dt)
326+
327+
def test_to_zarr(self, tmpdir):
328+
filepath = str(
329+
tmpdir / "test.zarr"
330+
) # casting to str avoids a pathlib bug in xarray
331+
original_dt = create_test_datatree()
332+
original_dt.to_zarr(filepath)
333+
334+
roundtrip_dt = open_datatree(filepath, engine="zarr")
335+
336+
assert_tree_equal(original_dt, roundtrip_dt)

requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
xarray>=0.19.0
2-
netcdf4
32
anytree
43
future

0 commit comments

Comments
 (0)