From c857a3058eebdd1f68f6afb356406a85f8bcb59b Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 26 Aug 2021 01:18:59 -0400 Subject: [PATCH 01/10] pseudocode ideas for generalizing map_over_subtree --- datatree/datatree.py | 33 ++++++++++++++++---- datatree/tests/test_dataset_api.py | 48 +++++++++++++++++++++++++++--- datatree/tests/test_datatree.py | 17 +++++++++-- datatree/treenode.py | 2 +- 4 files changed, 87 insertions(+), 13 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index 416e1894..2ffb32e7 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -50,6 +50,15 @@ """ +def _check_trees_match(*trees): + """ + Function to check that trees have the same structure. Does not require the names (and therefore paths) of the nodes + to be equal. Also does not check the data in the nodes (but it does check that data does/doesn't exist for all nodes + at the location. + """ + ... + + def map_over_subtree(func): """ Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees. @@ -66,13 +75,13 @@ def map_over_subtree(func): ---------- func : callable Function to apply to datasets with signature: - `func(node.ds, *args, **kwargs) -> Dataset`. + `func(*args, **kwargs) -> Dataset`. Function will not be applied to any nodes without datasets. *args : tuple, optional - Positional arguments passed on to `func`. + Positional arguments passed on to `func`. Will be converted to Datasets via .ds if DataTrees. **kwargs : Any - Keyword arguments passed on to `func`. + Keyword arguments passed on to `func`. Will be converted to Datasets via .ds if DataTrees. Returns ------- @@ -86,15 +95,27 @@ def map_over_subtree(func): """ @functools.wraps(func) - def _map_over_subtree(tree, *args, **kwargs): + def _map_over_subtree(*args, **kwargs): """Internal function which maps func over every node in tree, returning a tree of the results.""" - # Recreate and act on root node + all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [a for a in args if isinstance(a, DataTree)] + first_tree = _check_trees_match(all_tree_inputs) + + args_as_datasets = [a.ds if isinstance(a, DataTree) else a for a in args] + kwargs_as_datasets = {k: v.ds if isinstance(v, DataTree) else v for k, v in kwargs} + + + # Recreate root node out_tree = DataNode(name=tree.name, data=tree.ds) + + # Act on root node if out_tree.has_data: - out_tree.ds = func(out_tree.ds, *args, **kwargs) + out_tree.ds = func(*args_as_datasets, **kwargs_as_datasets) # Act on every other node in the tree, and rebuild from results + + # TODO walk all tree arguments simultaneously, applying func to the all nodes that lie in same position in different trees + for node in tree.descendants: # TODO make a proper relative_path method relative_path = node.pathstr.replace(tree.pathstr, "") diff --git a/datatree/tests/test_dataset_api.py b/datatree/tests/test_dataset_api.py index 82d8871e..843eafff 100644 --- a/datatree/tests/test_dataset_api.py +++ b/datatree/tests/test_dataset_api.py @@ -4,11 +4,26 @@ from test_datatree import create_test_datatree from xarray.testing import assert_equal +from test_datatree import assert_tree_equal from datatree import DataNode, DataTree, map_over_subtree +class TestCheckTreesMatch: + def test_different_widths(self): + ... + + def test_different_heights(self): + ... + + def test_only_some_have_data(self): + ... + + def test_incompatible_dt_args(self): + ... + + class TestMapOverSubTree: - def test_map_over_subtree(self): + def test_single_dt_arg(self): dt = create_test_datatree() @map_over_subtree @@ -29,7 +44,7 @@ def times_ten(ds): else: assert not result_node.has_data - def test_map_over_subtree_with_args_and_kwargs(self): + def test_single_dt_arg_plus_args_and_kwargs(self): dt = create_test_datatree() @map_over_subtree @@ -49,7 +64,30 @@ def multiply_then_add(ds, times, add=0.0): else: assert not result_node.has_data - def test_map_over_subtree_method(self): + def test_multiple_dt_args(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt = DataNode("root", data=ds) + DataNode("results", data=ds + 0.2, parent=dt) + + @map_over_subtree + def add(ds1, ds2): + return ds1 + ds2 + + expected = DataNode("root", data=ds * 2) + DataNode("results", data=(ds + 0.2) * 2, parent=expected) + + result = add(dt, dt) + + #dt1 = create_test_datatree() + #dt2 = create_test_datatree() + #expected = create_test_datatree(modify=lambda ds: 2 * ds) + + assert_tree_equal(result, expected) + + def test_dt_as_kwarg(self): + ... + + def test_dt_method(self): dt = create_test_datatree() def multiply_then_add(ds, times, add=0.0): @@ -68,7 +106,9 @@ def multiply_then_add(ds, times, add=0.0): else: assert not result_node.has_data - @pytest.mark.xfail + +@pytest.mark.xfail +class TestMapOverSubTreeInplace: def test_map_over_subtree_inplace(self): raise NotImplementedError diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index b82ae984..a3f42e4d 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -7,7 +7,21 @@ from datatree.io import open_datatree -def create_test_datatree(): +def assert_tree_equal(dt_a, dt_b): + assert dt_a.name == dt_b.name + assert dt_a.parent is dt_b.parent + + assert dt_a.ds.equals(dt_b.ds) + for a, b in zip(dt_a.descendants, dt_b.descendants): + assert a.name == b.name + assert a.pathstr == b.pathstr + if a.has_data: + assert a.ds.equals(b.ds) + else: + assert a.ds is b.ds + + +def create_test_datatree(modify=lambda ds: ds): """ Create a test datatree with this structure: @@ -42,7 +56,6 @@ def create_test_datatree(): root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) # Avoid using __init__ so we can independently test it - # TODO change so it has a DataTree at the bottom root = DataNode(name="root", data=root_data) set1 = DataNode(name="set1", parent=root, data=set1_data) DataNode(name="set1", parent=set1) diff --git a/datatree/treenode.py b/datatree/treenode.py index 898ee12d..276577e7 100644 --- a/datatree/treenode.py +++ b/datatree/treenode.py @@ -84,7 +84,7 @@ def _pre_attach(self, parent: TreeNode) -> None: """ if self.name in list(c.name for c in parent.children): raise KeyError( - f"parent {str(parent)} already has a child named {self.name}" + f"parent {parent.name} already has a child named {self.name}" ) def add_child(self, child: TreeNode) -> None: From 871802abd4b32d1e6d8513cc689e3f96470fd674 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 26 Aug 2021 17:25:34 -0400 Subject: [PATCH 02/10] pseudocode for a generalized map_over_subtree (still only one return arg) + a new mapping.py file --- datatree/datatree.py | 79 +---------------- datatree/mapping.py | 154 +++++++++++++++++++++++++++++++++ datatree/tests/test_mapping.py | 0 3 files changed, 155 insertions(+), 78 deletions(-) create mode 100644 datatree/mapping.py create mode 100644 datatree/tests/test_mapping.py diff --git a/datatree/datatree.py b/datatree/datatree.py index 2ffb32e7..a999b463 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools import textwrap from typing import Any, Callable, Dict, Hashable, Iterable, List, Mapping, Union @@ -14,6 +13,7 @@ from xarray.core.ops import NAN_CUM_METHODS, NAN_REDUCE_METHODS, REDUCE_METHODS from xarray.core.variable import Variable +from .mapping import map_over_subtree from .treenode import PathType, TreeNode, _init_single_treenode """ @@ -50,83 +50,6 @@ """ -def _check_trees_match(*trees): - """ - Function to check that trees have the same structure. Does not require the names (and therefore paths) of the nodes - to be equal. Also does not check the data in the nodes (but it does check that data does/doesn't exist for all nodes - at the location. - """ - ... - - -def map_over_subtree(func): - """ - Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees. - - Applies a function to every dataset in this subtree, returning a new tree which stores the results. - - The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the - descendant nodes. The returned tree will have the same structure as the original subtree. - - func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each - result will be assigned to its respective node of new tree via `DataTree.__setitem__`. - - Parameters - ---------- - func : callable - Function to apply to datasets with signature: - `func(*args, **kwargs) -> Dataset`. - - Function will not be applied to any nodes without datasets. - *args : tuple, optional - Positional arguments passed on to `func`. Will be converted to Datasets via .ds if DataTrees. - **kwargs : Any - Keyword arguments passed on to `func`. Will be converted to Datasets via .ds if DataTrees. - - Returns - ------- - mapped : callable - Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node. - - See also - -------- - DataTree.map_over_subtree - DataTree.map_over_subtree_inplace - """ - - @functools.wraps(func) - def _map_over_subtree(*args, **kwargs): - """Internal function which maps func over every node in tree, returning a tree of the results.""" - - all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [a for a in args if isinstance(a, DataTree)] - first_tree = _check_trees_match(all_tree_inputs) - - args_as_datasets = [a.ds if isinstance(a, DataTree) else a for a in args] - kwargs_as_datasets = {k: v.ds if isinstance(v, DataTree) else v for k, v in kwargs} - - - # Recreate root node - out_tree = DataNode(name=tree.name, data=tree.ds) - - # Act on root node - if out_tree.has_data: - out_tree.ds = func(*args_as_datasets, **kwargs_as_datasets) - - # Act on every other node in the tree, and rebuild from results - - # TODO walk all tree arguments simultaneously, applying func to the all nodes that lie in same position in different trees - - for node in tree.descendants: - # TODO make a proper relative_path method - relative_path = node.pathstr.replace(tree.pathstr, "") - result = func(node.ds, *args, **kwargs) if node.has_data else None - out_tree[relative_path] = result - - return out_tree - - return _map_over_subtree - - class DatasetPropertiesMixin: """Expose properties of wrapped Dataset""" diff --git a/datatree/mapping.py b/datatree/mapping.py new file mode 100644 index 00000000..06424a27 --- /dev/null +++ b/datatree/mapping.py @@ -0,0 +1,154 @@ +import functools + +from anytree.iterators import LevelOrderIter +from xarray import Dataset, DataArray + +from .treenode import TreeNode +from .datatree import DataNode, DataTree + + +class TreeIsomorphismError(ValueError): + """Error raised if two tree objects are not isomorphic to one another when they need to be.""" + + pass + + +def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): + """ + Check that two trees have the same structure, raising an error if not. + + Does not check the actual data in the nodes, but it does check that if one node does/doesn't have data then its + counterpart in the other tree also does/doesn't have data. + + Also does not check that the root nodes of each tree have the same parent - so this function checks that subtrees + are isomorphic, not the entire tree above (if it exists). + + Can optionally check if respective nodes should have the same name. + + Parameters + ---------- + subtree_a : DataTree + subtree_b : DataTree + require_names_equal : Bool, optional + Whether or not to also check that each node has the same name as its counterpart. Default is False. + + Raises + ------ + TypeError + If either subtree_a or subtree_b are not tree objects. + TreeIsomorphismError + If subtree_a and subtree_b are tree objects, but are not isomorphic to one another, or one contains data at a + location the other does not. Also optionally raised if their structure is isomorphic, but the names of any two + respective nodes are not equal. + """ + # TODO turn this into a public function called assert_isomorphic + + for i, dt in enumerate(subtree_a, subtree_b): + if not isinstance(dt, TreeNode): + raise TypeError(f"Argument number {i+1} is not a tree, it is of type {type(dt)}") + + # Walking nodes in "level-order" fashion means walking down from the root breadth-first. + # Checking by walking in this way implicitly assumes that the tree is an ordered tree (which it is so long as + # children are stored in a tuple or list rather than in a set). + for node_a, node_b in zip(LevelOrderIter(subtree_a), LevelOrderIter(subtree_b)): + path_a, path_b = node_a.pathstr, node_b.pathstr + + if require_names_equal: + if node_a.name != node_b.name: + raise TreeIsomorphismError(f"Trees are not isomorphic because node {path_a} in the first tree has name" + f"{node_a.name}, whereas its counterpart node {path_b} in the second tree " + f"has name {node_b.name}.") + + if node_a.has_data != node_b.has_data: + dat_a = 'no ' if not node_a.has_data else '' + dat_b = 'no ' if not node_b.has_data else '' + raise TreeIsomorphismError(f"Trees are not isomorphic because node {path_a} in the first tree has " + f"{dat_a}data, whereas its counterpart node {path_b} in the second tree " + f"has {dat_b}data.") + + if len(node_a.children) != len(node_b.children): + raise TreeIsomorphismError(f"Trees are not isomorphic because node {path_a} in the first tree has " + f"{len(node_a.children)} children, whereas its counterpart node {path_b} in the " + f"second tree has {len(node_b.children)} children.") + + +def map_over_subtree(func): + """ + Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. + + Applies a function to every dataset in this subtree, returning one or more new trees which store the results. + + The function will be applied to any dataset stored in any of the nodes in the trees. The returned trees will have + the same structure as the supplied trees. + + `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after + mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any + returned value that is one of these types will be stacked into a separate tree before returning all of them. + + Parameters + ---------- + func : callable + Function to apply to datasets with signature: + `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. + + (i.e. func must accept at least one Dataset and return at least one Dataset.) + Function will not be applied to any nodes without datasets. + *args : tuple, optional + Positional arguments passed on to `func`. Will be converted to Datasets via .ds if DataTrees. + **kwargs : Any + Keyword arguments passed on to `func`. Will be converted to Datasets via .ds if DataTrees. + + Returns + ------- + mapped : callable + Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at + each node. + + See also + -------- + DataTree.map_over_subtree + DataTree.map_over_subtree_inplace + """ + + # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? + + @functools.wraps(func) + def _map_over_subtree(*args, **kwargs): + """Internal function which maps func over every node in tree, returning a tree of the results.""" + + all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ + a for a in kwargs.values() if isinstance(a, DataTree) + ] + + if len(all_tree_inputs) > 0: + first_tree, *other_trees = all_tree_inputs + else: + raise TypeError("Must pass at least one tree object") + + for other_tree in other_trees: + # isomorphism is transitive + _check_isomorphic(first_tree, other_tree, require_names_equal=False) + + # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees + out_data_objects = {} + for nodes in zip(dt.subtree for dt in all_tree_inputs): + + node_first_tree, *_ = nodes + + # TODO make a proper relative_path method + relative_path = node_first_tree.pathstr.replace(first_tree.pathstr, "") + + node_args_as_datasets = [a.ds if isinstance(a, DataTree) else a for a in args] + node_kwargs_as_datasets = { + k: v.ds if isinstance(v, DataTree) else v for k, v in kwargs + } + + # TODO should we allow mapping functions that return zero datasets? + # TODO generalise to functions that return multiple values + result = func(*node_args_as_datasets, **node_kwargs_as_datasets) if node_first_tree.has_data else None + out_data_objects[relative_path] = result + + # TODO: Possible bug - what happens if another tree argument does not have root named the same way? + return DataTree(name=first_tree.name, data_objects=out_data_objects) + + return _map_over_subtree diff --git a/datatree/tests/test_mapping.py b/datatree/tests/test_mapping.py new file mode 100644 index 00000000..e69de29b From 2b61af363dbdbcefdaa67c0384469b5a71150875 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 26 Aug 2021 22:35:14 -0400 Subject: [PATCH 03/10] pseudocode for mapping but now multiple return values --- datatree/mapping.py | 88 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 18 deletions(-) diff --git a/datatree/mapping.py b/datatree/mapping.py index 06424a27..b1c1be2c 100644 --- a/datatree/mapping.py +++ b/datatree/mapping.py @@ -4,7 +4,7 @@ from xarray import Dataset, DataArray from .treenode import TreeNode -from .datatree import DataNode, DataTree +#from datatree import DataTree class TreeIsomorphismError(ValueError): @@ -43,9 +43,10 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): """ # TODO turn this into a public function called assert_isomorphic - for i, dt in enumerate(subtree_a, subtree_b): - if not isinstance(dt, TreeNode): - raise TypeError(f"Argument number {i+1} is not a tree, it is of type {type(dt)}") + if not isinstance(subtree_a, TreeNode): + raise TypeError(f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}") + if not isinstance(subtree_b, TreeNode): + raise TypeError(f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}") # Walking nodes in "level-order" fashion means walking down from the root breadth-first. # Checking by walking in this way implicitly assumes that the tree is an ordered tree (which it is so long as @@ -55,21 +56,21 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): if require_names_equal: if node_a.name != node_b.name: - raise TreeIsomorphismError(f"Trees are not isomorphic because node {path_a} in the first tree has name" - f"{node_a.name}, whereas its counterpart node {path_b} in the second tree " - f"has name {node_b.name}.") + raise TreeIsomorphismError(f"Trees are not isomorphic because node '{path_a}' in the first tree has name" + f"'{node_a.name}', whereas its counterpart node '{path_b}' in the second " + f"tree has name '{node_b.name}'.") if node_a.has_data != node_b.has_data: dat_a = 'no ' if not node_a.has_data else '' dat_b = 'no ' if not node_b.has_data else '' - raise TreeIsomorphismError(f"Trees are not isomorphic because node {path_a} in the first tree has " - f"{dat_a}data, whereas its counterpart node {path_b} in the second tree " + raise TreeIsomorphismError(f"Trees are not isomorphic because node '{path_a}' in the first tree has " + f"{dat_a}data, whereas its counterpart node '{path_b}' in the second tree " f"has {dat_b}data.") if len(node_a.children) != len(node_b.children): - raise TreeIsomorphismError(f"Trees are not isomorphic because node {path_a} in the first tree has " - f"{len(node_a.children)} children, whereas its counterpart node {path_b} in the " - f"second tree has {len(node_b.children)} children.") + raise TreeIsomorphismError(f"Trees are not isomorphic because node '{path_a}' in the first tree has " + f"{len(node_a.children)} children, whereas its counterpart node '{path_b}' in " + f"the second tree has {len(node_b.children)} children.") def map_over_subtree(func): @@ -115,6 +116,7 @@ def map_over_subtree(func): @functools.wraps(func) def _map_over_subtree(*args, **kwargs): """Internal function which maps func over every node in tree, returning a tree of the results.""" + from .datatree import DataTree all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ a for a in kwargs.values() if isinstance(a, DataTree) @@ -143,12 +145,62 @@ def _map_over_subtree(*args, **kwargs): k: v.ds if isinstance(v, DataTree) else v for k, v in kwargs } - # TODO should we allow mapping functions that return zero datasets? - # TODO generalise to functions that return multiple values - result = func(*node_args_as_datasets, **node_kwargs_as_datasets) if node_first_tree.has_data else None - out_data_objects[relative_path] = result + results = func(*node_args_as_datasets, **node_kwargs_as_datasets) if node_first_tree.has_data else None + out_data_objects[relative_path] = results - # TODO: Possible bug - what happens if another tree argument does not have root named the same way? - return DataTree(name=first_tree.name, data_objects=out_data_objects) + # Find out how many return values we had + num_return_values = _check_return_values(out_data_objects) + + # Reconstruct potentially multiple subtrees from the dict of results + # Fill in all nodes of all result trees + result_trees = [] + for _ in range(num_return_values): + out_tree_contents = {} + for n in first_tree.subtree: + p = n.pathstr + out_tree_contents[p] = out_data_objects[p] if p in out_data_objects.keys() else None + + # TODO: Possible bug - what happens if another tree argument does not have root named the same way? + new_tree = DataTree(name=first_tree.name, data_objects=out_tree_contents) + result_trees.append(new_tree) + + # If only one result then don't wrap it in a tuple + if len(result_trees) == 1: + return next(result_trees) + else: + return tuple(result_trees) return _map_over_subtree + + +def _check_return_values(returned_objects): + """Walk through all values returned by mapping func over subtrees, raising on any invalid types or inconsistency.""" + result_data_objects = [(p, r) for p, r in returned_objects if r is not None] + + if result_data_objects is None: + raise TypeError("Called supplied function on all nodes but found a return value of None for" + "all of them.") + + prev_path, prev_obj = result_data_objects[0] + prev_num_return_values, num_return_values = None, None + for p, obj in result_data_objects[1:]: + if isinstance(obj, (Dataset, DataArray)): + num_return_values = 1 + elif isinstance(obj, tuple): + for r in enumerate(obj): + if not isinstance(r, (Dataset, DataArray)): + raise TypeError(f"One of the results of calling func on datasets on the nodes at position {p} is " + f"of type {type(r)}, not Dataset or DataArray.") + + num_return_values = len(tuple) + else: + raise TypeError(f"The result of calling func on the node at position {p} is of type {type(obj)}, not " + f"Dataset or DataArray, nor a tuple of such types.") + + if num_return_values != prev_num_return_values and prev_num_return_values is not None: + raise TypeError(f"Calling func on the nodes at position {p} returns {num_return_values}, whereas calling" + f"func on the nodes at position {prev_path} returns {prev_num_return_values}.") + + prev_path, prev_obj, prev_num_return_values = p, obj, num_return_values + + return num_return_values From 9045c23e6965be8e9cf5c05ef7401869b8a61070 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 26 Aug 2021 22:39:51 -0400 Subject: [PATCH 04/10] pseudocode for mapping but with multiple return values --- datatree/tests/test_dataset_api.py | 109 +----------------- datatree/tests/test_mapping.py | 174 +++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 108 deletions(-) diff --git a/datatree/tests/test_dataset_api.py b/datatree/tests/test_dataset_api.py index 843eafff..e930f49f 100644 --- a/datatree/tests/test_dataset_api.py +++ b/datatree/tests/test_dataset_api.py @@ -1,116 +1,9 @@ import numpy as np import pytest import xarray as xr -from test_datatree import create_test_datatree from xarray.testing import assert_equal -from test_datatree import assert_tree_equal -from datatree import DataNode, DataTree, map_over_subtree - - -class TestCheckTreesMatch: - def test_different_widths(self): - ... - - def test_different_heights(self): - ... - - def test_only_some_have_data(self): - ... - - def test_incompatible_dt_args(self): - ... - - -class TestMapOverSubTree: - def test_single_dt_arg(self): - dt = create_test_datatree() - - @map_over_subtree - def times_ten(ds): - return 10.0 * ds - - result_tree = times_ten(dt) - - # TODO write an assert_tree_equal function - for ( - result_node, - original_node, - ) in zip(result_tree.subtree, dt.subtree): - assert isinstance(result_node, DataTree) - - if original_node.has_data: - assert_equal(result_node.ds, original_node.ds * 10.0) - else: - assert not result_node.has_data - - def test_single_dt_arg_plus_args_and_kwargs(self): - dt = create_test_datatree() - - @map_over_subtree - def multiply_then_add(ds, times, add=0.0): - return times * ds + add - - result_tree = multiply_then_add(dt, 10.0, add=2.0) - - for ( - result_node, - original_node, - ) in zip(result_tree.subtree, dt.subtree): - assert isinstance(result_node, DataTree) - - if original_node.has_data: - assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0) - else: - assert not result_node.has_data - - def test_multiple_dt_args(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataNode("root", data=ds) - DataNode("results", data=ds + 0.2, parent=dt) - - @map_over_subtree - def add(ds1, ds2): - return ds1 + ds2 - - expected = DataNode("root", data=ds * 2) - DataNode("results", data=(ds + 0.2) * 2, parent=expected) - - result = add(dt, dt) - - #dt1 = create_test_datatree() - #dt2 = create_test_datatree() - #expected = create_test_datatree(modify=lambda ds: 2 * ds) - - assert_tree_equal(result, expected) - - def test_dt_as_kwarg(self): - ... - - def test_dt_method(self): - dt = create_test_datatree() - - def multiply_then_add(ds, times, add=0.0): - return times * ds + add - - result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0) - - for ( - result_node, - original_node, - ) in zip(result_tree.subtree, dt.subtree): - assert isinstance(result_node, DataTree) - - if original_node.has_data: - assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0) - else: - assert not result_node.has_data - - -@pytest.mark.xfail -class TestMapOverSubTreeInplace: - def test_map_over_subtree_inplace(self): - raise NotImplementedError +from datatree import DataNode class TestDSProperties: diff --git a/datatree/tests/test_mapping.py b/datatree/tests/test_mapping.py index e69de29b..73f5cfbe 100644 --- a/datatree/tests/test_mapping.py +++ b/datatree/tests/test_mapping.py @@ -0,0 +1,174 @@ +import pytest +import xarray as xr +from xarray.testing import assert_equal + +from datatree.datatree import DataNode, DataTree +from datatree.mapping import _check_isomorphic, TreeIsomorphismError, map_over_subtree + +from test_datatree import assert_tree_equal, create_test_datatree + + +empty = xr.Dataset() + + +class TestCheckTreesIsomorphic: + def test_not_a_tree(self): + with pytest.raises(TypeError, match="not a tree"): + _check_isomorphic('s', 1) + + def test_different_widths(self): + dt1 = DataTree(data_objects={'a': empty}) + dt2 = DataTree(data_objects={'a': empty, 'b': empty}) + expected_err_str = "'root' in the first tree has 1 children, whereas its counterpart node 'root' in the " \ + "second tree has 2 children" + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2) + + def test_different_heights(self): + dt1 = DataTree(data_objects={'a': empty}) + print(dt1) + dt2 = DataTree(data_objects={'a': empty, 'a/b': empty}) + print(dt2) + expected_err_str = "'root/a' in the first tree has 0 children, whereas its counterpart node 'root/a' in the " \ + "second tree has 1 children" + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2) + + def test_only_one_has_data(self): + dt1 = DataTree(data_objects={'/': None}) + dt2 = DataTree(data_objects={'a': empty}) + expected_err_str = "'root/a' in the first tree has data, whereas its counterpart node 'root/a' in the " \ + "second tree has no data" + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2) + + def test_names_different(self): + dt1 = DataTree(data_objects={'a': xr.Dataset()}) + dt2 = DataTree(data_objects={'b': empty}) + expected_err_str = "'root/a' in the first tree has name a, whereas its counterpart node 'root/b' in the " \ + "second tree has name b" + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2, require_names_equal=True) + + def test_isomorphic_names_equal(self): + dt1 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) + dt2 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) + expected_err_str = "'root/a' in the first tree has name a, whereas its counterpart node 'root/b' in the " \ + "second tree has name b" + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2, require_names_equal=True) + + def test_isomorphic_names_not_equal(self): + dt1 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) + dt2 = DataTree(data_objects={'A': empty, 'B': empty, 'B/C': empty, 'B/D': empty}) + expected_err_str = "'root/a' in the first tree has name a, whereas its counterpart node 'root/b' in the " \ + "second tree has name b" + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2, require_names_equal=True) + + +class TestMapOverSubTree: + def test_no_trees_passed(self): + ... + + def test_not_isomorphic(self): + ... + + def test_no_trees_returned(self): + ... + + def test_single_dt_arg(self): + dt = create_test_datatree() + + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + result_tree = times_ten(dt) + + # TODO write an assert_tree_equal function + for ( + result_node, + original_node, + ) in zip(result_tree.subtree, dt.subtree): + assert isinstance(result_node, DataTree) + + if original_node.has_data: + assert_equal(result_node.ds, original_node.ds * 10.0) + else: + assert not result_node.has_data + + def test_single_dt_arg_plus_args_and_kwargs(self): + dt = create_test_datatree() + + @map_over_subtree + def multiply_then_add(ds, times, add=0.0): + return times * ds + add + + result_tree = multiply_then_add(dt, 10.0, add=2.0) + + for ( + result_node, + original_node, + ) in zip(result_tree.subtree, dt.subtree): + assert isinstance(result_node, DataTree) + + if original_node.has_data: + assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0) + else: + assert not result_node.has_data + + def test_multiple_dt_args(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt = DataNode("root", data=ds) + DataNode("results", data=ds + 0.2, parent=dt) + + @map_over_subtree + def add(ds1, ds2): + return ds1 + ds2 + + expected = DataNode("root", data=ds * 2) + DataNode("results", data=(ds + 0.2) * 2, parent=expected) + + result = add(dt, dt) + + # dt1 = create_test_datatree() + # dt2 = create_test_datatree() + # expected = create_test_datatree(modify=lambda ds: 2 * ds) + + assert_tree_equal(result, expected) + + def test_dt_as_kwarg(self): + ... + + @pytest.mark.xfail + def test_return_multiple_dts(self): + raise NotImplementedError + + def test_return_no_dts(self): + ... + + def test_dt_method(self): + dt = create_test_datatree() + + def multiply_then_add(ds, times, add=0.0): + return times * ds + add + + result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0) + + for ( + result_node, + original_node, + ) in zip(result_tree.subtree, dt.subtree): + assert isinstance(result_node, DataTree) + + if original_node.has_data: + assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0) + else: + assert not result_node.has_data + + +@pytest.mark.xfail +class TestMapOverSubTreeInplace: + def test_map_over_subtree_inplace(self): + raise NotImplementedError From 1e4c68d2293614ab1b28a9694e13fed8c5ef4cfb Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 27 Aug 2021 10:25:09 -0400 Subject: [PATCH 05/10] check_isomorphism works and has tests --- datatree/mapping.py | 7 +++---- datatree/tests/test_mapping.py | 33 +++++++++++++++++++-------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/datatree/mapping.py b/datatree/mapping.py index b1c1be2c..8878825b 100644 --- a/datatree/mapping.py +++ b/datatree/mapping.py @@ -4,7 +4,6 @@ from xarray import Dataset, DataArray from .treenode import TreeNode -#from datatree import DataTree class TreeIsomorphismError(ValueError): @@ -56,9 +55,9 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): if require_names_equal: if node_a.name != node_b.name: - raise TreeIsomorphismError(f"Trees are not isomorphic because node '{path_a}' in the first tree has name" - f"'{node_a.name}', whereas its counterpart node '{path_b}' in the second " - f"tree has name '{node_b.name}'.") + raise TreeIsomorphismError(f"Trees are not isomorphic because node '{path_a}' in the first tree has " + f"name '{node_a.name}', whereas its counterpart node '{path_b}' in the " + f"second tree has name '{node_b.name}'.") if node_a.has_data != node_b.has_data: dat_a = 'no ' if not node_a.has_data else '' diff --git a/datatree/tests/test_mapping.py b/datatree/tests/test_mapping.py index 73f5cfbe..71c979f0 100644 --- a/datatree/tests/test_mapping.py +++ b/datatree/tests/test_mapping.py @@ -2,6 +2,7 @@ import xarray as xr from xarray.testing import assert_equal +from datatree.treenode import TreeNode from datatree.datatree import DataNode, DataTree from datatree.mapping import _check_isomorphic, TreeIsomorphismError, map_over_subtree @@ -26,17 +27,15 @@ def test_different_widths(self): def test_different_heights(self): dt1 = DataTree(data_objects={'a': empty}) - print(dt1) dt2 = DataTree(data_objects={'a': empty, 'a/b': empty}) - print(dt2) expected_err_str = "'root/a' in the first tree has 0 children, whereas its counterpart node 'root/a' in the " \ "second tree has 1 children" with pytest.raises(TreeIsomorphismError, match=expected_err_str): _check_isomorphic(dt1, dt2) def test_only_one_has_data(self): - dt1 = DataTree(data_objects={'/': None}) - dt2 = DataTree(data_objects={'a': empty}) + dt1 = DataTree(data_objects={'a': xr.Dataset({'a': 0})}) + dt2 = DataTree(data_objects={'a': None}) expected_err_str = "'root/a' in the first tree has data, whereas its counterpart node 'root/a' in the " \ "second tree has no data" with pytest.raises(TreeIsomorphismError, match=expected_err_str): @@ -45,26 +44,32 @@ def test_only_one_has_data(self): def test_names_different(self): dt1 = DataTree(data_objects={'a': xr.Dataset()}) dt2 = DataTree(data_objects={'b': empty}) - expected_err_str = "'root/a' in the first tree has name a, whereas its counterpart node 'root/b' in the " \ - "second tree has name b" + expected_err_str = "'root/a' in the first tree has name 'a', whereas its counterpart node 'root/b' in the " \ + "second tree has name 'b'" with pytest.raises(TreeIsomorphismError, match=expected_err_str): _check_isomorphic(dt1, dt2, require_names_equal=True) def test_isomorphic_names_equal(self): dt1 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) dt2 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) - expected_err_str = "'root/a' in the first tree has name a, whereas its counterpart node 'root/b' in the " \ - "second tree has name b" - with pytest.raises(TreeIsomorphismError, match=expected_err_str): - _check_isomorphic(dt1, dt2, require_names_equal=True) + _check_isomorphic(dt1, dt2, require_names_equal=True) + + def test_isomorphic_ordering(self): + dt1 = DataTree(data_objects={'a': empty, 'b': empty, 'b/d': empty, 'b/c': empty}) + dt2 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) + _check_isomorphic(dt1, dt2, require_names_equal=False) def test_isomorphic_names_not_equal(self): dt1 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) dt2 = DataTree(data_objects={'A': empty, 'B': empty, 'B/C': empty, 'B/D': empty}) - expected_err_str = "'root/a' in the first tree has name a, whereas its counterpart node 'root/b' in the " \ - "second tree has name b" - with pytest.raises(TreeIsomorphismError, match=expected_err_str): - _check_isomorphic(dt1, dt2, require_names_equal=True) + _check_isomorphic(dt1, dt2) + + def test_not_isomorphic_complex_tree(self): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + dt2.set_node('set1/set2', TreeNode('set3')) + with pytest.raises(TreeIsomorphismError, match="root/set1/set2"): + _check_isomorphic(dt1, dt2) class TestMapOverSubTree: From e14f7a913bec48131c9ae32a68d23b0bda11310d Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 27 Aug 2021 17:07:41 -0400 Subject: [PATCH 06/10] cleaned up the mapping tests a bit --- datatree/tests/test_mapping.py | 29 ++++------------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/datatree/tests/test_mapping.py b/datatree/tests/test_mapping.py index 71c979f0..c126fe58 100644 --- a/datatree/tests/test_mapping.py +++ b/datatree/tests/test_mapping.py @@ -83,45 +83,24 @@ def test_no_trees_returned(self): ... def test_single_dt_arg(self): - dt = create_test_datatree() + dt = create_test_datatree(modify=lambda ds: 10.0 * ds) @map_over_subtree def times_ten(ds): return 10.0 * ds result_tree = times_ten(dt) - - # TODO write an assert_tree_equal function - for ( - result_node, - original_node, - ) in zip(result_tree.subtree, dt.subtree): - assert isinstance(result_node, DataTree) - - if original_node.has_data: - assert_equal(result_node.ds, original_node.ds * 10.0) - else: - assert not result_node.has_data + assert_tree_equal(result_tree, dt) def test_single_dt_arg_plus_args_and_kwargs(self): - dt = create_test_datatree() + dt = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) @map_over_subtree def multiply_then_add(ds, times, add=0.0): return times * ds + add result_tree = multiply_then_add(dt, 10.0, add=2.0) - - for ( - result_node, - original_node, - ) in zip(result_tree.subtree, dt.subtree): - assert isinstance(result_node, DataTree) - - if original_node.has_data: - assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0) - else: - assert not result_node.has_data + assert_tree_equal(result_tree, dt) def test_multiple_dt_args(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) From 86e01a85931b7fc710be58e682e072740affb59f Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 27 Aug 2021 18:14:35 -0400 Subject: [PATCH 07/10] remove WIP from oter branch --- datatree/mapping.py | 122 +++++++++----------------------------------- 1 file changed, 23 insertions(+), 99 deletions(-) diff --git a/datatree/mapping.py b/datatree/mapping.py index 8878825b..690825a3 100644 --- a/datatree/mapping.py +++ b/datatree/mapping.py @@ -74,35 +74,32 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): def map_over_subtree(func): """ - Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. + Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees. - Applies a function to every dataset in this subtree, returning one or more new trees which store the results. + Applies a function to every dataset in this subtree, returning a new tree which stores the results. - The function will be applied to any dataset stored in any of the nodes in the trees. The returned trees will have - the same structure as the supplied trees. + The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the + descendant nodes. The returned tree will have the same structure as the original subtree. - `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after - mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any - returned value that is one of these types will be stacked into a separate tree before returning all of them. + func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each + result will be assigned to its respective node of new tree via `DataTree.__setitem__`. Parameters ---------- func : callable Function to apply to datasets with signature: - `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. + `func(node.ds, *args, **kwargs) -> Dataset`. - (i.e. func must accept at least one Dataset and return at least one Dataset.) Function will not be applied to any nodes without datasets. *args : tuple, optional - Positional arguments passed on to `func`. Will be converted to Datasets via .ds if DataTrees. + Positional arguments passed on to `func`. **kwargs : Any - Keyword arguments passed on to `func`. Will be converted to Datasets via .ds if DataTrees. + Keyword arguments passed on to `func`. Returns ------- mapped : callable - Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at - each node. + Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node. See also -------- @@ -110,96 +107,23 @@ def map_over_subtree(func): DataTree.map_over_subtree_inplace """ - # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? - @functools.wraps(func) - def _map_over_subtree(*args, **kwargs): + def _map_over_subtree(tree, *args, **kwargs): """Internal function which maps func over every node in tree, returning a tree of the results.""" - from .datatree import DataTree - - all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ - a for a in kwargs.values() if isinstance(a, DataTree) - ] - - if len(all_tree_inputs) > 0: - first_tree, *other_trees = all_tree_inputs - else: - raise TypeError("Must pass at least one tree object") - - for other_tree in other_trees: - # isomorphism is transitive - _check_isomorphic(first_tree, other_tree, require_names_equal=False) - # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees - out_data_objects = {} - for nodes in zip(dt.subtree for dt in all_tree_inputs): - - node_first_tree, *_ = nodes + # Recreate and act on root node + from .datatree import DataNode + out_tree = DataNode(name=tree.name, data=tree.ds) + if out_tree.has_data: + out_tree.ds = func(out_tree.ds, *args, **kwargs) + # Act on every other node in the tree, and rebuild from results + for node in tree.descendants: # TODO make a proper relative_path method - relative_path = node_first_tree.pathstr.replace(first_tree.pathstr, "") - - node_args_as_datasets = [a.ds if isinstance(a, DataTree) else a for a in args] - node_kwargs_as_datasets = { - k: v.ds if isinstance(v, DataTree) else v for k, v in kwargs - } - - results = func(*node_args_as_datasets, **node_kwargs_as_datasets) if node_first_tree.has_data else None - out_data_objects[relative_path] = results - - # Find out how many return values we had - num_return_values = _check_return_values(out_data_objects) - - # Reconstruct potentially multiple subtrees from the dict of results - # Fill in all nodes of all result trees - result_trees = [] - for _ in range(num_return_values): - out_tree_contents = {} - for n in first_tree.subtree: - p = n.pathstr - out_tree_contents[p] = out_data_objects[p] if p in out_data_objects.keys() else None - - # TODO: Possible bug - what happens if another tree argument does not have root named the same way? - new_tree = DataTree(name=first_tree.name, data_objects=out_tree_contents) - result_trees.append(new_tree) - - # If only one result then don't wrap it in a tuple - if len(result_trees) == 1: - return next(result_trees) - else: - return tuple(result_trees) - - return _map_over_subtree + relative_path = node.pathstr.replace(tree.pathstr, "") + result = func(node.ds, *args, **kwargs) if node.has_data else None + out_tree[relative_path] = result + return out_tree -def _check_return_values(returned_objects): - """Walk through all values returned by mapping func over subtrees, raising on any invalid types or inconsistency.""" - result_data_objects = [(p, r) for p, r in returned_objects if r is not None] - - if result_data_objects is None: - raise TypeError("Called supplied function on all nodes but found a return value of None for" - "all of them.") - - prev_path, prev_obj = result_data_objects[0] - prev_num_return_values, num_return_values = None, None - for p, obj in result_data_objects[1:]: - if isinstance(obj, (Dataset, DataArray)): - num_return_values = 1 - elif isinstance(obj, tuple): - for r in enumerate(obj): - if not isinstance(r, (Dataset, DataArray)): - raise TypeError(f"One of the results of calling func on datasets on the nodes at position {p} is " - f"of type {type(r)}, not Dataset or DataArray.") - - num_return_values = len(tuple) - else: - raise TypeError(f"The result of calling func on the node at position {p} is of type {type(obj)}, not " - f"Dataset or DataArray, nor a tuple of such types.") - - if num_return_values != prev_num_return_values and prev_num_return_values is not None: - raise TypeError(f"Calling func on the nodes at position {p} returns {num_return_values}, whereas calling" - f"func on the nodes at position {prev_path} returns {prev_num_return_values}.") - - prev_path, prev_obj, prev_num_return_values = p, obj, num_return_values - - return num_return_values + return _map_over_subtree From 2bd80f8e573f1bb21ebfc401678c134f2d9d43d9 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 27 Aug 2021 18:14:56 -0400 Subject: [PATCH 08/10] ensure tests pass --- datatree/tests/test_datatree.py | 6 +++--- datatree/tests/test_mapping.py | 26 +++++++++++++++++--------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 404465f9..f13a7f3c 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -51,9 +51,9 @@ def create_test_datatree(modify=lambda ds: ds): The structure has deliberately repeated names of tags, variables, and dimensions in order to better check for bugs caused by name conflicts. """ - set1_data = xr.Dataset({"a": 0, "b": 1}) - set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) - root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = modify(xr.Dataset({"a": 0, "b": 1})) + set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) + root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) # Avoid using __init__ so we can independently test it root = DataNode(name="root", data=root_data) diff --git a/datatree/tests/test_mapping.py b/datatree/tests/test_mapping.py index c126fe58..6f70c66b 100644 --- a/datatree/tests/test_mapping.py +++ b/datatree/tests/test_mapping.py @@ -73,35 +73,41 @@ def test_not_isomorphic_complex_tree(self): class TestMapOverSubTree: + @pytest.mark.xfail def test_no_trees_passed(self): - ... + raise NotImplementedError + @pytest.mark.xfail def test_not_isomorphic(self): - ... + raise NotImplementedError + @pytest.mark.xfail def test_no_trees_returned(self): - ... + raise NotImplementedError def test_single_dt_arg(self): - dt = create_test_datatree(modify=lambda ds: 10.0 * ds) + dt = create_test_datatree() @map_over_subtree def times_ten(ds): return 10.0 * ds result_tree = times_ten(dt) - assert_tree_equal(result_tree, dt) + expected = create_test_datatree(modify=lambda ds: 10.0 * ds) + assert_tree_equal(result_tree, expected) def test_single_dt_arg_plus_args_and_kwargs(self): - dt = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) + dt = create_test_datatree() @map_over_subtree def multiply_then_add(ds, times, add=0.0): return times * ds + add result_tree = multiply_then_add(dt, 10.0, add=2.0) - assert_tree_equal(result_tree, dt) + expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) + assert_tree_equal(result_tree, expected) + @pytest.mark.xfail def test_multiple_dt_args(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataNode("root", data=ds) @@ -122,15 +128,17 @@ def add(ds1, ds2): assert_tree_equal(result, expected) + @pytest.mark.xfail def test_dt_as_kwarg(self): - ... + raise NotImplementedError @pytest.mark.xfail def test_return_multiple_dts(self): raise NotImplementedError + @pytest.mark.xfail def test_return_no_dts(self): - ... + raise NotImplementedError def test_dt_method(self): dt = create_test_datatree() From 78b35167bb050993cbc177f165ef41f6fcee2618 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 27 Aug 2021 18:25:37 -0400 Subject: [PATCH 09/10] map_over_subtree in the public API properly --- datatree/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datatree/__init__.py b/datatree/__init__.py index f83edbb0..fbe1cba7 100644 --- a/datatree/__init__.py +++ b/datatree/__init__.py @@ -1,4 +1,5 @@ # flake8: noqa # Ignoring F401: imported but unused -from .datatree import DataNode, DataTree, map_over_subtree +from .datatree import DataNode, DataTree from .io import open_datatree +from .mapping import map_over_subtree From 84ee174805abde8ca36ae6122e05be389667673a Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 27 Aug 2021 18:26:36 -0400 Subject: [PATCH 10/10] linting --- datatree/mapping.py | 38 ++++++++++------- datatree/tests/test_mapping.py | 76 +++++++++++++++++++++------------- 2 files changed, 71 insertions(+), 43 deletions(-) diff --git a/datatree/mapping.py b/datatree/mapping.py index 690825a3..b0ff2b22 100644 --- a/datatree/mapping.py +++ b/datatree/mapping.py @@ -1,7 +1,6 @@ import functools from anytree.iterators import LevelOrderIter -from xarray import Dataset, DataArray from .treenode import TreeNode @@ -43,9 +42,13 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): # TODO turn this into a public function called assert_isomorphic if not isinstance(subtree_a, TreeNode): - raise TypeError(f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}") + raise TypeError( + f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}" + ) if not isinstance(subtree_b, TreeNode): - raise TypeError(f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}") + raise TypeError( + f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}" + ) # Walking nodes in "level-order" fashion means walking down from the root breadth-first. # Checking by walking in this way implicitly assumes that the tree is an ordered tree (which it is so long as @@ -55,21 +58,27 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): if require_names_equal: if node_a.name != node_b.name: - raise TreeIsomorphismError(f"Trees are not isomorphic because node '{path_a}' in the first tree has " - f"name '{node_a.name}', whereas its counterpart node '{path_b}' in the " - f"second tree has name '{node_b.name}'.") + raise TreeIsomorphismError( + f"Trees are not isomorphic because node '{path_a}' in the first tree has " + f"name '{node_a.name}', whereas its counterpart node '{path_b}' in the " + f"second tree has name '{node_b.name}'." + ) if node_a.has_data != node_b.has_data: - dat_a = 'no ' if not node_a.has_data else '' - dat_b = 'no ' if not node_b.has_data else '' - raise TreeIsomorphismError(f"Trees are not isomorphic because node '{path_a}' in the first tree has " - f"{dat_a}data, whereas its counterpart node '{path_b}' in the second tree " - f"has {dat_b}data.") + dat_a = "no " if not node_a.has_data else "" + dat_b = "no " if not node_b.has_data else "" + raise TreeIsomorphismError( + f"Trees are not isomorphic because node '{path_a}' in the first tree has " + f"{dat_a}data, whereas its counterpart node '{path_b}' in the second tree " + f"has {dat_b}data." + ) if len(node_a.children) != len(node_b.children): - raise TreeIsomorphismError(f"Trees are not isomorphic because node '{path_a}' in the first tree has " - f"{len(node_a.children)} children, whereas its counterpart node '{path_b}' in " - f"the second tree has {len(node_b.children)} children.") + raise TreeIsomorphismError( + f"Trees are not isomorphic because node '{path_a}' in the first tree has " + f"{len(node_a.children)} children, whereas its counterpart node '{path_b}' in " + f"the second tree has {len(node_b.children)} children." + ) def map_over_subtree(func): @@ -113,6 +122,7 @@ def _map_over_subtree(tree, *args, **kwargs): # Recreate and act on root node from .datatree import DataNode + out_tree = DataNode(name=tree.name, data=tree.ds) if out_tree.has_data: out_tree.ds = func(out_tree.ds, *args, **kwargs) diff --git a/datatree/tests/test_mapping.py b/datatree/tests/test_mapping.py index 6f70c66b..da2ad8be 100644 --- a/datatree/tests/test_mapping.py +++ b/datatree/tests/test_mapping.py @@ -1,13 +1,11 @@ import pytest import xarray as xr +from test_datatree import assert_tree_equal, create_test_datatree from xarray.testing import assert_equal -from datatree.treenode import TreeNode from datatree.datatree import DataNode, DataTree -from datatree.mapping import _check_isomorphic, TreeIsomorphismError, map_over_subtree - -from test_datatree import assert_tree_equal, create_test_datatree - +from datatree.mapping import TreeIsomorphismError, _check_isomorphic, map_over_subtree +from datatree.treenode import TreeNode empty = xr.Dataset() @@ -15,59 +13,79 @@ class TestCheckTreesIsomorphic: def test_not_a_tree(self): with pytest.raises(TypeError, match="not a tree"): - _check_isomorphic('s', 1) + _check_isomorphic("s", 1) def test_different_widths(self): - dt1 = DataTree(data_objects={'a': empty}) - dt2 = DataTree(data_objects={'a': empty, 'b': empty}) - expected_err_str = "'root' in the first tree has 1 children, whereas its counterpart node 'root' in the " \ - "second tree has 2 children" + dt1 = DataTree(data_objects={"a": empty}) + dt2 = DataTree(data_objects={"a": empty, "b": empty}) + expected_err_str = ( + "'root' in the first tree has 1 children, whereas its counterpart node 'root' in the " + "second tree has 2 children" + ) with pytest.raises(TreeIsomorphismError, match=expected_err_str): _check_isomorphic(dt1, dt2) def test_different_heights(self): - dt1 = DataTree(data_objects={'a': empty}) - dt2 = DataTree(data_objects={'a': empty, 'a/b': empty}) - expected_err_str = "'root/a' in the first tree has 0 children, whereas its counterpart node 'root/a' in the " \ - "second tree has 1 children" + dt1 = DataTree(data_objects={"a": empty}) + dt2 = DataTree(data_objects={"a": empty, "a/b": empty}) + expected_err_str = ( + "'root/a' in the first tree has 0 children, whereas its counterpart node 'root/a' in the " + "second tree has 1 children" + ) with pytest.raises(TreeIsomorphismError, match=expected_err_str): _check_isomorphic(dt1, dt2) def test_only_one_has_data(self): - dt1 = DataTree(data_objects={'a': xr.Dataset({'a': 0})}) - dt2 = DataTree(data_objects={'a': None}) - expected_err_str = "'root/a' in the first tree has data, whereas its counterpart node 'root/a' in the " \ - "second tree has no data" + dt1 = DataTree(data_objects={"a": xr.Dataset({"a": 0})}) + dt2 = DataTree(data_objects={"a": None}) + expected_err_str = ( + "'root/a' in the first tree has data, whereas its counterpart node 'root/a' in the " + "second tree has no data" + ) with pytest.raises(TreeIsomorphismError, match=expected_err_str): _check_isomorphic(dt1, dt2) def test_names_different(self): - dt1 = DataTree(data_objects={'a': xr.Dataset()}) - dt2 = DataTree(data_objects={'b': empty}) - expected_err_str = "'root/a' in the first tree has name 'a', whereas its counterpart node 'root/b' in the " \ - "second tree has name 'b'" + dt1 = DataTree(data_objects={"a": xr.Dataset()}) + dt2 = DataTree(data_objects={"b": empty}) + expected_err_str = ( + "'root/a' in the first tree has name 'a', whereas its counterpart node 'root/b' in the " + "second tree has name 'b'" + ) with pytest.raises(TreeIsomorphismError, match=expected_err_str): _check_isomorphic(dt1, dt2, require_names_equal=True) def test_isomorphic_names_equal(self): - dt1 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) - dt2 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) + dt1 = DataTree( + data_objects={"a": empty, "b": empty, "b/c": empty, "b/d": empty} + ) + dt2 = DataTree( + data_objects={"a": empty, "b": empty, "b/c": empty, "b/d": empty} + ) _check_isomorphic(dt1, dt2, require_names_equal=True) def test_isomorphic_ordering(self): - dt1 = DataTree(data_objects={'a': empty, 'b': empty, 'b/d': empty, 'b/c': empty}) - dt2 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) + dt1 = DataTree( + data_objects={"a": empty, "b": empty, "b/d": empty, "b/c": empty} + ) + dt2 = DataTree( + data_objects={"a": empty, "b": empty, "b/c": empty, "b/d": empty} + ) _check_isomorphic(dt1, dt2, require_names_equal=False) def test_isomorphic_names_not_equal(self): - dt1 = DataTree(data_objects={'a': empty, 'b': empty, 'b/c': empty, 'b/d': empty}) - dt2 = DataTree(data_objects={'A': empty, 'B': empty, 'B/C': empty, 'B/D': empty}) + dt1 = DataTree( + data_objects={"a": empty, "b": empty, "b/c": empty, "b/d": empty} + ) + dt2 = DataTree( + data_objects={"A": empty, "B": empty, "B/C": empty, "B/D": empty} + ) _check_isomorphic(dt1, dt2) def test_not_isomorphic_complex_tree(self): dt1 = create_test_datatree() dt2 = create_test_datatree() - dt2.set_node('set1/set2', TreeNode('set3')) + dt2.set_node("set1/set2", TreeNode("set3")) with pytest.raises(TreeIsomorphismError, match="root/set1/set2"): _check_isomorphic(dt1, dt2)