Skip to content

Commit 6807504

Browse files
authored
Check isomorphism xarray-contrib/datatree#31
* pseudocode ideas for generalizing map_over_subtree * pseudocode for a generalized map_over_subtree (still only one return arg) + a new mapping.py file * pseudocode for mapping but now multiple return values * pseudocode for mapping but with multiple return values * check_isomorphism works and has tests * cleaned up the mapping tests a bit * remove WIP from oter branch * ensure tests pass * map_over_subtree in the public API properly * linting
1 parent fa69ad7 commit 6807504

File tree

7 files changed

+346
-132
lines changed

7 files changed

+346
-132
lines changed

datatree/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# flake8: noqa
22
# Ignoring F401: imported but unused
3-
from .datatree import DataNode, DataTree, map_over_subtree
3+
from .datatree import DataNode, DataTree
44
from .io import open_datatree
5+
from .mapping import map_over_subtree

datatree/datatree.py

+1-57
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import functools
43
import textwrap
54
from typing import Any, Callable, Dict, Hashable, Iterable, List, Mapping, Union
65

@@ -14,6 +13,7 @@
1413
from xarray.core.ops import NAN_CUM_METHODS, NAN_REDUCE_METHODS, REDUCE_METHODS
1514
from xarray.core.variable import Variable
1615

16+
from .mapping import map_over_subtree
1717
from .treenode import PathType, TreeNode, _init_single_treenode
1818

1919
"""
@@ -50,62 +50,6 @@
5050
"""
5151

5252

53-
def map_over_subtree(func):
54-
"""
55-
Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees.
56-
57-
Applies a function to every dataset in this subtree, returning a new tree which stores the results.
58-
59-
The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
60-
descendant nodes. The returned tree will have the same structure as the original subtree.
61-
62-
func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
63-
result will be assigned to its respective node of new tree via `DataTree.__setitem__`.
64-
65-
Parameters
66-
----------
67-
func : callable
68-
Function to apply to datasets with signature:
69-
`func(node.ds, *args, **kwargs) -> Dataset`.
70-
71-
Function will not be applied to any nodes without datasets.
72-
*args : tuple, optional
73-
Positional arguments passed on to `func`.
74-
**kwargs : Any
75-
Keyword arguments passed on to `func`.
76-
77-
Returns
78-
-------
79-
mapped : callable
80-
Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node.
81-
82-
See also
83-
--------
84-
DataTree.map_over_subtree
85-
DataTree.map_over_subtree_inplace
86-
"""
87-
88-
@functools.wraps(func)
89-
def _map_over_subtree(tree, *args, **kwargs):
90-
"""Internal function which maps func over every node in tree, returning a tree of the results."""
91-
92-
# Recreate and act on root node
93-
out_tree = DataNode(name=tree.name, data=tree.ds)
94-
if out_tree.has_data:
95-
out_tree.ds = func(out_tree.ds, *args, **kwargs)
96-
97-
# Act on every other node in the tree, and rebuild from results
98-
for node in tree.descendants:
99-
# TODO make a proper relative_path method
100-
relative_path = node.pathstr.replace(tree.pathstr, "")
101-
result = func(node.ds, *args, **kwargs) if node.has_data else None
102-
out_tree[relative_path] = result
103-
104-
return out_tree
105-
106-
return _map_over_subtree
107-
108-
10953
class DatasetPropertiesMixin:
11054
"""Expose properties of wrapped Dataset"""
11155

