|
1 | 1 | import functools
|
| 2 | +from itertools import repeat |
2 | 3 |
|
3 | 4 | from anytree.iterators import LevelOrderIter
|
| 5 | +from xarray import DataArray, Dataset |
4 | 6 |
|
5 | 7 | from .treenode import TreeNode
|
6 | 8 |
|
@@ -43,11 +45,11 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):
|
43 | 45 |
|
44 | 46 | if not isinstance(subtree_a, TreeNode):
|
45 | 47 | raise TypeError(
|
46 |
| - f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}" |
| 48 | + f"Argument `subtree_a` is not a tree, it is of type {type(subtree_a)}" |
47 | 49 | )
|
48 | 50 | if not isinstance(subtree_b, TreeNode):
|
49 | 51 | raise TypeError(
|
50 |
| - f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}" |
| 52 | + f"Argument `subtree_b` is not a tree, it is of type {type(subtree_b)}" |
51 | 53 | )
|
52 | 54 |
|
53 | 55 | # Walking nodes in "level-order" fashion means walking down from the root breadth-first.
|
@@ -83,57 +85,195 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):
|
83 | 85 |
|
84 | 86 | def map_over_subtree(func):
|
85 | 87 | """
|
86 |
| - Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees. |
| 88 | + Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. |
87 | 89 |
|
88 |
| - Applies a function to every dataset in this subtree, returning a new tree which stores the results. |
| 90 | + Applies a function to every dataset in one or more subtrees, returning new trees which store the results. |
89 | 91 |
|
90 |
| - The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the |
91 |
| - descendant nodes. The returned tree will have the same structure as the original subtree. |
| 92 | + The function will be applied to any dataset stored in any of the nodes in the trees. The returned trees will have |
| 93 | + the same structure as the supplied trees. |
92 | 94 |
|
93 |
| - func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each |
94 |
| - result will be assigned to its respective node of new tree via `DataTree.__setitem__`. |
| 95 | + `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after |
| 96 | + mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any |
| 97 | + returned value that is one of these types will be stacked into a separate tree before returning all of them. |
| 98 | +
|
| 99 | + The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named |
| 100 | + similarly, but all the output trees will have nodes named in the same way as the first tree passed. |
95 | 101 |
|
96 | 102 | Parameters
|
97 | 103 | ----------
|
98 | 104 | func : callable
|
99 | 105 | Function to apply to datasets with signature:
|
100 |
| - `func(node.ds, *args, **kwargs) -> Dataset`. |
101 | 106 |
|
| 107 | + `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. |
| 108 | +
|
| 109 | + (i.e. func must accept at least one Dataset and return at least one Dataset.) |
102 | 110 | Function will not be applied to any nodes without datasets.
|
103 | 111 | *args : tuple, optional
|
104 |
| - Positional arguments passed on to `func`. |
| 112 | + Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets \ |
| 113 | + via .ds . |
105 | 114 | **kwargs : Any
|
106 |
| - Keyword arguments passed on to `func`. |
| 115 | + Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets |
| 116 | + via .ds . |
107 | 117 |
|
108 | 118 | Returns
|
109 | 119 | -------
|
110 | 120 | mapped : callable
|
111 |
| - Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node. |
| 121 | + Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at |
| 122 | + each node. |
112 | 123 |
|
113 | 124 | See also
|
114 | 125 | --------
|
115 | 126 | DataTree.map_over_subtree
|
116 | 127 | DataTree.map_over_subtree_inplace
|
| 128 | + DataTree.subtree |
117 | 129 | """
|
118 | 130 |
|
| 131 | + # TODO examples in the docstring |
| 132 | + |
| 133 | + # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? |
| 134 | + |
119 | 135 | @functools.wraps(func)
|
120 |
| - def _map_over_subtree(tree, *args, **kwargs): |
| 136 | + def _map_over_subtree(*args, **kwargs): |
121 | 137 | """Internal function which maps func over every node in tree, returning a tree of the results."""
|
| 138 | + from .datatree import DataTree |
| 139 | + |
| 140 | + all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ |
| 141 | + a for a in kwargs.values() if isinstance(a, DataTree) |
| 142 | + ] |
| 143 | + |
| 144 | + if len(all_tree_inputs) > 0: |
| 145 | + first_tree, *other_trees = all_tree_inputs |
| 146 | + else: |
| 147 | + raise TypeError("Must pass at least one tree object") |
| 148 | + |
| 149 | + for other_tree in other_trees: |
| 150 | + # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic |
| 151 | + _check_isomorphic(first_tree, other_tree, require_names_equal=False) |
| 152 | + |
| 153 | + # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees |
| 154 | + # We don't know which arguments are DataTrees so we zip all arguments together as iterables |
| 155 | + # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return |
| 156 | + out_data_objects = {} |
| 157 | + args_as_tree_length_iterables = [ |
| 158 | + a.subtree if isinstance(a, DataTree) else repeat(a) for a in args |
| 159 | + ] |
| 160 | + n_args = len(args_as_tree_length_iterables) |
| 161 | + kwargs_as_tree_length_iterables = { |
| 162 | + k: v.subtree if isinstance(v, DataTree) else repeat(v) |
| 163 | + for k, v in kwargs.items() |
| 164 | + } |
| 165 | + for node_of_first_tree, *all_node_args in zip( |
| 166 | + first_tree.subtree, |
| 167 | + *args_as_tree_length_iterables, |
| 168 | + *list(kwargs_as_tree_length_iterables.values()), |
| 169 | + ): |
| 170 | + node_args_as_datasets = [ |
| 171 | + a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args] |
| 172 | + ] |
| 173 | + node_kwargs_as_datasets = dict( |
| 174 | + zip( |
| 175 | + [k for k in kwargs_as_tree_length_iterables.keys()], |
| 176 | + [ |
| 177 | + v.ds if isinstance(v, DataTree) else v |
| 178 | + for v in all_node_args[n_args:] |
| 179 | + ], |
| 180 | + ) |
| 181 | + ) |
122 | 182 |
|
123 |
| - # Recreate and act on root node |
124 |
| - from .datatree import DataNode |
| 183 | + # Now we can call func on the data in this particular set of corresponding nodes |
| 184 | + results = ( |
| 185 | + func(*node_args_as_datasets, **node_kwargs_as_datasets) |
| 186 | + if node_of_first_tree.has_data |
| 187 | + else None |
| 188 | + ) |
125 | 189 |
|
126 |
| - out_tree = DataNode(name=tree.name, data=tree.ds) |
127 |
| - if out_tree.has_data: |
128 |
| - out_tree.ds = func(out_tree.ds, *args, **kwargs) |
| 190 | + # TODO implement mapping over multiple trees in-place using if conditions from here on? |
| 191 | + out_data_objects[node_of_first_tree.pathstr] = results |
| 192 | + |
| 193 | + # Find out how many return values we received |
| 194 | + num_return_values = _check_all_return_values(out_data_objects) |
| 195 | + |
| 196 | + # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees |
| 197 | + result_trees = [] |
| 198 | + for i in range(num_return_values): |
| 199 | + out_tree_contents = {} |
| 200 | + for n in first_tree.subtree: |
| 201 | + p = n.pathstr |
| 202 | + if p in out_data_objects.keys(): |
| 203 | + if isinstance(out_data_objects[p], tuple): |
| 204 | + output_node_data = out_data_objects[p][i] |
| 205 | + else: |
| 206 | + output_node_data = out_data_objects[p] |
| 207 | + else: |
| 208 | + output_node_data = None |
| 209 | + out_tree_contents[p] = output_node_data |
| 210 | + |
| 211 | + new_tree = DataTree(name=first_tree.name, data_objects=out_tree_contents) |
| 212 | + result_trees.append(new_tree) |
| 213 | + |
| 214 | + # If only one result then don't wrap it in a tuple |
| 215 | + if len(result_trees) == 1: |
| 216 | + return result_trees[0] |
| 217 | + else: |
| 218 | + return tuple(result_trees) |
129 | 219 |
|
130 |
| - # Act on every other node in the tree, and rebuild from results |
131 |
| - for node in tree.descendants: |
132 |
| - # TODO make a proper relative_path method |
133 |
| - relative_path = node.pathstr.replace(tree.pathstr, "") |
134 |
| - result = func(node.ds, *args, **kwargs) if node.has_data else None |
135 |
| - out_tree[relative_path] = result |
| 220 | + return _map_over_subtree |
136 | 221 |
|
137 |
| - return out_tree |
138 | 222 |
|
139 |
| - return _map_over_subtree |
| 223 | +def _check_single_set_return_values(path_to_node, obj): |
| 224 | + """Check types returned from single evaluation of func, and return number of return values received from func.""" |
| 225 | + if isinstance(obj, (Dataset, DataArray)): |
| 226 | + return 1 |
| 227 | + elif isinstance(obj, tuple): |
| 228 | + for r in obj: |
| 229 | + if not isinstance(r, (Dataset, DataArray)): |
| 230 | + raise TypeError( |
| 231 | + f"One of the results of calling func on datasets on the nodes at position {path_to_node} is " |
| 232 | + f"of type {type(r)}, not Dataset or DataArray." |
| 233 | + ) |
| 234 | + return len(obj) |
| 235 | + else: |
| 236 | + raise TypeError( |
| 237 | + f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not " |
| 238 | + f"Dataset or DataArray, nor a tuple of such types." |
| 239 | + ) |
| 240 | + |
| 241 | + |
| 242 | +def _check_all_return_values(returned_objects): |
| 243 | + """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" |
| 244 | + |
| 245 | + if all(r is None for r in returned_objects.values()): |
| 246 | + raise TypeError( |
| 247 | + "Called supplied function on all nodes but found a return value of None for" |
| 248 | + "all of them." |
| 249 | + ) |
| 250 | + |
| 251 | + result_data_objects = [ |
| 252 | + (path_to_node, r) |
| 253 | + for path_to_node, r in returned_objects.items() |
| 254 | + if r is not None |
| 255 | + ] |
| 256 | + |
| 257 | + if len(result_data_objects) == 1: |
| 258 | + # Only one node in the tree: no need to check consistency of results between nodes |
| 259 | + path_to_node, result = result_data_objects[0] |
| 260 | + num_return_values = _check_single_set_return_values(path_to_node, result) |
| 261 | + else: |
| 262 | + prev_path, _ = result_data_objects[0] |
| 263 | + prev_num_return_values, num_return_values = None, None |
| 264 | + for path_to_node, obj in result_data_objects[1:]: |
| 265 | + num_return_values = _check_single_set_return_values(path_to_node, obj) |
| 266 | + |
| 267 | + if ( |
| 268 | + num_return_values != prev_num_return_values |
| 269 | + and prev_num_return_values is not None |
| 270 | + ): |
| 271 | + raise TypeError( |
| 272 | + f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return " |
| 273 | + f"values, whereas calling func on the nodes at position {prev_path} instead returns " |
| 274 | + f"{prev_num_return_values} separate return values." |
| 275 | + ) |
| 276 | + |
| 277 | + prev_path, prev_num_return_values = path_to_node, num_return_values |
| 278 | + |
| 279 | + return num_return_values |
0 commit comments