Skip to content

Commit 8d728bf

Browse files
ignamvTomNicholasmax-sixty
authored
Add argument check_dims to assert_allclose to allow transposed inputs (#5733) (#8991)
* Add argument check_dims to assert_allclose to allow transposed inputs * Update whats-new.rst * Add `check_dims` argument to assert_equal and assert_identical + tests * Assert that dimensions match before transposing or comparing values * Add docstring for check_dims to assert_equal and assert_identical * Update doc/whats-new.rst Co-authored-by: Tom Nicholas <[email protected]> * Undo fat finger Co-authored-by: Tom Nicholas <[email protected]> * Add attribution to whats-new.rst * Replace check_dims with bool argument check_dim_order, rename align_dims to maybe_transpose_dims * Remove left-over half-made test * Remove check_dim_order argument from assert_identical * assert_allclose/equal: emit full diff if dimensions don't match * Rename check_dim_order test, test Dataset with different dim orders * Update whats-new.rst * Hide maybe_transpose_dims from Pytest traceback Co-authored-by: Maximilian Roos <[email protected]> * Ignore mypy error due to missing functools.partial.__name__ --------- Co-authored-by: Tom Nicholas <[email protected]> Co-authored-by: Maximilian Roos <[email protected]>
1 parent c2b9429 commit 8d728bf

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ New Features
2929
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
3030
then, such as broadcasting.
3131
By `Ilan Gold <https://github.com/ilan-gold>`_.
32+
- :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`)
33+
By `Ignacio Martinez Vazquez <https://github.com/ignamv>`_.
3234
- Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg
3335
`create_index=False`. (:pull:`8960`)
3436
By `Tom Nicholas <https://github.com/TomNicholas>`_.

xarray/testing/assertions.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,18 @@ def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False):
9595
raise TypeError(f"{type(a)} not of type DataTree")
9696

9797

98+
def maybe_transpose_dims(a, b, check_dim_order: bool):
99+
"""Helper for assert_equal/allclose/identical"""
100+
__tracebackhide__ = True
101+
if not isinstance(a, (Variable, DataArray, Dataset)):
102+
return b
103+
if not check_dim_order and set(a.dims) == set(b.dims):
104+
# Ensure transpose won't fail if a dimension is missing
105+
# If this is the case, the difference will be caught by the caller
106+
return b.transpose(*a.dims)
107+
return b
108+
109+
98110
@overload
99111
def assert_equal(a, b): ...
100112

@@ -104,7 +116,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...
104116

105117

106118
@ensure_warnings
107-
def assert_equal(a, b, from_root=True):
119+
def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
108120
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
109121
objects.
110122
@@ -127,6 +139,8 @@ def assert_equal(a, b, from_root=True):
127139
Only used when comparing DataTree objects. Indicates whether or not to
128140
first traverse to the root of the trees before checking for isomorphism.
129141
If a & b have no parents then this has no effect.
142+
check_dim_order : bool, optional, default is True
143+
Whether dimensions must be in the same order.
130144
131145
See Also
132146
--------
@@ -137,6 +151,7 @@ def assert_equal(a, b, from_root=True):
137151
assert (
138152
type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates)
139153
)
154+
b = maybe_transpose_dims(a, b, check_dim_order)
140155
if isinstance(a, (Variable, DataArray)):
141156
assert a.equals(b), formatting.diff_array_repr(a, b, "equals")
142157
elif isinstance(a, Dataset):
@@ -182,6 +197,8 @@ def assert_identical(a, b, from_root=True):
182197
Only used when comparing DataTree objects. Indicates whether or not to
183198
first traverse to the root of the trees before checking for isomorphism.
184199
If a & b have no parents then this has no effect.
200+
check_dim_order : bool, optional, default is True
201+
Whether dimensions must be in the same order.
185202
186203
See Also
187204
--------
@@ -213,7 +230,9 @@ def assert_identical(a, b, from_root=True):
213230

214231

215232
@ensure_warnings
216-
def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
233+
def assert_allclose(
234+
a, b, rtol=1e-05, atol=1e-08, decode_bytes=True, check_dim_order: bool = True
235+
):
217236
"""Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects.
218237
219238
Raises an AssertionError if two objects are not equal up to desired
@@ -233,23 +252,25 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
233252
Whether byte dtypes should be decoded to strings as UTF-8 or not.
234253
This is useful for testing serialization methods on Python 3 that
235254
return saved strings as bytes.
255+
check_dim_order : bool, optional, default is True
256+
Whether dimensions must be in the same order.
236257
237258
See Also
238259
--------
239260
assert_identical, assert_equal, numpy.testing.assert_allclose
240261
"""
241262
__tracebackhide__ = True
242263
assert type(a) == type(b)
264+
b = maybe_transpose_dims(a, b, check_dim_order)
243265

244266
equiv = functools.partial(
245267
_data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes
246268
)
247-
equiv.__name__ = "allclose"
269+
equiv.__name__ = "allclose" # type: ignore[attr-defined]
248270

249271
def compat_variable(a, b):
250272
a = getattr(a, "variable", a)
251273
b = getattr(b, "variable", b)
252-
253274
return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data))
254275

255276
if isinstance(a, Variable):

xarray/tests/test_assertions.py

+19
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,25 @@ def test_allclose_regression() -> None:
5757
def test_assert_allclose(obj1, obj2) -> None:
5858
with pytest.raises(AssertionError):
5959
xr.testing.assert_allclose(obj1, obj2)
60+
with pytest.raises(AssertionError):
61+
xr.testing.assert_allclose(obj1, obj2, check_dim_order=False)
62+
63+
64+
@pytest.mark.parametrize("func", ["assert_equal", "assert_allclose"])
65+
def test_assert_allclose_equal_transpose(func) -> None:
66+
"""Transposed DataArray raises assertion unless check_dim_order=False."""
67+
obj1 = xr.DataArray([[0, 1, 2], [2, 3, 4]], dims=["a", "b"])
68+
obj2 = xr.DataArray([[0, 2], [1, 3], [2, 4]], dims=["b", "a"])
69+
with pytest.raises(AssertionError):
70+
getattr(xr.testing, func)(obj1, obj2)
71+
getattr(xr.testing, func)(obj1, obj2, check_dim_order=False)
72+
ds1 = obj1.to_dataset(name="varname")
73+
ds1["var2"] = obj1
74+
ds2 = obj1.to_dataset(name="varname")
75+
ds2["var2"] = obj1.transpose()
76+
with pytest.raises(AssertionError):
77+
getattr(xr.testing, func)(ds1, ds2)
78+
getattr(xr.testing, func)(ds1, ds2, check_dim_order=False)
6079

6180

6281
@pytest.mark.filterwarnings("error")

0 commit comments

Comments
 (0)