Skip to content

Commit 0e48638

Browse files
authored
* add to_dict method * added roundtrip test * add xfailed test for rountripping with named root * whats-new
1 parent 64d98f5 commit 0e48638

File tree

4 files changed

+35
-5
lines changed

4 files changed

+35
-5
lines changed

datatree/datatree.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
TYPE_CHECKING,
77
Any,
88
Callable,
9+
Dict,
910
Generic,
1011
Iterable,
1112
Mapping,
@@ -316,7 +317,7 @@ def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None:
316317
@classmethod
317318
def from_dict(
318319
cls,
319-
d: MutableMapping[str, DataTree | Dataset | DataArray],
320+
d: MutableMapping[str, Dataset | DataArray | None],
320321
name: str = None,
321322
) -> DataTree:
322323
"""
@@ -337,19 +338,22 @@ def from_dict(
337338
Returns
338339
-------
339340
DataTree
341+
342+
Notes
343+
-----
344+
If your dictionary is nested you will need to flatten it before using this method.
340345
"""
341346

342347
# First create the root node
343-
# TODO there is a real bug here where what if root_data is of type DataTree?
344348
root_data = d.pop("/", None)
345-
obj = cls(name=name, data=root_data, parent=None, children=None) # type: ignore[arg-type]
349+
obj = cls(name=name, data=root_data, parent=None, children=None)
346350

347351
if d:
348352
# Populate tree with children determined from data_objects mapping
349353
for path, data in d.items():
350354
# Create and set new node
351355
node_name = NodePath(path).name
352-
new_node = cls(name=node_name, data=data) # type: ignore[arg-type]
356+
new_node = cls(name=node_name, data=data)
353357
obj._set_item(
354358
path,
355359
new_node,
@@ -358,6 +362,16 @@ def from_dict(
358362
)
359363
return obj
360364

365+
def to_dict(self) -> Dict[str, Any]:
366+
"""
367+
Create a dictionary mapping of absolute node paths to the data contained in those nodes.
368+
369+
Returns
370+
-------
371+
Dict
372+
"""
373+
return {node.path: node.ds for node in self.subtree}
374+
361375
@property
362376
def nbytes(self) -> int:
363377
return sum(node.ds.nbytes if node.has_data else 0 for node in self.subtree)

datatree/tests/test_datatree.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,20 @@ def test_full(self):
328328
"/set3",
329329
]
330330

331+
def test_roundtrip(self):
332+
dt = create_test_datatree()
333+
roundtrip = DataTree.from_dict(dt.to_dict())
334+
assert roundtrip.equals(dt)
335+
336+
@pytest.mark.xfail
337+
def test_roundtrip_unnamed_root(self):
338+
# See GH81
339+
340+
dt = create_test_datatree()
341+
dt.name = "root"
342+
roundtrip = DataTree.from_dict(dt.to_dict())
343+
assert roundtrip.equals(dt)
344+
331345

332346
class TestBrowsing:
333347
...

docs/source/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,13 @@ I/O
215215

216216
open_datatree
217217
DataTree.from_dict
218+
DataTree.to_dict
218219
DataTree.to_netcdf
219220
DataTree.to_zarr
220221

221222
..
222223
223224
Missing
224-
DataTree.to_dict
225225
open_mfdatatree
226226

227227
Exceptions

docs/source/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ New Features
3838
By `Tom Nicholas <https://github.com/TomNicholas>`_.
3939
- New delitem method so you can delete nodes. (:pull:`88`)
4040
By `Tom Nicholas <https://github.com/TomNicholas>`_.
41+
- New ``to_dict`` method. (:pull:`82`)
42+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
4143

4244
Breaking changes
4345
~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)