|
| 1 | +import functools |
| 2 | + |
| 3 | +from anytree.iterators import LevelOrderIter |
| 4 | + |
| 5 | +from .treenode import TreeNode |
| 6 | + |
| 7 | + |
| 8 | +class TreeIsomorphismError(ValueError): |
| 9 | + """Error raised if two tree objects are not isomorphic to one another when they need to be.""" |
| 10 | + |
| 11 | + pass |
| 12 | + |
| 13 | + |
| 14 | +def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): |
| 15 | + """ |
| 16 | + Check that two trees have the same structure, raising an error if not. |
| 17 | +
|
| 18 | + Does not check the actual data in the nodes, but it does check that if one node does/doesn't have data then its |
| 19 | + counterpart in the other tree also does/doesn't have data. |
| 20 | +
|
| 21 | + Also does not check that the root nodes of each tree have the same parent - so this function checks that subtrees |
| 22 | + are isomorphic, not the entire tree above (if it exists). |
| 23 | +
|
| 24 | + Can optionally check if respective nodes should have the same name. |
| 25 | +
|
| 26 | + Parameters |
| 27 | + ---------- |
| 28 | + subtree_a : DataTree |
| 29 | + subtree_b : DataTree |
| 30 | + require_names_equal : Bool, optional |
| 31 | + Whether or not to also check that each node has the same name as its counterpart. Default is False. |
| 32 | +
|
| 33 | + Raises |
| 34 | + ------ |
| 35 | + TypeError |
| 36 | + If either subtree_a or subtree_b are not tree objects. |
| 37 | + TreeIsomorphismError |
| 38 | + If subtree_a and subtree_b are tree objects, but are not isomorphic to one another, or one contains data at a |
| 39 | + location the other does not. Also optionally raised if their structure is isomorphic, but the names of any two |
| 40 | + respective nodes are not equal. |
| 41 | + """ |
| 42 | + # TODO turn this into a public function called assert_isomorphic |
| 43 | + |
| 44 | + if not isinstance(subtree_a, TreeNode): |
| 45 | + raise TypeError( |
| 46 | + f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}" |
| 47 | + ) |
| 48 | + if not isinstance(subtree_b, TreeNode): |
| 49 | + raise TypeError( |
| 50 | + f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}" |
| 51 | + ) |
| 52 | + |
| 53 | + # Walking nodes in "level-order" fashion means walking down from the root breadth-first. |
| 54 | + # Checking by walking in this way implicitly assumes that the tree is an ordered tree (which it is so long as |
| 55 | + # children are stored in a tuple or list rather than in a set). |
| 56 | + for node_a, node_b in zip(LevelOrderIter(subtree_a), LevelOrderIter(subtree_b)): |
| 57 | + path_a, path_b = node_a.pathstr, node_b.pathstr |
| 58 | + |
| 59 | + if require_names_equal: |
| 60 | + if node_a.name != node_b.name: |
| 61 | + raise TreeIsomorphismError( |
| 62 | + f"Trees are not isomorphic because node '{path_a}' in the first tree has " |
| 63 | + f"name '{node_a.name}', whereas its counterpart node '{path_b}' in the " |
| 64 | + f"second tree has name '{node_b.name}'." |
| 65 | + ) |
| 66 | + |
| 67 | + if node_a.has_data != node_b.has_data: |
| 68 | + dat_a = "no " if not node_a.has_data else "" |
| 69 | + dat_b = "no " if not node_b.has_data else "" |
| 70 | + raise TreeIsomorphismError( |
| 71 | + f"Trees are not isomorphic because node '{path_a}' in the first tree has " |
| 72 | + f"{dat_a}data, whereas its counterpart node '{path_b}' in the second tree " |
| 73 | + f"has {dat_b}data." |
| 74 | + ) |
| 75 | + |
| 76 | + if len(node_a.children) != len(node_b.children): |
| 77 | + raise TreeIsomorphismError( |
| 78 | + f"Trees are not isomorphic because node '{path_a}' in the first tree has " |
| 79 | + f"{len(node_a.children)} children, whereas its counterpart node '{path_b}' in " |
| 80 | + f"the second tree has {len(node_b.children)} children." |
| 81 | + ) |
| 82 | + |
| 83 | + |
| 84 | +def map_over_subtree(func): |
| 85 | + """ |
| 86 | + Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees. |
| 87 | +
|
| 88 | + Applies a function to every dataset in this subtree, returning a new tree which stores the results. |
| 89 | +
|
| 90 | + The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the |
| 91 | + descendant nodes. The returned tree will have the same structure as the original subtree. |
| 92 | +
|
| 93 | + func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each |
| 94 | + result will be assigned to its respective node of new tree via `DataTree.__setitem__`. |
| 95 | +
|
| 96 | + Parameters |
| 97 | + ---------- |
| 98 | + func : callable |
| 99 | + Function to apply to datasets with signature: |
| 100 | + `func(node.ds, *args, **kwargs) -> Dataset`. |
| 101 | +
|
| 102 | + Function will not be applied to any nodes without datasets. |
| 103 | + *args : tuple, optional |
| 104 | + Positional arguments passed on to `func`. |
| 105 | + **kwargs : Any |
| 106 | + Keyword arguments passed on to `func`. |
| 107 | +
|
| 108 | + Returns |
| 109 | + ------- |
| 110 | + mapped : callable |
| 111 | + Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node. |
| 112 | +
|
| 113 | + See also |
| 114 | + -------- |
| 115 | + DataTree.map_over_subtree |
| 116 | + DataTree.map_over_subtree_inplace |
| 117 | + """ |
| 118 | + |
| 119 | + @functools.wraps(func) |
| 120 | + def _map_over_subtree(tree, *args, **kwargs): |
| 121 | + """Internal function which maps func over every node in tree, returning a tree of the results.""" |
| 122 | + |
| 123 | + # Recreate and act on root node |
| 124 | + from .datatree import DataNode |
| 125 | + |
| 126 | + out_tree = DataNode(name=tree.name, data=tree.ds) |
| 127 | + if out_tree.has_data: |
| 128 | + out_tree.ds = func(out_tree.ds, *args, **kwargs) |
| 129 | + |
| 130 | + # Act on every other node in the tree, and rebuild from results |
| 131 | + for node in tree.descendants: |
| 132 | + # TODO make a proper relative_path method |
| 133 | + relative_path = node.pathstr.replace(tree.pathstr, "") |
| 134 | + result = func(node.ds, *args, **kwargs) if node.has_data else None |
| 135 | + out_tree[relative_path] = result |
| 136 | + |
| 137 | + return out_tree |
| 138 | + |
| 139 | + return _map_over_subtree |
0 commit comments