Skip to content

WIP: testing.assert_* check dtype #4760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4ea8703
WIP: maybe check dtype in equals & identical
mathause Jan 3, 2021
36ffb4c
Merge branch 'master' into assert_check_dtype
mathause Jan 3, 2021
72611e4
Merge branch 'master' into assert_check_dtype
mathause Jan 4, 2021
f9764cd
add check_dtype to allclose_or_equiv
mathause Jan 4, 2021
3139c0b
replace kwargs
mathause Jan 4, 2021
e3e7ac7
default: False
mathause Jan 4, 2021
82289a0
Merge branch 'master' into assert_check_dtype
mathause Jan 4, 2021
8c6d65a
Merge branch 'master' into assert_check_dtype
mathause Jan 9, 2021
c7d9d2a
Merge branch 'master' into assert_check_dtype
mathause Jan 12, 2021
408c5bd
Merge branch 'master' into assert_check_dtype
mathause Jan 16, 2021
d8e78f7
correctly do the lazy dtype check
mathause Jan 17, 2021
5ced90c
add docstrings
mathause Jan 19, 2021
63ace8c
Merge branch 'master' into assert_check_dtype
mathause Jan 28, 2021
54f9701
add kwarg to tests version of assert_*
mathause Feb 2, 2021
fd27add
update comment
mathause Feb 2, 2021
73338e7
return bool and not np.bool_
mathause Feb 2, 2021
788082f
add first tests
mathause Feb 2, 2021
a0f686e
Merge branch 'master' into assert_check_dtype
mathause Feb 2, 2021
beca29a
also test the default
mathause Feb 4, 2021
b515dbc
Merge branch 'master' into assert_check_dtype
mathause Feb 8, 2021
26bc455
Merge branch 'master' into assert_check_dtype
mathause Feb 8, 2021
9fa8830
Merge branch 'master' into assert_check_dtype
mathause Feb 26, 2021
7c93f2d
Merge branch 'master' into assert_check_dtype
mathause Feb 26, 2021
6f8e885
updates
mathause Feb 26, 2021
51b7de5
test_formatting
mathause Mar 1, 2021
5578b6e
Merge branch 'master' into assert_check_dtype
mathause Apr 25, 2021
acacfb0
Merge branch 'master' into assert_check_dtype
mathause May 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2907,32 +2907,40 @@ def from_iris(cls, cube: "iris_Cube") -> "DataArray":

return from_iris(cube)

def _all_compat(self, other: "DataArray", compat_str: str) -> bool:
def _all_compat(
self, other: "DataArray", compat_str: str, check_dtype: bool = False
) -> bool:
"""Helper function for equals, broadcast_equals, and identical"""

def compat(x, y):
return getattr(x.variable, compat_str)(y.variable)
return getattr(x.variable, compat_str)(y.variable, check_dtype=check_dtype)

return utils.dict_equiv(self.coords, other.coords, compat=compat) and compat(
self, other
)

def broadcast_equals(self, other: "DataArray") -> bool:
def broadcast_equals(self, other: "DataArray", check_dtype: bool = False) -> bool:
"""Two DataArrays are broadcast equal if they are equal after
broadcasting them against each other such that they have the same
dimensions.

Parameters
----------
check_dtype : bool, default: False
Whether to check if the objects' dtypes are identical. Compares the
dtypes of the data and the coords.

See Also
--------
DataArray.equals
DataArray.identical
"""
try:
return self._all_compat(other, "broadcast_equals")
return self._all_compat(other, "broadcast_equals", check_dtype=check_dtype)
except (TypeError, AttributeError):
return False

def equals(self, other: "DataArray") -> bool:
def equals(self, other: "DataArray", check_dtype: bool = False) -> bool:
"""True if two DataArrays have the same dimensions, coordinates and
values; otherwise False.

Expand All @@ -2942,27 +2950,41 @@ def equals(self, other: "DataArray") -> bool:
This method is necessary because `v1 == v2` for ``DataArray``
does element-wise comparisons (like numpy.ndarrays).

Parameters
----------
check_dtype : bool, default: False
Whether to check if the objects' dtypes are identical. Compares the
dtypes of the data and the coords.

See Also
--------
DataArray.broadcast_equals
DataArray.identical
"""
try:
return self._all_compat(other, "equals")
return self._all_compat(other, "equals", check_dtype=check_dtype)
except (TypeError, AttributeError):
return False

def identical(self, other: "DataArray") -> bool:
def identical(self, other: "DataArray", check_dtype: bool = False) -> bool:
"""Like equals, but also checks the array name and attributes, and
attributes on all coordinates.

Parameters
----------
check_dtype : bool, default: False
Whether to check if the objects' dtypes are identical. Compares the
dtypes of the data and the coords.

See Also
--------
DataArray.broadcast_equals
DataArray.equals
"""
try:
return self.name == other.name and self._all_compat(other, "identical")
return self.name == other.name and self._all_compat(
other, "identical", check_dtype=check_dtype
)
except (TypeError, AttributeError):
return False

