Skip to content

Commit 4323b19

Browse files
authored
Ensure TreeNode doesn't copy in-place (#9482)
* test from #9196 but on TreeNode * move assignment and copying of children to TreeNode constructor * move copy methods over to TreeNode * change copying behaviour to be in line with #9196 * explicitly test that ._copy_subtree works for TreeNode * reimplement ._copy_subtree using recursion * change treenode.py tests to match expected non-in-place behaviour * fix but created in DataTree.__init__ * add type hints for Generic TreeNode back in * update typing of ._copy_node * remove redunant setting of _name
1 parent aeaa082 commit 4323b19

File tree

3 files changed

+148
-91
lines changed

3 files changed

+148
-91
lines changed

xarray/core/datatree.py

+7-58
Original file line numberDiff line numberDiff line change
@@ -447,14 +447,10 @@ def __init__(
447447
--------
448448
DataTree.from_dict
449449
"""
450-
if children is None:
451-
children = {}
452-
453-
super().__init__(name=name)
454450
self._set_node_data(_to_new_dataset(dataset))
455451

456-
# shallow copy to avoid modifying arguments in-place (see GH issue #9196)
457-
self.children = {name: child.copy() for name, child in children.items()}
452+
# comes after setting node data as this will check for clashes between child names and existing variable names
453+
super().__init__(name=name, children=children)
458454

459455
def _set_node_data(self, dataset: Dataset):
460456
data_vars, coord_vars = _collect_data_and_coord_variables(dataset)
@@ -775,67 +771,20 @@ def _replace_node(
775771

776772
self.children = children
777773

778-
def copy(
779-
self: DataTree,
780-
deep: bool = False,
781-
) -> DataTree:
782-
"""
783-
Returns a copy of this subtree.
784-
785-
Copies this node and all child nodes.
786-
787-
If `deep=True`, a deep copy is made of each of the component variables.
788-
Otherwise, a shallow copy of each of the component variable is made, so
789-
that the underlying memory region of the new datatree is the same as in
790-
the original datatree.
791-
792-
Parameters
793-
----------
794-
deep : bool, default: False
795-
Whether each component variable is loaded into memory and copied onto
796-
the new object. Default is False.
797-
798-
Returns
799-
-------
800-
object : DataTree
801-
New object with dimensions, attributes, coordinates, name, encoding,
802-
and data of this node and all child nodes copied from original.
803-
804-
See Also
805-
--------
806-
xarray.Dataset.copy
807-
pandas.DataFrame.copy
808-
"""
809-
return self._copy_subtree(deep=deep)
810-
811-
def _copy_subtree(
812-
self: DataTree,
813-
deep: bool = False,
814-
memo: dict[int, Any] | None = None,
815-
) -> DataTree:
816-
"""Copy entire subtree"""
817-
new_tree = self._copy_node(deep=deep)
818-
for node in self.descendants:
819-
path = node.relative_to(self)
820-
new_tree[path] = node._copy_node(deep=deep)
821-
return new_tree
822-
823774
def _copy_node(
824775
self: DataTree,
825776
deep: bool = False,
826777
) -> DataTree:
827778
"""Copy just one node of a tree"""
779+
780+
new_node = super()._copy_node()
781+
828782
data = self._to_dataset_view(rebuild_dims=False, inherited=False)
829783
if deep:
830784
data = data.copy(deep=True)
831-
new_node = DataTree(data, name=self.name)
832-
return new_node
833-
834-
def __copy__(self: DataTree) -> DataTree:
835-
return self._copy_subtree(deep=False)
785+
new_node._set_node_data(data)
836786

837-
def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree:
838-
return self._copy_subtree(deep=True, memo=memo)
787+
return new_node
839788

840789
def get( # type: ignore[override]
841790
self: DataTree, key: str, default: DataTree | DataArray | None = None

xarray/core/treenode.py

+75-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import PurePosixPath
66
from typing import (
77
TYPE_CHECKING,
8+
Any,
89
Generic,
910
TypeVar,
1011
)
@@ -78,8 +79,10 @@ def __init__(self, children: Mapping[str, Tree] | None = None):
7879
"""Create a parentless node."""
7980
self._parent = None
8081
self._children = {}
81-
if children is not None:
82-
self.children = children
82+
83+
if children:
84+
# shallow copy to avoid modifying arguments in-place (see GH issue #9196)
85+
self.children = {name: child.copy() for name, child in children.items()}
8386

8487
@property
8588
def parent(self) -> Tree | None:
@@ -235,6 +238,67 @@ def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None:
235238
"""Method call after attaching `children`."""
236239
pass
237240

241+
def copy(
242+
self: Tree,
243+
deep: bool = False,
244+
) -> Tree:
245+
"""
246+
Returns a copy of this subtree.
247+
248+
Copies this node and all child nodes.
249+
250+
If `deep=True`, a deep copy is made of each of the component variables.
251+
Otherwise, a shallow copy of each of the component variable is made, so
252+
that the underlying memory region of the new datatree is the same as in
253+
the original datatree.
254+
255+
Parameters
256+
----------
257+
deep : bool, default: False
258+
Whether each component variable is loaded into memory and copied onto
259+
the new object. Default is False.
260+
261+
Returns
262+
-------
263+
object : DataTree
264+
New object with dimensions, attributes, coordinates, name, encoding,
265+
and data of this node and all child nodes copied from original.
266+
267+
See Also
268+
--------
269+
xarray.Dataset.copy
270+
pandas.DataFrame.copy
271+
"""
272+
return self._copy_subtree(deep=deep)
273+
274+
def _copy_subtree(
275+
self: Tree,
276+
deep: bool = False,
277+
memo: dict[int, Any] | None = None,
278+
) -> Tree:
279+
"""Copy entire subtree recursively."""
280+
281+
new_tree = self._copy_node(deep=deep)
282+
for name, child in self.children.items():
283+
# TODO use `.children[name] = ...` once #9477 is implemented
284+
new_tree._set(name, child._copy_subtree(deep=deep))
285+
286+
return new_tree
287+
288+
def _copy_node(
289+
self: Tree,
290+
deep: bool = False,
291+
) -> Tree:
292+
"""Copy just one node of a tree"""
293+
new_empty_node = type(self)()
294+
return new_empty_node
295+
296+
def __copy__(self: Tree) -> Tree:
297+
return self._copy_subtree(deep=False)
298+
299+
def __deepcopy__(self: Tree, memo: dict[int, Any] | None = None) -> Tree:
300+
return self._copy_subtree(deep=True, memo=memo)
301+
238302
def _iter_parents(self: Tree) -> Iterator[Tree]:
239303
"""Iterate up the tree, starting from the current node's parent."""
240304
node: Tree | None = self.parent
@@ -619,6 +683,15 @@ def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None:
619683
"""Ensures child has name attribute corresponding to key under which it has been stored."""
620684
self.name = name
621685

686+
def _copy_node(
687+
self: AnyNamedNode,
688+
deep: bool = False,
689+
) -> AnyNamedNode:
690+
"""Copy just one node of a tree"""
691+
new_node = super()._copy_node()
692+
new_node._name = self.name
693+
return new_node
694+
622695
@property
623696
def path(self) -> str:
624697
"""Return the file-like path from the root to this node."""

xarray/tests/test_treenode.py

+66-31
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,28 @@ def test_forbid_setting_parent_directly(self):
6464
):
6565
mary.parent = john
6666

