Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Add zarr read/write #30

Merged
merged 7 commits into from
Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ dependencies:
- black
- codecov
- pytest-cov
- zarr
31 changes: 31 additions & 0 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,37 @@ def to_netcdf(
**kwargs,
)

def to_zarr(self, store, mode: str = "w", encoding=None, **kwargs):
"""
Write datatree contents to a netCDF file.

Parameters
---------
store : MutableMapping, str or Path, optional
Store or path to directory in file system
mode : {{"w", "w-", "a", "r+", None}, default: "w"
Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists);
“a” means override existing variables (create if does not exist); “r+” means modify existing
array values only (raise an error if any metadata or shapes would change). The default mode
is “a” if append_dim is set. Otherwise, it is “r+” if region is set and w- otherwise.
encoding : dict, optional
Nested dictionary with variable names as keys and dictionaries of
variable specific encodings as values, e.g.,
``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1}, ...}, ...}``.
See ``xarray.Dataset.to_zarr`` for available options.
kwargs :
Addional keyword arguments to be passed to ``xarray.Dataset.to_zarr``
"""
from .io import _datatree_to_zarr

_datatree_to_zarr(
self,
store,
mode=mode,
encoding=encoding,
**kwargs,
)

def plot(self):
raise NotImplementedError

Expand Down
140 changes: 115 additions & 25 deletions datatree/io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from typing import Dict, Sequence
import pathlib
from typing import Sequence

import netCDF4
from xarray import open_dataset

from .datatree import DataNode, DataTree, PathType
from .datatree import DataTree, PathType


def _ds_or_none(ds):
Expand All @@ -14,37 +13,87 @@ def _ds_or_none(ds):
return None


def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs):
for g in ncgroup.groups.values():
def _iter_zarr_groups(root, parrent=""):
parrent = pathlib.Path(parrent)
for path, group in root.groups():
gpath = parrent / path
yield str(gpath)
yield from _iter_zarr_groups(group, parrent=gpath)

# Open and add this node's dataset to the tree
name = os.path.basename(g.path)
ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs)
ds = _ds_or_none(ds)
child_node = DataNode(name, ds)
node.add_child(child_node)

_open_group_children_recursively(filename, node[name], g, chunks, **kwargs)
def _iter_nc_groups(root, parrent=""):
parrent = pathlib.Path(parrent)
for path, group in root.groups.items():
gpath = parrent / path
yield str(gpath)
yield from _iter_nc_groups(group, parrent=gpath)


def open_datatree(filename: str, chunks: Dict = None, **kwargs) -> DataTree:
def _get_nc_dataset_class(engine):
if engine == "netcdf4":
from netCDF4 import Dataset
elif engine == "h5netcdf":
from h5netcdf import Dataset
elif engine is None:
try:
from netCDF4 import Dataset
except ImportError:
from h5netcdf import Dataset
else:
raise ValueError(f"unsupported engine: {engine}")
return Dataset


def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree:
"""
Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file.

Parameters
----------
filename
chunks
filename_or_obj : str, Path, file-like, or DataStore
Strings and Path objects are interpreted as a path to a netCDF file or Zarr store.
engine : str, optional
Xarray backend engine to us. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`.
kwargs :
Additional keyword arguments passed to ``xarray.open_dataset`` for each group.

Returns
-------
DataTree
"""

with netCDF4.Dataset(filename, mode="r") as ncfile:
ds = open_dataset(filename, chunks=chunks, **kwargs)
if engine == "zarr":
return _open_datatree_zarr(filename_or_obj, **kwargs)
elif engine in [None, "netcdf4", "h5netcdf"]:
return _open_datatree_netcdf(filename_or_obj, engine=engine, **kwargs)


def _open_datatree_netcdf(filename: str, **kwargs) -> DataTree:
ncDataset = _get_nc_dataset_class(kwargs.get("engine", None))