datatree/mapping.py

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import functools
2+
3+
from anytree.iterators import LevelOrderIter
4+
5+
from .treenode import TreeNode
6+
7+
8+
class TreeIsomorphismError(ValueError):
9+
"""Error raised if two tree objects are not isomorphic to one another when they need to be."""
10+
11+
pass
12+
13+
14+
def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):
15+
"""
16+
Check that two trees have the same structure, raising an error if not.
17+
18+
Does not check the actual data in the nodes, but it does check that if one node does/doesn't have data then its
19+
counterpart in the other tree also does/doesn't have data.
20+
21+
Also does not check that the root nodes of each tree have the same parent - so this function checks that subtrees
22+
are isomorphic, not the entire tree above (if it exists).
23+
24+
Can optionally check if respective nodes should have the same name.
25+
26+
Parameters
27+
----------
28+
subtree_a : DataTree
29+
subtree_b : DataTree
30+
require_names_equal : Bool, optional
31+
Whether or not to also check that each node has the same name as its counterpart. Default is False.
32+
33+
Raises
34+
------
35+
TypeError
36+
If either subtree_a or subtree_b are not tree objects.
37+
TreeIsomorphismError
38+
If subtree_a and subtree_b are tree objects, but are not isomorphic to one another, or one contains data at a
39+
location the other does not. Also optionally raised if their structure is isomorphic, but the names of any two
40+
respective nodes are not equal.
41+
"""
42+
# TODO turn this into a public function called assert_isomorphic
43+
44+
if not isinstance(subtree_a, TreeNode):
45+
raise TypeError(
46+
f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}"
47+
)
48+
if not isinstance(subtree_b, TreeNode):
49+
raise TypeError(
50+
f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}"
51+
)
52+
53+
# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
54+
# Checking by walking in this way implicitly assumes that the tree is an ordered tree (which it is so long as
55+
# children are stored in a tuple or list rather than in a set).
56+
for node_a, node_b in zip(LevelOrderIter(subtree_a), LevelOrderIter(subtree_b)):
57+
path_a, path_b = node_a.pathstr, node_b.pathstr
58+
59+
if require_names_equal:
60+
if node_a.name != node_b.name:
61+
raise TreeIsomorphismError(
62+
f"Trees are not isomorphic because node '{path_a}' in the first tree has "
63+
f"name '{node_a.name}', whereas its counterpart node '{path_b}' in the "
64+
f"second tree has name '{node_b.name}'."
65+
)
66+
67+
if node_a.has_data != node_b.has_data:
68+
dat_a = "no " if not node_a.has_data else ""
69+
dat_b = "no " if not node_b.has_data else ""
70+
raise TreeIsomorphismError(
71+
f"Trees are not isomorphic because node '{path_a}' in the first tree has "
72+
f"{dat_a}data, whereas its counterpart node '{path_b}' in the second tree "
73+
f"has {dat_b}data."
74+
)
75+
76+
if len(node_a.children) != len(node_b.children):
77+
raise TreeIsomorphismError(
78+
f"Trees are not isomorphic because node '{path_a}' in the first tree has "
79+
f"{len(node_a.children)} children, whereas its counterpart node '{path_b}' in "
80+
f"the second tree has {len(node_b.children)} children."
81+
)
82+
83+
84+
def map_over_subtree(func):
85+
"""
86+
Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees.
87+
88+
Applies a function to every dataset in this subtree, returning a new tree which stores the results.
89+
90+
The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
91+
descendant nodes. The returned tree will have the same structure as the original subtree.
92+
93+
func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
94+
result will be assigned to its respective node of new tree via `DataTree.__setitem__`.
95+
96+
Parameters
97+
----------
98+
func : callable
99+
Function to apply to datasets with signature:
100+
`func(node.ds, *args, **kwargs) -> Dataset`.
101+
102+
Function will not be applied to any nodes without datasets.
103+
*args : tuple, optional
104+
Positional arguments passed on to `func`.
105+
**kwargs : Any
106+
Keyword arguments passed on to `func`.
107+
108+
Returns
109+
-------
110+
mapped : callable
111+
Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node.
112+
113+
See also
114+
--------
115+
DataTree.map_over_subtree
116+
DataTree.map_over_subtree_inplace
117+
"""
118+
119+
@functools.wraps(func)
120+
def _map_over_subtree(tree, *args, **kwargs):
121+
"""Internal function which maps func over every node in tree, returning a tree of the results."""
122+
123+
# Recreate and act on root node
124+
from .datatree import DataNode
125+
126+
out_tree = DataNode(name=tree.name, data=tree.ds)
127+
if out_tree.has_data:
128+
out_tree.ds = func(out_tree.ds, *args, **kwargs)
129+
130+
# Act on every other node in the tree, and rebuild from results
131+
for node in tree.descendants:
132+
# TODO make a proper relative_path method
133+
relative_path = node.pathstr.replace(tree.pathstr, "")
134+
result = func(node.ds, *args, **kwargs) if node.has_data else None
135+
out_tree[relative_path] = result
136+
137+
return out_tree
138+
139+
return _map_over_subtree

