From e16e1a4ca8138132872181550745a798d19d2245 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 25 Aug 2021 12:46:32 -0700 Subject: [PATCH 1/6] add test for roundtrip and support empty nodes --- datatree/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datatree/io.py b/datatree/io.py index c717203a..8da1d493 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -2,7 +2,7 @@ from typing import Dict, Sequence import netCDF4 -from xarray import open_dataset +from xarray import Dataset, open_dataset from .datatree import DataNode, DataTree, PathType From a9c680e4325466adf68beccddaf23d03274ba5b0 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 25 Aug 2021 15:07:21 -0700 Subject: [PATCH 2/6] update roundtrip test, improves empty node handling in IO --- datatree/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datatree/io.py b/datatree/io.py index 8da1d493..c717203a 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -2,7 +2,7 @@ from typing import Dict, Sequence import netCDF4 -from xarray import Dataset, open_dataset +from xarray import open_dataset from .datatree import DataNode, DataTree, PathType From b000936441d5c9d308a21554b4f61e50d9fd550c Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Thu, 26 Aug 2021 09:30:58 -0700 Subject: [PATCH 3/6] add zarr read/write support --- datatree/datatree.py | 31 +++++++++ datatree/io.py | 115 ++++++++++++++++++++++++++------ datatree/tests/test_datatree.py | 19 ++++++ 3 files changed, 145 insertions(+), 20 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index 416e1894..5533e364 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -912,6 +912,37 @@ def to_netcdf( **kwargs, ) + def to_zarr(self, store, mode: str = "w", encoding=None, **kwargs): + """ + Write datatree contents to a netCDF file. + + Paramters + --------- + store :MutableMapping, str or Path, optional + Store or path to directory in file system + mode : {{"w", "w-", "a", "r+", None}, default: "w" + Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists); + “a” means override existing variables (create if does not exist); “r+” means modify existing + array values only (raise an error if any metadata or shapes would change). The default mode + is “a” if append_dim is set. Otherwise, it is “r+” if region is set and w- otherwise. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1}, ...}, ...}``. + See ``xarray.Dataset.to_zarr`` for available options. + kwargs : + Addional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` + """ + from .io import _datatree_to_zarr + + _datatree_to_zarr( + self, + store, + mode=mode, + encoding=encoding, + **kwargs, + ) + def plot(self): raise NotImplementedError diff --git a/datatree/io.py b/datatree/io.py index c717203a..2430f297 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -1,10 +1,9 @@ import os -from typing import Dict, Sequence +from typing import Sequence -import netCDF4 from xarray import open_dataset -from .datatree import DataNode, DataTree, PathType +from .datatree import DataTree, PathType def _ds_or_none(ds): @@ -14,20 +13,21 @@ def _ds_or_none(ds): return None -def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs): - for g in ncgroup.groups.values(): +def _iter_zarr_groups(root, parrent=""): + for path, group in root.groups(): + gpath = os.path.join(parrent, path) + yield gpath + yield from _iter_zarr_groups(group, parrent=gpath) - # Open and add this node's dataset to the tree - name = os.path.basename(g.path) - ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs) - ds = _ds_or_none(ds) - child_node = DataNode(name, ds) - node.add_child(child_node) - _open_group_children_recursively(filename, node[name], g, chunks, **kwargs) +def _iter_nc_groups(root, parrent=""): + for path, group in root.groups.items(): + gpath = os.path.join(parrent, path) + yield gpath + yield from _iter_nc_groups(group, parrent=gpath) -def open_datatree(filename: str, chunks: Dict = None, **kwargs) -> DataTree: +def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree: """ Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file. @@ -41,11 +41,34 @@ def open_datatree(filename: str, chunks: Dict = None, **kwargs) -> DataTree: DataTree """ - with netCDF4.Dataset(filename, mode="r") as ncfile: - ds = open_dataset(filename, chunks=chunks, **kwargs) + if engine == "zarr": + return _open_datatree_zarr(filename_or_obj, **kwargs) + else: + return _open_datatree_netcdf(filename_or_obj, engine=engine, **kwargs) + + +def _open_datatree_netcdf(filename: str, **kwargs) -> DataTree: + import netCDF4 + + with netCDF4.Dataset(filename, mode="r") as ncds: + ds = open_dataset(filename, **kwargs).pipe(_ds_or_none) + tree_root = DataTree(data_objects={"root": ds}) + for key in _iter_nc_groups(ncds): + tree_root[key] = open_dataset(filename, group=key, **kwargs).pipe( + _ds_or_none + ) + + +def _open_datatree_zarr(store, **kwargs) -> DataTree: + import zarr + + with zarr.Dataset(store, mode="r") as zds: + ds = open_dataset(store, engine="zarr", **kwargs).pipe(_ds_or_none) tree_root = DataTree(data_objects={"root": ds}) - _open_group_children_recursively(filename, tree_root, ncfile, chunks, **kwargs) - return tree_root + for key in _iter_zarr_groups(zds): + tree_root[key] = open_dataset( + store, engine="zarr", group=key, **kwargs + ).pipe(_ds_or_none) def open_mfdatatree( @@ -80,7 +103,9 @@ def _maybe_extract_group_kwargs(enc, group): return None -def _create_empty_group(filename, group, mode): +def _create_empty_netcdf_group(filename, group, mode): + import netCDF4 + with netCDF4.Dataset(filename, mode=mode) as rootgrp: rootgrp.createGroup(group) @@ -117,7 +142,7 @@ def _datatree_to_netcdf( ds = dt.ds group_path = dt.pathstr.replace(dt.root.pathstr, "") if ds is None: - _create_empty_group(filepath, group_path, mode) + _create_empty_netcdf_group(filepath, group_path, mode) else: ds.to_netcdf( filepath, @@ -133,7 +158,7 @@ def _datatree_to_netcdf( ds = node.ds group_path = node.pathstr.replace(dt.root.pathstr, "") if ds is None: - _create_empty_group(filepath, group_path, mode) + _create_empty_netcdf_group(filepath, group_path, mode) else: ds.to_netcdf( filepath, @@ -143,3 +168,53 @@ def _datatree_to_netcdf( unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), **kwargs ) + + +def _create_empty_zarr_group(store, group, mode): + import zarr + + root = zarr.open_group(store, mode=mode) + root.create_group(group, overwrite=True) + + +def _datatree_to_zarr(dt: DataTree, store, mode: str = "w", encoding=None, **kwargs): + + if kwargs.get("group", None) is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + + if not kwargs.get("compute", True): + raise NotImplementedError("compute=False has not been implemented yet") + + if encoding is None: + encoding = {} + + ds = dt.ds + group_path = dt.pathstr.replace(dt.root.pathstr, "") + if ds is None: + _create_empty_zarr_group(store, group_path, mode) + else: + ds.to_zarr( + store, + group=group_path, + mode=mode, + encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), + **kwargs + ) + if "w" in mode: + mode = "a" + + for node in dt.descendants: + ds = node.ds + group_path = node.pathstr.replace(dt.root.pathstr, "") + if ds is None: + _create_empty_zarr_group(store, group_path, mode) + else: + ds.to_zarr( + store, + group=group_path, + mode=mode, + encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), + **kwargs + ) diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index b82ae984..ebe4e112 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -316,3 +316,22 @@ def test_to_netcdf(self, tmpdir): assert a.ds.identical(b.ds) else: assert a.ds is b.ds + + def test_to_zarr(self, tmpdir): + filepath = str( + tmpdir / "test.zarr" + ) # casting to str avoids a pathlib bug in xarray + original_dt = create_test_datatree() + original_dt.to_zarr(filepath) + + roundtrip_dt = open_datatree(filepath, engine="zarr") + + original_dt.name == roundtrip_dt.name + assert original_dt.ds.identical(roundtrip_dt.ds) + for a, b in zip(original_dt.descendants, roundtrip_dt.descendants): + assert a.name == b.name + assert a.pathstr == b.pathstr + if a.has_data: + assert a.ds.identical(b.ds) + else: + assert a.ds is b.ds From 537e3052ab21d395807ac77969b3b4a5f6fc8a25 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Thu, 26 Aug 2021 10:38:53 -0700 Subject: [PATCH 4/6] support netcdf4 or h5netcdf --- datatree/io.py | 65 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 20 deletions(-) diff --git a/datatree/io.py b/datatree/io.py index 2430f297..90c63785 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -27,14 +27,33 @@ def _iter_nc_groups(root, parrent=""): yield from _iter_nc_groups(group, parrent=gpath) +def _get_nc_dataset_class(engine): + if engine == "netcdf4": + from netCDF4 import Dataset + elif engine == "h5netcdf": + from h5netcdf import Dataset + elif engine is None: + try: + from netCDF4 import Dataset + except ImportError: + from h5netcdf import Dataset + else: + raise ValueError(f"unsupported engine: {engine}") + return Dataset + + def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree: """ Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file. Parameters ---------- - filename - chunks + filename_or_obj : str, Path, file-like, or DataStore + Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. + engine : str, optional + Xarray backend engine to us. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`. + kwargs : + Additional keyword arguments passed to ``xarray.open_dataset`` for each group. Returns ------- @@ -43,32 +62,37 @@ def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree: if engine == "zarr": return _open_datatree_zarr(filename_or_obj, **kwargs) - else: + elif engine in [None, "netcdf4", "h5netcdf"]: return _open_datatree_netcdf(filename_or_obj, engine=engine, **kwargs) def _open_datatree_netcdf(filename: str, **kwargs) -> DataTree: - import netCDF4 + ncDataset = _get_nc_dataset_class(kwargs.get("engine", None)) - with netCDF4.Dataset(filename, mode="r") as ncds: + with ncDataset(filename, mode="r") as ncds: ds = open_dataset(filename, **kwargs).pipe(_ds_or_none) tree_root = DataTree(data_objects={"root": ds}) for key in _iter_nc_groups(ncds): tree_root[key] = open_dataset(filename, group=key, **kwargs).pipe( _ds_or_none ) + return tree_root def _open_datatree_zarr(store, **kwargs) -> DataTree: import zarr - with zarr.Dataset(store, mode="r") as zds: + with zarr.open_group(store, mode="r") as zds: ds = open_dataset(store, engine="zarr", **kwargs).pipe(_ds_or_none) tree_root = DataTree(data_objects={"root": ds}) for key in _iter_zarr_groups(zds): - tree_root[key] = open_dataset( - store, engine="zarr", group=key, **kwargs - ).pipe(_ds_or_none) + try: + tree_root[key] = open_dataset( + store, engine="zarr", group=key, **kwargs + ).pipe(_ds_or_none) + except zarr.errors.PathNotFoundError: + tree_root[key] = None + return tree_root def open_mfdatatree( @@ -103,10 +127,10 @@ def _maybe_extract_group_kwargs(enc, group): return None -def _create_empty_netcdf_group(filename, group, mode): - import netCDF4 +def _create_empty_netcdf_group(filename, group, mode, engine): + ncDataset = _get_nc_dataset_class(engine) - with netCDF4.Dataset(filename, mode=mode) as rootgrp: + with ncDataset(filename, mode=mode) as rootgrp: rootgrp.createGroup(group) @@ -116,13 +140,14 @@ def _datatree_to_netcdf( mode: str = "w", encoding=None, unlimited_dims=None, - **kwargs + **kwargs, ): if kwargs.get("format", None) not in [None, "NETCDF4"]: raise ValueError("to_netcdf only supports the NETCDF4 format") - if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]: + engine = kwargs.get("engine", None) + if engine not in [None, "netcdf4", "h5netcdf"]: raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines") if kwargs.get("group", None) is not None: @@ -142,7 +167,7 @@ def _datatree_to_netcdf( ds = dt.ds group_path = dt.pathstr.replace(dt.root.pathstr, "") if ds is None: - _create_empty_netcdf_group(filepath, group_path, mode) + _create_empty_netcdf_group(filepath, group_path, mode, engine) else: ds.to_netcdf( filepath, @@ -150,7 +175,7 @@ def _datatree_to_netcdf( mode=mode, encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), - **kwargs + **kwargs, ) mode = "a" @@ -158,7 +183,7 @@ def _datatree_to_netcdf( ds = node.ds group_path = node.pathstr.replace(dt.root.pathstr, "") if ds is None: - _create_empty_netcdf_group(filepath, group_path, mode) + _create_empty_netcdf_group(filepath, group_path, mode, engine) else: ds.to_netcdf( filepath, @@ -166,7 +191,7 @@ def _datatree_to_netcdf( mode=mode, encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), - **kwargs + **kwargs, ) @@ -200,7 +225,7 @@ def _datatree_to_zarr(dt: DataTree, store, mode: str = "w", encoding=None, **kwa group=group_path, mode=mode, encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), - **kwargs + **kwargs, ) if "w" in mode: mode = "a" @@ -216,5 +241,5 @@ def _datatree_to_zarr(dt: DataTree, store, mode: str = "w", encoding=None, **kwa group=group_path, mode=mode, encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), - **kwargs + **kwargs, ) From 7ec2f903d06142a85f934cc005c186fcdb12f90c Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Thu, 26 Aug 2021 10:51:20 -0700 Subject: [PATCH 5/6] netcdf is optional, zarr too! --- ci/environment.yml | 1 + requirements.txt | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/environment.yml b/ci/environment.yml index 8486fc92..e379a9fa 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -11,3 +11,4 @@ dependencies: - black - codecov - pytest-cov + - zarr diff --git a/requirements.txt b/requirements.txt index 67e19d19..a95f277b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ xarray>=0.19.0 -netcdf4 anytree future From 82fd7002c12e40b26615a1ec98fccfca3ff820cb Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Mon, 30 Aug 2021 08:17:14 -0700 Subject: [PATCH 6/6] Apply suggestions from code review Co-authored-by: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> --- datatree/datatree.py | 4 ++-- datatree/io.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index 5533e364..02c8f6d6 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -916,9 +916,9 @@ def to_zarr(self, store, mode: str = "w", encoding=None, **kwargs): """ Write datatree contents to a netCDF file. - Paramters + Parameters --------- - store :MutableMapping, str or Path, optional + store : MutableMapping, str or Path, optional Store or path to directory in file system mode : {{"w", "w-", "a", "r+", None}, default: "w" Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists); diff --git a/datatree/io.py b/datatree/io.py index 90c63785..14fd879b 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -20,9 +20,9 @@ def _iter_zarr_groups(root, parrent=""): yield from _iter_zarr_groups(group, parrent=gpath) -def _iter_nc_groups(root, parrent=""): +def _iter_nc_groups(root, parent=""): for path, group in root.groups.items(): - gpath = os.path.join(parrent, path) + gpath = os.path.join(parent, path) yield gpath yield from _iter_nc_groups(group, parrent=gpath)