Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit 84d4814

Browse files
authored
Map over multiple subtrees (#32)
* pseudocode ideas for generalizing map_over_subtree * pseudocode for a generalized map_over_subtree (still only one return arg) + a new mapping.py file * pseudocode for mapping but now multiple return values * pseudocode for mapping but with multiple return values * check_isomorphism works and has tests * cleaned up the mapping tests a bit * tests for mapping over multiple trees * incorrect pseudocode attempt to map over multiple subtrees * small improvements * fixed test * zipping of multiple arguments * passes for mapping over a single tree * successfully maps over multiple trees * successfully returns multiple trees * filled out all tests * checking types now works for trees with only one node * improved docstring
1 parent 3f68eea commit 84d4814

File tree

4 files changed

+283
-70
lines changed

4 files changed

+283
-70
lines changed

datatree/datatree.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,12 @@ def __init__(
424424
else:
425425
node_path, node_name = "/", path
426426

427+
relative_path = node_path.replace(self.name, "")
428+
427429
# Create and set new node
428430
new_node = DataNode(name=node_name, data=data)
429431
self.set_node(
430-
node_path,
432+
relative_path,
431433
new_node,
432434
allow_overwrite=False,
433435
new_nodes_along_path=True,

datatree/mapping.py

+166-26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import functools
2+
from itertools import repeat
23

34
from anytree.iterators import LevelOrderIter
5+
from xarray import DataArray, Dataset
46

57
from .treenode import TreeNode
68

@@ -43,11 +45,11 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):
4345

4446
if not isinstance(subtree_a, TreeNode):
4547
raise TypeError(
46-
f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}"
48+
f"Argument `subtree_a` is not a tree, it is of type {type(subtree_a)}"
4749
)
4850
if not isinstance(subtree_b, TreeNode):
4951
raise TypeError(
50-
f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}"
52+
f"Argument `subtree_b` is not a tree, it is of type {type(subtree_b)}"
5153
)
5254

5355
# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
@@ -83,57 +85,195 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):
8385

8486
def map_over_subtree(func):
8587
"""
86-
Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees.
88+
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.
8789
88-
Applies a function to every dataset in this subtree, returning a new tree which stores the results.
90+
Applies a function to every dataset in one or more subtrees, returning new trees which store the results.
8991
90-
The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
91-
descendant nodes. The returned tree will have the same structure as the original subtree.
92+
The function will be applied to any dataset stored in any of the nodes in the trees. The returned trees will have
93+
the same structure as the supplied trees.
9294
93-
func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
94-
result will be assigned to its respective node of new tree via `DataTree.__setitem__`.
95+
`func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after
96+
mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any
97+
returned value that is one of these types will be stacked into a separate tree before returning all of them.
98+
99+
The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named
100+
similarly, but all the output trees will have nodes named in the same way as the first tree passed.
95101
96102
Parameters
97103
----------
98104
func : callable
99105
Function to apply to datasets with signature:
100-
`func(node.ds, *args, **kwargs) -> Dataset`.
101106
107+
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
108+
109+
(i.e. func must accept at least one Dataset and return at least one Dataset.)
102110
Function will not be applied to any nodes without datasets.
103111
*args : tuple, optional
104-
Positional arguments passed on to `func`.
112+
Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets \
113+
via .ds .
105114
**kwargs : Any
106-
Keyword arguments passed on to `func`.
115+
Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
116+
via .ds .
107117
108118
Returns
109119
-------
110120
mapped : callable
111-
Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node.
121+
Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at
122+
each node.
112123
113124
See also
114125
--------
115126
DataTree.map_over_subtree
116127
DataTree.map_over_subtree_inplace
128+
DataTree.subtree
117129
"""
118130

131+
# TODO examples in the docstring
132+
133+
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
134+
119135
@functools.wraps(func)
120-
def _map_over_subtree(tree, *args, **kwargs):
136+
def _map_over_subtree(*args, **kwargs):
121137
"""Internal function which maps func over every node in tree, returning a tree of the results."""
138+
from .datatree import DataTree
139+
140+
all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
141+
a for a in kwargs.values() if isinstance(a, DataTree)
142+
]
143+
144+
if len(all_tree_inputs) > 0:
145+
first_tree, *other_trees = all_tree_inputs
146+
else:
147+
raise TypeError("Must pass at least one tree object")
148+
149+
for other_tree in other_trees:
150+
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
151+
_check_isomorphic(first_tree, other_tree, require_names_equal=False)
152+
153+
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
154+
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
155+
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
156+
out_data_objects = {}
157+
args_as_tree_length_iterables = [
158+
a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
159+
]
160+
n_args = len(args_as_tree_length_iterables)
161+
kwargs_as_tree_length_iterables = {
162+
k: v.subtree if isinstance(v, DataTree) else repeat(v)
163+
for k, v in kwargs.items()
164+
}
165+
for node_of_first_tree, *all_node_args in zip(
166+
first_tree.subtree,
167+
*args_as_tree_length_iterables,
168+
*list(kwargs_as_tree_length_iterables.values()),
169+
):
170+
node_args_as_datasets = [
171+
a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]
172+
]
173+
node_kwargs_as_datasets = dict(
174+
zip(
175+
[k for k in kwargs_as_tree_length_iterables.keys()],
176+
[
177+
v.ds if isinstance(v, DataTree) else v
178+
for v in all_node_args[n_args:]
179+
],
180+
)
181+
)
122182