67+
def test_dont_modify_children_inplace(self):
68+
# GH issue 9196
69+
child: TreeNode = TreeNode()
70+
TreeNode(children={"child": child})
71+
assert child.parent is None
72+
6773
def test_multi_child_family(self):
68-
mary: TreeNode = TreeNode()
69-
kate: TreeNode = TreeNode()
70-
john: TreeNode = TreeNode(children={"Mary": mary, "Kate": kate})
71-
assert john.children["Mary"] is mary
72-
assert john.children["Kate"] is kate
74+
john: TreeNode = TreeNode(children={"Mary": TreeNode(), "Kate": TreeNode()})
75+
76+
assert "Mary" in john.children
77+
mary = john.children["Mary"]
78+
assert isinstance(mary, TreeNode)
7379
assert mary.parent is john
80+
81+
assert "Kate" in john.children
82+
kate = john.children["Kate"]
83+
assert isinstance(kate, TreeNode)
7484
assert kate.parent is john
7585

7686
def test_disown_child(self):
77-
mary: TreeNode = TreeNode()
78-
john: TreeNode = TreeNode(children={"Mary": mary})
87+
john: TreeNode = TreeNode(children={"Mary": TreeNode()})
88+
mary = john.children["Mary"]
7989
mary.orphan()
8090
assert mary.parent is None
8191
assert "Mary" not in john.children
@@ -96,29 +106,45 @@ def test_doppelganger_child(self):
96106
assert john.children["Kate"] is evil_kate
97107