Expand Down
34 changes: 26 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,37 +1537,43 @@ def __delitem__(self, key: Hashable) -> None:
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore[assignment]

def _all_compat(self, other: "Dataset", compat_str: str) -> bool:
def _all_compat(self, other: "Dataset", compat_str: str, check_dtype: bool) -> bool:
"""Helper function for equals and identical"""

# some stores (e.g., scipy) do not seem to preserve order, so don't
# require matching order for equality
def compat(x: Variable, y: Variable) -> bool:
return getattr(x, compat_str)(y)
return getattr(x, compat_str)(y, check_dtype=check_dtype)

return self._coord_names == other._coord_names and utils.dict_equiv(
self._variables, other._variables, compat=compat
)

def broadcast_equals(self, other: "Dataset") -> bool:
def broadcast_equals(self, other: "Dataset", check_dtype: bool = False) -> bool:
"""Two Datasets are broadcast equal if they are equal after
broadcasting all variables against each other.

For example, variables that are scalar in one dataset but non-scalar in
the other dataset can still be broadcast equal if the the non-scalar
variable is a constant.

Parameters
----------
check_dtype : bool, default: False
Whether to check if the objects' dtypes are identical. Compares the
dtypes of all data variables and coords.

See Also
--------
Dataset.equals
Dataset.identical
"""
try:
return self._all_compat(other, "broadcast_equals")
return self._all_compat(other, "broadcast_equals", check_dtype=check_dtype)
except (TypeError, AttributeError):
return False

def equals(self, other: "Dataset") -> bool:
def equals(self, other: "Dataset", check_dtype: bool = False) -> bool:
"""Two Datasets are equal if they have matching variables and
coordinates, all of which are equal.

Expand All @@ -1577,28 +1583,40 @@ def equals(self, other: "Dataset") -> bool:
This method is necessary because `v1 == v2` for ``Dataset``
does element-wise comparisons (like numpy.ndarrays).

Parameters
----------
check_dtype : bool, default: False
Whether to check if the objects' dtypes are identical. Compares the
dtypes of all data variables and coords.

See Also
--------
Dataset.broadcast_equals
Dataset.identical
"""
try:
return self._all_compat(other, "equals")
return self._all_compat(other, "equals", check_dtype=check_dtype)
except (TypeError, AttributeError):
return False

def identical(self, other: "Dataset") -> bool:
def identical(self, other: "Dataset", check_dtype: bool = False) -> bool:
"""Like equals, but also checks all dataset attributes and the
attributes on all variables and coordinates.

Parameters
----------
check_dtype : bool, default: False
Whether to check if the objects' dtypes are identical. Compares the
dtypes of all data variables and coords.

See Also
--------
Dataset.broadcast_equals
Dataset.equals
"""
try:
return utils.dict_equiv(self.attrs, other.attrs) and self._all_compat(
other, "identical"
other, "identical", check_dtype=check_dtype
)
except (TypeError, AttributeError):
return False
Expand Down
41 changes: 34 additions & 7 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,11 @@ def as_shared_dtype(scalars_or_arrays):
return [x.astype(out_type, copy=False) for x in arrays]


def lazy_array_equiv(arr1, arr2):
def lazy_array_equiv(arr1, arr2, check_dtype=False):
"""Like array_equal, but doesn't actually compare values.
Returns True when arr1, arr2 identical or their dask tokens are equal.
Returns False when shapes are not equal.
Returns False if dtype does not match and check_dtype is True.
Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays;
or their dask tokens are not equal
"""
Expand All @@ -231,6 +232,9 @@ def lazy_array_equiv(arr1, arr2):
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
# "is False" needed -> should not return on None
if check_dtype and same_dtype(arr1, arr2, lazy=True) is False:
return False
if dask_array and is_duck_dask_array(arr1) and is_duck_dask_array(arr2):
# GH3068, GH4221
if tokenize(arr1) == tokenize(arr2):
Expand All @@ -240,26 +244,47 @@ def lazy_array_equiv(arr1, arr2):
return None


def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
def same_dtype(arr1, arr2, lazy):

# object dask arrays can change dtype -> need to compute them
if arr1.dtype == object and is_duck_dask_array(arr1):
if lazy:
return None
# arr.compute() can return a scalar -> wrap in an array
arr1 = asarray(arr1.compute())

if arr2.dtype == object and is_duck_dask_array(arr2):
if lazy:
return None
arr2 = asarray(arr2.compute())

return arr1.dtype == arr2.dtype


def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8, check_dtype=False):
"""Like np.allclose, but also allows values to be NaN in both arrays"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)

lazy_equiv = lazy_array_equiv(arr1, arr2)
lazy_equiv = lazy_array_equiv(arr1, arr2, check_dtype=check_dtype)
if lazy_equiv is None:
if check_dtype and not same_dtype(arr1, arr2, lazy=False):
return False
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
else:
return lazy_equiv


def array_equiv(arr1, arr2):
def array_equiv(arr1, arr2, check_dtype=False):
"""Like np.array_equal, but also allows values to be NaN in both arrays"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
lazy_equiv = lazy_array_equiv(arr1, arr2)
lazy_equiv = lazy_array_equiv(arr1, arr2, check_dtype=check_dtype)
if lazy_equiv is None:
if check_dtype and not same_dtype(arr1, arr2, lazy=False):
return False
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
Expand All @@ -268,14 +293,16 @@ def array_equiv(arr1, arr2):
return lazy_equiv


