Skip to content

Commit de3fce8

Browse files
authored
Updates to DataTree.equals and DataTree.identical (#9627)
* Updates to DataTree.equals and DataTree.identical In contrast to `equals`, `identical` now also checks that any inherited variables are inherited on both objects. However, they do not need to be inherited from the same source. This aligns the behavior of `identical` with the DataTree `__repr__`. I've also removed the `from_root` argument from `equals` and `identical`. If a user wants to compare trees from their roots, a better (simpler) inference is to simply call these methods on the `.root` properties. I would also like to remove the `strict_names` argument, but that will require switching to use the new `zip_subtrees` (#9623) first. * More efficient check for inherited coordinates
1 parent 3c01ced commit de3fce8

File tree

4 files changed

+138
-71
lines changed

4 files changed

+138
-71
lines changed

xarray/core/datatree.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -1252,61 +1252,61 @@ def isomorphic(
12521252
except (TypeError, TreeIsomorphismError):
12531253
return False
12541254

1255-
def equals(self, other: DataTree, from_root: bool = True) -> bool:
1255+
def equals(self, other: DataTree) -> bool:
12561256
"""
1257-
Two DataTrees are equal if they have isomorphic node structures, with matching node names,
1258-
and if they have matching variables and coordinates, all of which are equal.
1259-
1260-
By default this method will check the whole tree above the given node.
1257+
Two DataTrees are equal if they have isomorphic node structures, with
1258+
matching node names, and if they have matching variables and
1259+
coordinates, all of which are equal.
12611260
12621261
Parameters
12631262
----------
12641263
other : DataTree
12651264
The other tree object to compare to.
1266-
from_root : bool, optional, default is True
1267-
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
1268-
If neither tree has a parent then this has no effect.
12691265
12701266
See Also
12711267
--------
12721268
Dataset.equals
12731269
DataTree.isomorphic
12741270
DataTree.identical
12751271
"""
1276-
if not self.isomorphic(other, from_root=from_root, strict_names=True):
1272+
if not self.isomorphic(other, strict_names=True):
12771273
return False
12781274

12791275
return all(
1280-
[
1281-
node.dataset.equals(other_node.dataset)
1282-
for node, other_node in zip(self.subtree, other.subtree, strict=True)
1283-
]
1276+
node.dataset.equals(other_node.dataset)
1277+
for node, other_node in zip(self.subtree, other.subtree, strict=True)
12841278
)
12851279

1286-
def identical(self, other: DataTree, from_root=True) -> bool:
1287-
"""
1288-
Like equals, but will also check all dataset attributes and the attributes on
1289-
all variables and coordinates.
1280+
def _inherited_coords_set(self) -> set[str]:
1281+
return set(self.parent.coords if self.parent else [])
12901282

1291-
By default this method will check the whole tree above the given node.
1283+
def identical(self, other: DataTree) -> bool:
1284+
"""
1285+
Like equals, but also checks attributes on all datasets, variables and
1286+
coordinates, and requires that any inherited coordinates at the tree
1287+
root are also inherited on the other tree.
12921288
12931289
Parameters
12941290
----------
12951291
other : DataTree
12961292
The other tree object to compare to.
1297-
from_root : bool, optional, default is True
1298-
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
1299-
If neither tree has a parent then this has no effect.
13001293
13011294
See Also
13021295
--------
13031296
Dataset.identical
13041297
DataTree.isomorphic
13051298
DataTree.equals
13061299
"""
1307-
if not self.isomorphic(other, from_root=from_root, strict_names=True):
1300+
if not self.isomorphic(other, strict_names=True):
1301+
return False
1302+
1303+
if self.name != other.name:
1304+
return False
1305+
1306+
if self._inherited_coords_set() != other._inherited_coords_set():
13081307
return False
13091308

1309+
# TODO: switch to zip_subtrees, when available
13101310
return all(
13111311
node.dataset.identical(other_node.dataset)
13121312
for node, other_node in zip(self.subtree, other.subtree, strict=True)

xarray/testing/assertions.py

+4-41
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import functools
44
import warnings
55
from collections.abc import Hashable
6-
from typing import overload
76

87
import numpy as np
98
import pandas as pd
@@ -107,16 +106,8 @@ def maybe_transpose_dims(a, b, check_dim_order: bool):
107106
return b
108107

109108

110-
@overload
111-
def assert_equal(a, b): ...
112-
113-
114-
@overload
115-
def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...
116-
117-
118109
@ensure_warnings
119-
def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
110+
def assert_equal(a, b, check_dim_order: bool = True):
120111
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
121112
objects.
122113
@@ -135,10 +126,6 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
135126
or xarray.core.datatree.DataTree. The first object to compare.
136127
b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates
137128
or xarray.core.datatree.DataTree. The second object to compare.
138-
from_root : bool, optional, default is True
139-
Only used when comparing DataTree objects. Indicates whether or not to
140-
first traverse to the root of the trees before checking for isomorphism.
141-
If a & b have no parents then this has no effect.
142129
check_dim_order : bool, optional, default is True
143130
Whether dimensions must be in the same order.
144131
@@ -159,25 +146,13 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
159146
elif isinstance(a, Coordinates):
160147
assert a.equals(b), formatting.diff_coords_repr(a, b, "equals")
161148
elif isinstance(a, DataTree):
162-
if from_root:
163-
a = a.root
164-
b = b.root
165-
166-
assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals")
149+
assert a.equals(b), diff_datatree_repr(a, b, "equals")
167150
else:
168151
raise TypeError(f"{type(a)} not supported by assertion comparison")
169152

170153

171-
@overload
172-
def assert_identical(a, b): ...
173-
174-
175-
@overload
176-
def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ...
177-
178-
179154
@ensure_warnings
180-
def assert_identical(a, b, from_root=True):
155+
def assert_identical(a, b):
181156
"""Like :py:func:`xarray.testing.assert_equal`, but also matches the
182157
objects' names and attributes.
183158
@@ -193,12 +168,6 @@ def assert_identical(a, b, from_root=True):
193168
The first object to compare.
194169
b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates
195170
The second object to compare.
196-
from_root : bool, optional, default is True
197-
Only used when comparing DataTree objects. Indicates whether or not to
198-
first traverse to the root of the trees before checking for isomorphism.
199-
If a & b have no parents then this has no effect.
200-
check_dim_order : bool, optional, default is True
201-
Whether dimensions must be in the same order.
202171
203172
See Also
204173
--------
@@ -220,13 +189,7 @@ def assert_identical(a, b, from_root=True):
220189
elif isinstance(a, Coordinates):
221190
assert a.identical(b), formatting.diff_coords_repr(a, b, "identical")
222191
elif isinstance(a, DataTree):
223-
if from_root:
224-
a = a.root
225-
b = b.root
226-
227-
assert a.identical(b, from_root=from_root), diff_datatree_repr(
228-
a, b, "identical"
229-
)
192+
assert a.identical(b), diff_datatree_repr(a, b, "identical")
230193
else:
231194
raise TypeError(f"{type(a)} not supported by assertion comparison")
232195

xarray/tests/test_datatree.py

+111-7
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,110 @@ def f(x, tree, y):
15381538
assert actual is dt and actual.attrs == attrs
15391539

15401540

1541+
class TestEqualsAndIdentical:
1542+
1543+
def test_minimal_variations(self):
1544+
tree = DataTree.from_dict(
1545+
{
1546+
"/": Dataset({"x": 1}),
1547+
"/child": Dataset({"x": 2}),
1548+
}
1549+
)
1550+
assert tree.equals(tree)
1551+
assert tree.identical(tree)
1552+
1553+
child = tree.children["child"]
1554+
assert child.equals(child)
1555+
assert child.identical(child)
1556+
1557+
new_child = DataTree(dataset=Dataset({"x": 2}), name="child")
1558+
assert child.equals(new_child)
1559+
assert child.identical(new_child)
1560+
1561+
anonymous_child = DataTree(dataset=Dataset({"x": 2}))
1562+
# TODO: re-enable this after fixing .equals() not to require matching
1563+
# names on the root node (i.e., after switching to use zip_subtrees)
1564+
# assert child.equals(anonymous_child)
1565+
assert not child.identical(anonymous_child)
1566+
1567+
different_variables = DataTree.from_dict(
1568+
{
1569+
"/": Dataset(),
1570+
"/other": Dataset({"x": 2}),
1571+
}
1572+
)
1573+
assert not tree.equals(different_variables)
1574+
assert not tree.identical(different_variables)
1575+
1576+
different_root_data = DataTree.from_dict(
1577+
{
1578+
"/": Dataset({"x": 4}),
1579+
"/child": Dataset({"x": 2}),
1580+
}
1581+
)
1582+
assert not tree.equals(different_root_data)
1583+
assert not tree.identical(different_root_data)
1584+
1585+
different_child_data = DataTree.from_dict(
1586+
{
1587+
"/": Dataset({"x": 1}),
1588+
"/child": Dataset({"x": 3}),
1589+
}
1590+
)
1591+
assert not tree.equals(different_child_data)
1592+
assert not tree.identical(different_child_data)
1593+
1594+
different_child_node_attrs = DataTree.from_dict(
1595+
{
1596+
"/": Dataset({"x": 1}),
1597+
"/child": Dataset({"x": 2}, attrs={"foo": "bar"}),
1598+
}
1599+
)
1600+
assert tree.equals(different_child_node_attrs)
1601+
assert not tree.identical(different_child_node_attrs)
1602+
1603+
different_child_variable_attrs = DataTree.from_dict(
1604+
{
1605+
"/": Dataset({"x": 1}),
1606+
"/child": Dataset({"x": ((), 2, {"foo": "bar"})}),
1607+
}
1608+
)
1609+
assert tree.equals(different_child_variable_attrs)
1610+
assert not tree.identical(different_child_variable_attrs)
1611+
1612+
different_name = DataTree.from_dict(
1613+
{
1614+
"/": Dataset({"x": 1}),
1615+
"/child": Dataset({"x": 2}),
1616+
},
1617+
name="different",
1618+
)
1619+
# TODO: re-enable this after fixing .equals() not to require matching
1620+
# names on the root node (i.e., after switching to use zip_subtrees)
1621+
# assert tree.equals(different_name)
1622+
assert not tree.identical(different_name)
1623+
1624+
def test_differently_inherited_coordinates(self):
1625+
root = DataTree.from_dict(
1626+
{
1627+
"/": Dataset(coords={"x": [1, 2]}),
1628+
"/child": Dataset(),
1629+
}
1630+
)
1631+
child = root.children["child"]
1632+
assert child.equals(child)
1633+
assert child.identical(child)
1634+
1635+
new_child = DataTree(dataset=Dataset(coords={"x": [1, 2]}), name="child")
1636+
assert child.equals(new_child)
1637+
assert not child.identical(new_child)
1638+
1639+
deeper_root = DataTree(children={"root": root})
1640+
grandchild = deeper_root["/root/child"]
1641+
assert child.equals(grandchild)
1642+
assert child.identical(grandchild)
1643+
1644+
15411645
class TestSubset:
15421646
def test_match(self) -> None:
15431647
# TODO is this example going to cause problems with case sensitivity?
@@ -1599,7 +1703,7 @@ def test_isel_siblings(self) -> None:
15991703
}
16001704
)
16011705
actual = tree.isel(x=-1)
1602-
assert_equal(actual, expected)
1706+
assert_identical(actual, expected)
16031707