123-
# Recreate and act on root node
124-
from .datatree import DataNode
183+
# Now we can call func on the data in this particular set of corresponding nodes
184+
results = (
185+
func(*node_args_as_datasets, **node_kwargs_as_datasets)
186+
if node_of_first_tree.has_data
187+
else None
188+
)
125189

126-
out_tree = DataNode(name=tree.name, data=tree.ds)
127-
if out_tree.has_data:
128-
out_tree.ds = func(out_tree.ds, *args, **kwargs)
190+
# TODO implement mapping over multiple trees in-place using if conditions from here on?
191+
out_data_objects[node_of_first_tree.pathstr] = results
192+
193+
# Find out how many return values we received
194+
num_return_values = _check_all_return_values(out_data_objects)
195+
196+
# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
197+
result_trees = []
198+
for i in range(num_return_values):
199+
out_tree_contents = {}
200+
for n in first_tree.subtree:
201+
p = n.pathstr
202+
if p in out_data_objects.keys():
203+
if isinstance(out_data_objects[p], tuple):
204+
output_node_data = out_data_objects[p][i]
205+
else:
206+
output_node_data = out_data_objects[p]
207+
else:
208+
output_node_data = None
209+
out_tree_contents[p] = output_node_data
210+
211+
new_tree = DataTree(name=first_tree.name, data_objects=out_tree_contents)
212+
result_trees.append(new_tree)
213+
214+
# If only one result then don't wrap it in a tuple
215+
if len(result_trees) == 1:
216+
return result_trees[0]
217+
else:
218+
return tuple(result_trees)
129219

130-
# Act on every other node in the tree, and rebuild from results
131-
for node in tree.descendants:
132-
# TODO make a proper relative_path method
133-
relative_path = node.pathstr.replace(tree.pathstr, "")
134-
result = func(node.ds, *args, **kwargs) if node.has_data else None
135-
out_tree[relative_path] = result
220+
return _map_over_subtree
136221

137-
return out_tree
138222

139-
return _map_over_subtree
223+
def _check_single_set_return_values(path_to_node, obj):
224+
"""Check types returned from single evaluation of func, and return number of return values received from func."""
225+
if isinstance(obj, (Dataset, DataArray)):
226+
return 1
227+
elif isinstance(obj, tuple):
228+
for r in obj:
229+
if not isinstance(r, (Dataset, DataArray)):
230+
raise TypeError(
231+
f"One of the results of calling func on datasets on the nodes at position {path_to_node} is "
232+
f"of type {type(r)}, not Dataset or DataArray."
233+
)
234+
return len(obj)
235+
else:
236+
raise TypeError(
237+
f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not "
238+
f"Dataset or DataArray, nor a tuple of such types."
239+
)
240+
241+
242+
def _check_all_return_values(returned_objects):
243+
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""
244+
245+
if all(r is None for r in returned_objects.values()):
246+
raise TypeError(
247+
"Called supplied function on all nodes but found a return value of None for"
248+
"all of them."
249+
)
250+
251+
result_data_objects = [
252+
(path_to_node, r)
253+
for path_to_node, r in returned_objects.items()
254+
if r is not None
255+
]
256+
257+
if len(result_data_objects) == 1:
258+
# Only one node in the tree: no need to check consistency of results between nodes
259+
path_to_node, result = result_data_objects[0]
260+
num_return_values = _check_single_set_return_values(path_to_node, result)
261+
else:
262+
prev_path, _ = result_data_objects[0]
263+
prev_num_return_values, num_return_values = None, None
264+
for path_to_node, obj in result_data_objects[1:]:
265+
num_return_values = _check_single_set_return_values(path_to_node, obj)
266+
267+
if (
268+
num_return_values != prev_num_return_values
269+
and prev_num_return_values is not None
270+
):
271+
raise TypeError(
272+
f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return "
273+
f"values, whereas calling func on the nodes at position {prev_path} instead returns "
274+
f"{prev_num_return_values} separate return values."
275+
)
276+
277+
prev_path, prev_num_return_values = path_to_node, num_return_values
278+
279+
return num_return_values

datatree/tests/test_datatree.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88

99

1010
def assert_tree_equal(dt_a, dt_b):
11-
assert dt_a.name == dt_b.name
1211
assert dt_a.parent is dt_b.parent
1312

14-
assert dt_a.ds.equals(dt_b.ds)
15-
for a, b in zip(dt_a.descendants, dt_b.descendants):
13+
for a, b in zip(dt_a.subtree, dt_b.subtree):
1614
assert a.name == b.name
1715
assert a.pathstr == b.pathstr
1816
if a.has_data:
@@ -321,7 +319,6 @@ def test_to_netcdf(self, tmpdir):
321319
original_dt.to_netcdf(filepath, engine="netcdf4")
322320

323321
roundtrip_dt = open_datatree(filepath)
324-
325322
assert_tree_equal(original_dt, roundtrip_dt)
326323

327324
def test_to_zarr(self, tmpdir):
@@ -332,5 +329,4 @@ def test_to_zarr(self, tmpdir):
332329
original_dt.to_zarr(filepath)
333330

334331
roundtrip_dt = open_datatree(filepath, engine="zarr")
335-
336332
assert_tree_equal(original_dt, roundtrip_dt)

0 commit comments

Comments
 (0)