def array_notnull_equiv(arr1, arr2):
def array_notnull_equiv(arr1, arr2, check_dtype=False):
"""Like np.array_equal, but also allows values to be NaN in either or both
arrays
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
lazy_equiv = lazy_array_equiv(arr1, arr2)
lazy_equiv = lazy_array_equiv(arr1, arr2, check_dtype=check_dtype)
if lazy_equiv is None:
if check_dtype and not same_dtype(arr1, arr2, lazy=False):
return False
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
Expand Down
45 changes: 34 additions & 11 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ def diff_dim_summary(a, b):
return ""


def _diff_mapping_repr(a_mapping, b_mapping, compat, title, summarizer, col_width=None):
def _diff_mapping_repr(
a_mapping, b_mapping, compat, title, summarizer, col_width=None, check_dtype=False
):
def extra_items_repr(extra_keys, mapping, ab_side):
extra_repr = [summarizer(k, mapping[k], col_width) for k in extra_keys]
if extra_repr:
Expand All @@ -586,9 +588,15 @@ def extra_items_repr(extra_keys, mapping, ab_side):
try:
# compare xarray variable
if not callable(compat):
compatible = getattr(a_mapping[k], compat)(b_mapping[k])
compatible = getattr(a_mapping[k], compat)(
b_mapping[k], check_dtype=check_dtype
)
else:
compatible = compat(a_mapping[k], b_mapping[k])
compatible = compat(
a_mapping[k],
b_mapping[k],
check_dtype=check_dtype,
)
is_variable = True
except AttributeError:
# compare attribute value
Expand Down Expand Up @@ -620,8 +628,9 @@ def extra_items_repr(extra_keys, mapping, ab_side):

diff_items += [ab_side + s[1:] for ab_side, s in zip(("L", "R"), temp)]

maybe_dtype = " (values and/ or dtype)" if check_dtype else ""
if diff_items:
summary += [f"Differing {title.lower()}:"] + diff_items
summary += [f"Differing {title.lower()}{maybe_dtype}:"] + diff_items

summary += extra_items_repr(a_keys - b_keys, a_mapping, "left")
summary += extra_items_repr(b_keys - a_keys, b_mapping, "right")
Expand Down Expand Up @@ -656,7 +665,7 @@ def _compat_to_str(compat):
return compat


def diff_array_repr(a, b, compat):
def diff_array_repr(a, b, compat, check_dtype=False):
# used for DataArray, Variable and IndexVariable
summary = [
"Left and right {} objects are not {}".format(
Expand All @@ -670,18 +679,22 @@ def diff_array_repr(a, b, compat):
else:
equiv = array_equiv

if not equiv(a.data, b.data):
maybe_dtype = " or dtype" if check_dtype else ""

if not equiv(a.data, b.data, check_dtype=check_dtype):
temp = [wrap_indent(short_numpy_repr(obj), start=" ") for obj in (a, b)]
diff_data_repr = [
ab_side + "\n" + ab_data_repr
for ab_side, ab_data_repr in zip(("L", "R"), temp)
]
summary += ["Differing values:"] + diff_data_repr
summary += [f"Differing values{maybe_dtype}:"] + diff_data_repr

if hasattr(a, "coords"):
col_width = _calculate_col_width(set(a.coords) | set(b.coords))
summary.append(
diff_coords_repr(a.coords, b.coords, compat, col_width=col_width)
diff_coords_repr(
a.coords, b.coords, compat, col_width=col_width, check_dtype=check_dtype
)
)

if compat == "identical":
Expand All @@ -690,7 +703,7 @@ def diff_array_repr(a, b, compat):
return "\n".join(summary)


def diff_dataset_repr(a, b, compat):
def diff_dataset_repr(a, b, compat, check_dtype=False):
summary = [
"Left and right {} objects are not {}".format(
type(a).__name__, _compat_to_str(compat)
Expand All @@ -702,9 +715,19 @@ def diff_dataset_repr(a, b, compat):
)

summary.append(diff_dim_summary(a, b))
summary.append(diff_coords_repr(a.coords, b.coords, compat, col_width=col_width))
summary.append(
diff_data_vars_repr(a.data_vars, b.data_vars, compat, col_width=col_width)
diff_coords_repr(
a.coords, b.coords, compat, col_width=col_width, check_dtype=check_dtype
)
)
summary.append(
diff_data_vars_repr(
a.data_vars,
b.data_vars,
compat,
col_width=col_width,
check_dtype=check_dtype,
)
)

if compat == "identical":
Expand Down
Loading