16041708
expected = DataTree.from_dict(
16051709
{
@@ -1608,13 +1712,13 @@ def test_isel_siblings(self) -> None:
16081712
}
16091713
)
16101714
actual = tree.isel(x=slice(1))
1611-
assert_equal(actual, expected)
1715+
assert_identical(actual, expected)
16121716

16131717
actual = tree.isel(x=[0])
1614-
assert_equal(actual, expected)
1718+
assert_identical(actual, expected)
16151719

16161720
actual = tree.isel(x=slice(None))
1617-
assert_equal(actual, tree)
1721+
assert_identical(actual, tree)
16181722

16191723
def test_isel_inherited(self) -> None:
16201724
tree = DataTree.from_dict(
@@ -1631,15 +1735,15 @@ def test_isel_inherited(self) -> None:
16311735
}
16321736
)
16331737
actual = tree.isel(x=-1)
1634-
assert_equal(actual, expected)
1738+
assert_identical(actual, expected)
16351739

16361740
expected = DataTree.from_dict(
16371741
{
16381742
"/child": xr.Dataset({"foo": 4}),
16391743
}
16401744
)
16411745
actual = tree.isel(x=-1, drop=True)
1642-
assert_equal(actual, expected)
1746+
assert_identical(actual, expected)
16431747

16441748
expected = DataTree.from_dict(
16451749
{
@@ -1648,7 +1752,7 @@ def test_isel_inherited(self) -> None:
16481752
}
16491753
)
16501754
actual = tree.isel(x=[0])
1651-
assert_equal(actual, expected)
1755+
assert_identical(actual, expected)
16521756

16531757
actual = tree.isel(x=slice(None))
16541758

xarray/tests/test_datatree_mapping.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def times_ten(ds):
264264

265265
expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"]
266266
result_tree = times_ten(subtree)
267-
assert_equal(result_tree, expected, from_root=False)
267+
assert_equal(result_tree, expected)
268268

269269
def test_skip_empty_nodes_with_attrs(self, create_test_datatree):
270270
# inspired by xarray-datatree GH262

0 commit comments

Comments
 (0)