Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit ac1a68e

Browse files
committed
2 parents 75abfe1 + 20ada4e commit ac1a68e

File tree

3 files changed

+34
-32
lines changed

3 files changed

+34
-32
lines changed

datatree/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .datatree import DataTree, map_over_subtree, DataNode
2+
from .io import open_datatree

datatree/datatree.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def as_dataarray(self) -> DataArray:
639639
@property
640640
def groups(self):
641641
"""Return all netCDF4 groups in the tree, given as a tuple of path-like strings."""
642-
return tuple(node.path for node in self.subtree_nodes)
642+
return tuple(node.pathstr for node in self.subtree_nodes)
643643

644644
def to_netcdf(self, filename: str):
645645
from .io import _datatree_to_netcdf

datatree/io.py

+32-31
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,47 @@
1-
from typing import Sequence
1+
from typing import Sequence, Dict
2+
import os
23

3-
from netCDF4 import Dataset as nc_dataset
4+
import netCDF4
45

56
from xarray import open_dataset
67

7-
from .datatree import DataTree, PathType
8+
from .datatree import DataTree, DataNode, PathType
89

910

10-
def _get_group_names(file):
11-
rootgrp = nc_dataset("test.nc", "r", format="NETCDF4")
11+
def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs):
12+
for g in ncgroup.groups.values():
1213

13-
def walktree(top):
14-
yield top.groups.values()
15-
for value in top.groups.values():
16-
yield from walktree(value)
14+
# Open and add this node's dataset to the tree
15+
name = os.path.basename(g.path)
16+
ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs)
17+
child_node = DataNode(name, ds)
18+
node.add_child(child_node)
1719

18-
groups = []
19-
for children in walktree(rootgrp):
20-
for child in children:
21-
# TODO include parents in saved path
22-
groups.append(child.name)
20+
_open_group_children_recursively(filename, node[name], g, chunks, **kwargs)
2321

24-
rootgrp.close()
25-
return groups
2622

27-
28-
def open_datatree(filename_or_obj, engine=None, chunks=None, **kwargs) -> DataTree:
29-
"""
30-
Open and decode a dataset from a file or file-like object, creating one DataTree node
31-
for each group in the file.
23+
def open_datatree(filename: str, chunks: Dict = None, **kwargs) -> DataTree:
3224
"""
25+
Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file.
3326
34-
# TODO find all the netCDF groups in the file
35-
file_groups = _get_group_names(filename_or_obj)
27+
Parameters
28+
----------
29+
filename
30+
chunks
31+
32+
Returns
33+
-------
34+
DataTree
35+
"""
3636

37-
# Populate the DataTree with the groups
38-
groups_and_datasets = {group_path: open_dataset(engine=engine, chunks=chunks, **kwargs)
39-
for group_path in file_groups}
40-
return DataTree(data_objects=groups_and_datasets)
37+
with netCDF4.Dataset(filename, mode='r') as ncfile:
38+
ds = open_dataset(filename, chunks=chunks, **kwargs)
39+
tree_root = DataTree(data_objects={'root': ds})
40+
_open_group_children_recursively(filename, tree_root, ncfile, chunks, **kwargs)
41+
return tree_root
4142

4243

43-
def open_mfdatatree(filepaths, rootnames: Sequence[PathType] = None, engine=None, chunks=None, **kwargs) -> DataTree:
44+
def open_mfdatatree(filepaths, rootnames: Sequence[PathType] = None, chunks=None, **kwargs) -> DataTree:
4445
"""
4546
Open multiple files as a single DataTree.
4647
@@ -55,11 +56,11 @@ def open_mfdatatree(filepaths, rootnames: Sequence[PathType] = None, engine=None
5556
full_tree = DataTree()
5657

5758
for file, root in zip(filepaths, rootnames):
58-
dt = open_datatree(file, engine=engine, chunks=chunks, **kwargs)
59-
full_tree._set_item(path=root, value=dt, new_nodes_along_path=True, allow_overwrites=False)
59+
dt = open_datatree(file, chunks=chunks, **kwargs)
60+
full_tree.set_node(path=root, node=dt, new_nodes_along_path=True, allow_overwrite=False)
6061

6162
return full_tree
6263

6364

64-
def _datatree_to_netcdf(dt: DataTree, path_or_file: str):
65+
def _datatree_to_netcdf(dt: DataTree, filepath: str):
6566
raise NotImplementedError

0 commit comments

Comments
 (0)