datatree/tests/test_dataset_api.py

+1-68
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,9 @@
11
import numpy as np
22
import pytest
33
import xarray as xr
4-
from test_datatree import create_test_datatree
54
from xarray.testing import assert_equal
65

7-
from datatree import DataNode, DataTree, map_over_subtree
8-
9-
10-
class TestMapOverSubTree:
11-
def test_map_over_subtree(self):
12-
dt = create_test_datatree()
13-
14-
@map_over_subtree
15-
def times_ten(ds):
16-
return 10.0 * ds
17-
18-
result_tree = times_ten(dt)
19-
20-
# TODO write an assert_tree_equal function
21-
for (
22-
result_node,
23-
original_node,
24-
) in zip(result_tree.subtree, dt.subtree):
25-
assert isinstance(result_node, DataTree)
26-
27-
if original_node.has_data:
28-
assert_equal(result_node.ds, original_node.ds * 10.0)
29-
else:
30-
assert not result_node.has_data
31-
32-
def test_map_over_subtree_with_args_and_kwargs(self):
33-
dt = create_test_datatree()
34-
35-
@map_over_subtree
36-
def multiply_then_add(ds, times, add=0.0):
37-
return times * ds + add
38-
39-
result_tree = multiply_then_add(dt, 10.0, add=2.0)
40-
41-
for (
42-
result_node,
43-
original_node,
44-
) in zip(result_tree.subtree, dt.subtree):
45-
assert isinstance(result_node, DataTree)
46-
47-
if original_node.has_data:
48-
assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0)
49-
else:
50-
assert not result_node.has_data
51-
52-
def test_map_over_subtree_method(self):
53-
dt = create_test_datatree()
54-
55-
def multiply_then_add(ds, times, add=0.0):
56-
return times * ds + add
57-
58-
result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0)
59-
60-
for (
61-
result_node,
62-
original_node,
63-
) in zip(result_tree.subtree, dt.subtree):
64-
assert isinstance(result_node, DataTree)
65-
66-
if original_node.has_data:
67-
assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0)
68-
else:
69-
assert not result_node.has_data
70-
71-
@pytest.mark.xfail
72-
def test_map_over_subtree_inplace(self):
73-
raise NotImplementedError
6+
from datatree import DataNode
747

758

769
class TestDSProperties:

datatree/tests/test_datatree.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,21 @@
77
from datatree.io import open_datatree
88

99

10-
def create_test_datatree():
10+
def assert_tree_equal(dt_a, dt_b):
11+
assert dt_a.name == dt_b.name
12+
assert dt_a.parent is dt_b.parent
13+
14+
assert dt_a.ds.equals(dt_b.ds)
15+
for a, b in zip(dt_a.descendants, dt_b.descendants):
16+
assert a.name == b.name
17+
assert a.pathstr == b.pathstr
18+
if a.has_data:
19+
assert a.ds.equals(b.ds)
20+
else:
21+
assert a.ds is b.ds
22+
23+
24+
def create_test_datatree(modify=lambda ds: ds):
1125
"""
1226
Create a test datatree with this structure:
1327
@@ -37,12 +51,11 @@ def create_test_datatree():
3751
The structure has deliberately repeated names of tags, variables, and
3852
dimensions in order to better check for bugs caused by name conflicts.
3953
"""
40-
set1_data = xr.Dataset({"a": 0, "b": 1})
41-
set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})
42-
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
54+
set1_data = modify(xr.Dataset({"a": 0, "b": 1}))
55+
set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}))
56+
root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}))
4357

4458
# Avoid using __init__ so we can independently test it
45-
# TODO change so it has a DataTree at the bottom
4659
root = DataNode(name="root", data=root_data)
4760
set1 = DataNode(name="set1", parent=root, data=set1_data)
4861
DataNode(name="set1", parent=set1)

0 commit comments

Comments
 (0)