with ncDataset(filename, mode="r") as ncds:
ds = open_dataset(filename, **kwargs).pipe(_ds_or_none)
tree_root = DataTree(data_objects={"root": ds})
for key in _iter_nc_groups(ncds):
tree_root[key] = open_dataset(filename, group=key, **kwargs).pipe(
_ds_or_none
)
return tree_root


def _open_datatree_zarr(store, **kwargs) -> DataTree:
import zarr

with zarr.open_group(store, mode="r") as zds:
ds = open_dataset(store, engine="zarr", **kwargs).pipe(_ds_or_none)
tree_root = DataTree(data_objects={"root": ds})
_open_group_children_recursively(filename, tree_root, ncfile, chunks, **kwargs)
for key in _iter_zarr_groups(zds):
try:
tree_root[key] = open_dataset(
store, engine="zarr", group=key, **kwargs
).pipe(_ds_or_none)
except zarr.errors.PathNotFoundError:
tree_root[key] = None
return tree_root


Expand Down Expand Up @@ -80,8 +129,10 @@ def _maybe_extract_group_kwargs(enc, group):
return None


def _create_empty_group(filename, group, mode):
with netCDF4.Dataset(filename, mode=mode) as rootgrp:
def _create_empty_netcdf_group(filename, group, mode, engine):
ncDataset = _get_nc_dataset_class(engine)

with ncDataset(filename, mode=mode) as rootgrp:
rootgrp.createGroup(group)


Expand All @@ -91,13 +142,14 @@ def _datatree_to_netcdf(
mode: str = "w",
encoding=None,
unlimited_dims=None,
**kwargs
**kwargs,
):

if kwargs.get("format", None) not in [None, "NETCDF4"]:
raise ValueError("to_netcdf only supports the NETCDF4 format")

if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]:
engine = kwargs.get("engine", None)
if engine not in [None, "netcdf4", "h5netcdf"]:
raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines")

if kwargs.get("group", None) is not None:
Expand All @@ -118,14 +170,52 @@ def _datatree_to_netcdf(
ds = node.ds
group_path = node.pathstr.replace(dt.root.pathstr, "")
if ds is None:
_create_empty_group(filepath, group_path, mode)
_create_empty_netcdf_group(filepath, group_path, mode, engine)
else:

ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
**kwargs
**kwargs,
)
mode = "a"


def _create_empty_zarr_group(store, group, mode):
import zarr

root = zarr.open_group(store, mode=mode)
root.create_group(group, overwrite=True)


def _datatree_to_zarr(dt: DataTree, store, mode: str = "w", encoding=None, **kwargs):

if kwargs.get("group", None) is not None:
raise NotImplementedError(
"specifying a root group for the tree has not been implemented"
)

if not kwargs.get("compute", True):
raise NotImplementedError("compute=False has not been implemented yet")

if encoding is None:
encoding = {}

for node in dt.subtree:
ds = node.ds
group_path = node.pathstr.replace(dt.root.pathstr, "")
if ds is None:
_create_empty_zarr_group(store, group_path, mode)
else:
ds.to_zarr(
store,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
**kwargs,
)
if "w" in mode:
mode = "a"
21 changes: 12 additions & 9 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,15 @@ def test_to_netcdf(self, tmpdir):

roundtrip_dt = open_datatree(filepath)

original_dt.name == roundtrip_dt.name
assert original_dt.ds.identical(roundtrip_dt.ds)
for a, b in zip(original_dt.descendants, roundtrip_dt.descendants):
assert a.name == b.name
assert a.pathstr == b.pathstr
if a.has_data:
assert a.ds.identical(b.ds)
else:
assert a.ds is b.ds
assert_tree_equal(original_dt, roundtrip_dt)

def test_to_zarr(self, tmpdir):
filepath = str(
tmpdir / "test.zarr"
) # casting to str avoids a pathlib bug in xarray
original_dt = create_test_datatree()
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr")

assert_tree_equal(original_dt, roundtrip_dt)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
xarray>=0.19.0
netcdf4
anytree
future