98108
def test_sibling_relationships(self):
99-
mary: TreeNode = TreeNode()
100-
kate: TreeNode = TreeNode()
101-
ashley: TreeNode = TreeNode()
102-
TreeNode(children={"Mary": mary, "Kate": kate, "Ashley": ashley})
103-
assert kate.siblings["Mary"] is mary
104-
assert kate.siblings["Ashley"] is ashley
109+
john: TreeNode = TreeNode(
110+
children={"Mary": TreeNode(), "Kate": TreeNode(), "Ashley": TreeNode()}
111+
)
112+
kate = john.children["Kate"]
113+
assert list(kate.siblings) == ["Mary", "Ashley"]
105114
assert "Kate" not in kate.siblings
106115

107-
def test_ancestors(self):
116+
def test_copy_subtree(self):
108117
tony: TreeNode = TreeNode()
109118
michael: TreeNode = TreeNode(children={"Tony": tony})
110119
vito = TreeNode(children={"Michael": michael})
120+
121+
# check that children of assigned children are also copied (i.e. that ._copy_subtree works)
122+
copied_tony = vito.children["Michael"].children["Tony"]
123+
assert copied_tony is not tony
124+
125+
def test_parents(self):
126+
vito: TreeNode = TreeNode(
127+
children={"Michael": TreeNode(children={"Tony": TreeNode()})},
128+
)
129+
michael = vito.children["Michael"]
130+
tony = michael.children["Tony"]
131+
111132
assert tony.root is vito
112133
assert tony.parents == (michael, vito)
113-
assert tony.ancestors == (vito, michael, tony)
114134

115135

116136
class TestGetNodes:
117137
def test_get_child(self):
118-
steven: TreeNode = TreeNode()
119-
sue = TreeNode(children={"Steven": steven})
120-
mary = TreeNode(children={"Sue": sue})
121-
john = TreeNode(children={"Mary": mary})
138+
john: TreeNode = TreeNode(
139+
children={
140+
"Mary": TreeNode(
141+
children={"Sue": TreeNode(children={"Steven": TreeNode()})}
142+
)
143+
}
144+
)
145+
mary = john.children["Mary"]
146+
sue = mary.children["Sue"]
147+
steven = sue.children["Steven"]
122148

123149
# get child
124150
assert john._get_item("Mary") is mary
@@ -138,10 +164,14 @@ def test_get_child(self):
138164
assert mary._get_item("Sue/Steven") is steven
139165

140166
def test_get_upwards(self):
141-
sue: TreeNode = TreeNode()
142-
kate: TreeNode = TreeNode()
143-
mary = TreeNode(children={"Sue": sue, "Kate": kate})
144-
john = TreeNode(children={"Mary": mary})
167+
john: TreeNode = TreeNode(
168+
children={
169+
"Mary": TreeNode(children={"Sue": TreeNode(), "Kate": TreeNode()})
170+
}
171+
)
172+
mary = john.children["Mary"]
173+
sue = mary.children["Sue"]
174+
kate = mary.children["Kate"]
145175

146176
assert sue._get_item("../") is mary
147177
assert sue._get_item("../../") is john
@@ -150,9 +180,11 @@ def test_get_upwards(self):
150180
assert sue._get_item("../Kate") is kate
151181

152182
def test_get_from_root(self):
153-
sue: TreeNode = TreeNode()
154-
mary = TreeNode(children={"Sue": sue})
155-
john = TreeNode(children={"Mary": mary}) # noqa
183+
john: TreeNode = TreeNode(
184+
children={"Mary": TreeNode(children={"Sue": TreeNode()})}
185+
)
186+
mary = john.children["Mary"]
187+
sue = mary.children["Sue"]
156188

157189
assert sue._get_item("/Mary") is mary
158190

@@ -367,11 +399,14 @@ def test_levels(self):
367399

368400
class TestRenderTree:
369401
def test_render_nodetree(self):
370-
sam: NamedNode = NamedNode()
371-
ben: NamedNode = NamedNode()
372-
mary: NamedNode = NamedNode(children={"Sam": sam, "Ben": ben})
373-
kate: NamedNode = NamedNode()
374-
john: NamedNode = NamedNode(children={"Mary": mary, "Kate": kate})
402+
john: NamedNode = NamedNode(
403+
children={
404+
"Mary": NamedNode(children={"Sam": NamedNode(), "Ben": NamedNode()}),
405+
"Kate": NamedNode(),
406+
}
407+
)
408+
mary = john.children["Mary"]
409+
375410
expected_nodes = [
376411
"NamedNode()",
377412
"\tNamedNode('Mary')",

0 commit comments

Comments
 (0)