Skip to content

Commit 0762f91

Browse files
committed
Updates to DataTree.equals and DataTree.identical
In contrast to `equals`, `identical` now also checks that any inherited variables are inherited on both objects. However, they do not need to be inherited from the same source. This aligns the behavior of `identical` with the DataTree `__repr__`. I've also removed the `from_root` argument from `equals` and `identical`. If a user wants to compare trees from their roots, a better (simpler) inference is to simply call these methods on the `.root` properties. I would also like to remove the `strict_names` argument, but that will require switching to use the new `zip_subtrees` (pydata#9623) first.
1 parent 33ead65 commit 0762f91

File tree

4 files changed

+135
-65
lines changed

4 files changed

+135
-65
lines changed

xarray/core/datatree.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -1250,7 +1250,7 @@ def isomorphic(
12501250
except (TypeError, TreeIsomorphismError):
12511251
return False
12521252

1253-
def equals(self, other: DataTree, from_root: bool = True) -> bool:
1253+
def equals(self, other: DataTree) -> bool:
12541254
"""
12551255
Two DataTrees are equal if they have isomorphic node structures, with matching node names,
12561256
and if they have matching variables and coordinates, all of which are equal.
@@ -1261,50 +1261,53 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool:
12611261
----------
12621262
other : DataTree
12631263
The other tree object to compare to.
1264-
from_root : bool, optional, default is True
1265-
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
1266-
If neither tree has a parent then this has no effect.
12671264
12681265
See Also
12691266
--------
12701267
Dataset.equals
12711268
DataTree.isomorphic
12721269
DataTree.identical
12731270
"""
1274-
if not self.isomorphic(other, from_root=from_root, strict_names=True):
1271+
if not self.isomorphic(other, strict_names=True):
12751272
return False
12761273

12771274
return all(
1278-
[
1279-
node.dataset.equals(other_node.dataset)
1280-
for node, other_node in zip(self.subtree, other.subtree, strict=True)
1281-
]
1275+
node.dataset.equals(other_node.dataset)
1276+
for node, other_node in zip(self.subtree, other.subtree, strict=True)
12821277
)
12831278

1284-
def identical(self, other: DataTree, from_root=True) -> bool:
1279+
def identical(self, other: DataTree) -> bool:
12851280
"""
12861281
Like equals, but will also check all dataset attributes and the attributes on
12871282
all variables and coordinates.
12881283
1289-
By default this method will check the whole tree above the given node.
1290-
12911284
Parameters
12921285
----------
12931286
other : DataTree
12941287
The other tree object to compare to.
1295-
from_root : bool, optional, default is True
1296-
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
1297-
If neither tree has a parent then this has no effect.
12981288
12991289
See Also
13001290
--------
13011291
Dataset.identical
13021292
DataTree.isomorphic
13031293
DataTree.equals
13041294
"""
1305-
if not self.isomorphic(other, from_root=from_root, strict_names=True):
1295+
if not self.isomorphic(other, strict_names=True):
1296+
return False
1297+
1298+
if self.name != other.name:
1299+
return False
1300+
1301+
# Check the root node's dataset twice, one with (below) and once without
1302+
# (here) inheritence. This ensures that even inheritance matches for
1303+
# identical DataTree objects, although inherited variables need not be
1304+
# defined at the same level.
1305+
self_ds = self._to_dataset_view(rebuild_dims=False, inherit=False)
1306+
other_ds = other._to_dataset_view(rebuild_dims=False, inherit=False)
1307+
if not self_ds.identical(other_ds):
13061308
return False
13071309

1310+
# TODO: switch to zip_subtrees, when available
13081311
return all(
13091312
node.dataset.identical(other_node.dataset)
13101313
for node, other_node in zip(self.subtree, other.subtree, strict=True)

xarray/testing/assertions.py

+4-41
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import functools
44
import warnings
55
from collections.abc import Hashable
6-
from typing import overload
76

87
import numpy as np
98
import pandas as pd
@@ -107,16 +106,8 @@ def maybe_transpose_dims(a, b, check_dim_order: bool):
107106
return b
108107

109108

110-
@overload
111-
def assert_equal(a, b): ...
112-
113-
114-
@overload
115-
def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...
116-
117-
118109
@ensure_warnings
119-
def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
110+
def assert_equal(a, b, check_dim_order: bool = True):
120111
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
121112
objects.
122113
@@ -135,10 +126,6 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
135126
or xarray.core.datatree.DataTree. The first object to compare.
136127
b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates
137128
or xarray.core.datatree.DataTree. The second object to compare.
138-
from_root : bool, optional, default is True
139-
Only used when comparing DataTree objects. Indicates whether or not to
140-
first traverse to the root of the trees before checking for isomorphism.
141-
If a & b have no parents then this has no effect.
142129
check_dim_order : bool, optional, default is True
143130
Whether dimensions must be in the same order.
144131
@@ -159,25 +146,13 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
159146
elif isinstance(a, Coordinates):
160147
assert a.equals(b), formatting.diff_coords_repr(a, b, "equals")
161148
elif isinstance(a, DataTree):
162-
if from_root:
163-
a = a.root
164-
b = b.root
165-
166-
assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals")
149+
assert a.equals(b), diff_datatree_repr(a, b, "equals")
167150
else:
168151
raise TypeError(f"{type(a)} not supported by assertion comparison")
169152

170153

171-
@overload
172-
def assert_identical(a, b): ...
173-
174-
175-
@overload
176-
def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ...
177-
178-
179154
@ensure_warnings
180-
def assert_identical(a, b, from_root=True):
155+
def assert_identical(a, b):
181156
"""Like :py:func:`xarray.testing.assert_equal`, but also matches the
182157
objects' names and attributes.
183158
@@ -193,12 +168,6 @@ def assert_identical(a, b, from_root=True):
193168
The first object to compare.
194169
b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates
195170
The second object to compare.
196-
from_root : bool, optional, default is True
197-
Only used when comparing DataTree objects. Indicates whether or not to
198-
first traverse to the root of the trees before checking for isomorphism.
199-
If a & b have no parents then this has no effect.
200-
check_dim_order : bool, optional, default is True
201-
Whether dimensions must be in the same order.
202171
203172
See Also
204173
--------
@@ -220,13 +189,7 @@ def assert_identical(a, b, from_root=True):
220189
elif isinstance(a, Coordinates):
221190
assert a.identical(b), formatting.diff_coords_repr(a, b, "identical")
222191
elif isinstance(a, DataTree):
223-
if from_root:
224-
a = a.root
225-
b = b.root
226-
227-
assert a.identical(b, from_root=from_root), diff_datatree_repr(
228-
a, b, "identical"
229-
)
192+
assert a.identical(b), diff_datatree_repr(a, b, "identical")
230193
else:
231194
raise TypeError(f"{type(a)} not supported by assertion comparison")
232195

xarray/tests/test_datatree.py

+111-7
Original file line numberDiff line numberDiff line change
@@ -1529,6 +1529,110 @@ def f(x, tree, y):
15291529
assert actual is dt and actual.attrs == attrs
15301530

15311531

1532+
class TestEqualsAndIdentical:
1533+
1534+
def test_minimal_variations(self):
1535+
tree = DataTree.from_dict(
1536+
{
1537+
"/": Dataset({"x": 1}),
1538+
"/child": Dataset({"x": 2}),
1539+
}
1540+
)
1541+
assert tree.equals(tree)
1542+
assert tree.identical(tree)
1543+
1544+
child = tree.children["child"]
1545+
assert child.equals(child)
1546+
assert child.identical(child)
1547+
1548+
new_child = DataTree(dataset=Dataset({"x": 2}), name="child")
1549+
assert child.equals(new_child)
1550+
assert child.identical(new_child)
1551+
1552+
anonymous_child = DataTree(dataset=Dataset({"x": 2}))
1553+
# TODO: re-enable this after fixing .equals() not to require matching
1554+
# names on the root node (i.e., after switching to use zip_subtrees)
1555+
# assert child.equals(anonymous_child)
1556+
assert not child.identical(anonymous_child)
1557+
1558+
different_variables = DataTree.from_dict(
1559+
{
1560+
"/": Dataset(),
1561+
"/other": Dataset({"x": 2}),
1562+
}
1563+
)
1564+
assert not tree.equals(different_variables)
1565+
assert not tree.identical(different_variables)
1566+
1567+
different_root_data = DataTree.from_dict(
1568+
{
1569+
"/": Dataset({"x": 4}),
1570+
"/child": Dataset({"x": 2}),
1571+
}
1572+
)
1573+
assert not tree.equals(different_root_data)
1574+
assert not tree.identical(different_root_data)
1575+
1576+
different_child_data = DataTree.from_dict(
1577+
{
1578+
"/": Dataset({"x": 1}),
1579+
"/child": Dataset({"x": 3}),
1580+
}
1581+
)
1582+
assert not tree.equals(different_child_data)
1583+
assert not tree.identical(different_child_data)
1584+
1585+
different_child_node_attrs = DataTree.from_dict(
1586+
{
1587+
"/": Dataset({"x": 1}),
1588+
"/child": Dataset({"x": 2}, attrs={"foo": "bar"}),
1589+
}
1590+
)
1591+
assert tree.equals(different_child_node_attrs)
1592+
assert not tree.identical(different_child_node_attrs)
1593+
1594+
different_child_variable_attrs = DataTree.from_dict(
1595+
{
1596+
"/": Dataset({"x": 1}),
1597+
"/child": Dataset({"x": ((), 2, {"foo": "bar"})}),
1598+
}
1599+
)
1600+
assert tree.equals(different_child_variable_attrs)
1601+
assert not tree.identical(different_child_variable_attrs)
1602+
1603+
different_name = DataTree.from_dict(
1604+
{
1605+
"/": Dataset({"x": 1}),
1606+
"/child": Dataset({"x": 2}),
1607+
},
1608+
name="different",
1609+
)
1610+
# TODO: re-enable this after fixing .equals() not to require matching
1611+
# names on the root node (i.e., after switching to use zip_subtrees)
1612+
# assert tree.equals(different_name)
1613+
assert not tree.identical(different_name)
1614+
1615+
def test_differently_inherited_coordinates(self):
1616+
root = DataTree.from_dict(
1617+
{
1618+
"/": Dataset(coords={"x": [1, 2]}),
1619+
"/child": Dataset(),
1620+
}
1621+
)
1622+
child = root.children["child"]
1623+
assert child.equals(child)
1624+
assert child.identical(child)
1625+
1626+
new_child = DataTree(dataset=Dataset(coords={"x": [1, 2]}), name="child")
1627+
assert child.equals(new_child)
1628+
assert not child.identical(new_child)
1629+
1630+
deeper_root = DataTree(children={"root": root})
1631+
grandchild = deeper_root["/root/child"]
1632+
assert child.equals(grandchild)
1633+
assert child.identical(grandchild)
1634+
1635+
15321636
class TestSubset:
15331637
def test_match(self):
15341638
# TODO is this example going to cause problems with case sensitivity?
@@ -1590,7 +1694,7 @@ def test_isel_siblings(self):
15901694
}
15911695
)
15921696
actual = tree.isel(x=-1)
1593-
assert_equal(actual, expected)
1697+
assert_identical(actual, expected)
15941698

15951699
expected = DataTree.from_dict(
15961700
{
@@ -1599,13 +1703,13 @@ def test_isel_siblings(self):
15991703
}
16001704
)
16011705
actual = tree.isel(x=slice(1))
1602-
assert_equal(actual, expected)
1706+
assert_identical(actual, expected)
16031707

16041708
actual = tree.isel(x=[0])
1605-
assert_equal(actual, expected)
1709+
assert_identical(actual, expected)
16061710

16071711
actual = tree.isel(x=slice(None))
1608-
assert_equal(actual, tree)
1712+
assert_identical(actual, tree)
16091713

16101714
def test_isel_inherited(self):
16111715
tree = DataTree.from_dict(
@@ -1622,15 +1726,15 @@ def test_isel_inherited(self):
16221726
}
16231727
)
16241728
actual = tree.isel(x=-1)
1625-
assert_equal(actual, expected)
1729+
assert_identical(actual, expected)
16261730

16271731
expected = DataTree.from_dict(
16281732
{
16291733
"/child": xr.Dataset({"foo": 4}),
16301734
}
16311735
)
16321736
actual = tree.isel(x=-1, drop=True)
1633-
assert_equal(actual, expected)
1737+
assert_identical(actual, expected)
16341738

16351739
expected = DataTree.from_dict(
16361740
{
@@ -1639,7 +1743,7 @@ def test_isel_inherited(self):
16391743
}
16401744
)
16411745
actual = tree.isel(x=[0])
1642-
assert_equal(actual, expected)
1746+
assert_identical(actual, expected)
16431747

16441748
actual = tree.isel(x=slice(None))
16451749
assert_equal(actual, tree)

xarray/tests/test_datatree_mapping.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def times_ten(ds):
264264

265265
expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"]
266266
result_tree = times_ten(subtree)
267-
assert_equal(result_tree, expected, from_root=False)
267+
assert_equal(result_tree, expected)
268268

269269
def test_skip_empty_nodes_with_attrs(self, create_test_datatree):
270270
# inspired by xarray-datatree GH262

0 commit comments

Comments
 (0)