Skip to content

ENH: new function isclose #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
cov
create_diagonal
expand_dims
isclose
kron
nunique
pad
Expand Down
253 changes: 125 additions & 128 deletions pixi.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import pad
from ._delegation import isclose, pad
from ._lib._at import at
from ._lib._funcs import (
atleast_nd,
Expand All @@ -23,6 +23,7 @@
"cov",
"create_diagonal",
"expand_dims",
"isclose",
"kron",
"nunique",
"pad",
Expand Down
92 changes: 91 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._lib._utils._compat import array_namespace
from ._lib._utils._typing import Array

__all__ = ["pad"]
__all__ = ["isclose", "pad"]


def _delegate(xp: ModuleType, *backends: Backend) -> bool:
Expand All @@ -30,6 +30,96 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:
return any(backend.is_namespace(xp) for backend in backends)


def isclose(
a: Array,
b: Array,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Return a boolean array where two arrays are element-wise equal within a tolerance.

The tolerance values are positive, typically very small numbers. The relative
difference ``(rtol * abs(b))`` and the absolute difference `atol` are added together
to compare against the absolute difference between `a` and `b`.

NaNs are treated as equal if they are in the same place and if ``equal_nan=True``.
Infs are treated as equal if they are in the same place and of the same sign in both
arrays.

Parameters
----------
a, b : Array
Input arrays to compare.
rtol : array_like, optional
The relative tolerance parameter (see Notes).
atol : array_like, optional
The absolute tolerance parameter (see Notes).
equal_nan : bool, optional
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
equal to NaN's in `b` in the output array.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Returns
-------
Array
A boolean array of shape broadcasted from `a` and `b`, containing ``True`` where
`a` is close to `b`, and ``False`` otherwise.

Warnings
--------
The default `atol` is not appropriate for comparing numbers with magnitudes much
smaller than one (see notes).

See Also
--------
math.isclose : Similar function in stdlib for Python scalars.

Notes
-----
For finite values, `isclose` uses the following equation to test whether two
floating point values are equivalent::

absolute(a - b) <= (atol + rtol * absolute(b))

Unlike the built-in `math.isclose`,
the above equation is not symmetric in `a` and `b`,
so that ``isclose(a, b)`` might be different from ``isclose(b, a)`` in some rare
cases.

The default value of `atol` is not appropriate when the reference value `b` has
magnitude smaller than one. For example, it is unlikely that ``a = 1e-9`` and
``b = 2e-9`` should be considered "close", yet ``isclose(1e-9, 2e-9)`` is ``True``
with default settings. Be sure to select `atol` for the use case at hand, especially
for defining the threshold below which a non-zero value in `a` will be considered
"close" to a very small or zero value in `b`.

The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
`b` need not have the same shape in order for ``isclose(a, b)`` to evaluate to
``True``.

`isclose` is not defined for non-numeric data types.
``bool`` is considered a numeric data-type for this purpose.
"""
xp = array_namespace(a, b) if xp is None else xp

if _delegate(
xp,
Backend.NUMPY,
Backend.CUPY,
Backend.DASK,
Backend.JAX,
Backend.TORCH,
):
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)

return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)


def pad(
x: Array,
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
Expand Down
34 changes: 34 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,40 @@ def expand_dims(
return a


def isclose(
a: Array,
b: Array,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""

a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
if a_inexact or b_inexact:
# FIXME: use scipy's lazywhere to suppress warnings on inf
out = xp.abs(a - b) <= (atol + rtol * xp.abs(b))
out = xp.where(xp.isinf(a) & xp.isinf(b), xp.sign(a) == xp.sign(b), out)
if equal_nan:
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
return out

if xp.isdtype(a.dtype, "bool") or xp.isdtype(b.dtype, "bool"):
if atol >= 1 or rtol >= 1:
return xp.ones_like(a == b)
Copy link
Contributor Author

@crusaderky crusaderky Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On eager backends, this is less performant than

return xp.ones(xp.broadcast_arrays(a, b)[0], dtype=bool, device=a.device)

but it supports backends with NaN shapes like Dask.
Both jax.jit and dask with non-NaN shape should elide the comparison away.

return a == b

# integer types
atol = int(atol)
if rtol == 0:
return xp.abs(a - b) <= atol
nrtol = int(1.0 / rtol)
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)


def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Kronecker product of two arrays.
Expand Down
29 changes: 27 additions & 2 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
Note that this is private API; don't expect it to be stable.
"""

import math
from types import ModuleType

from ._utils._compat import (
array_namespace,
is_cupy_namespace,
is_dask_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
)
Expand Down Expand Up @@ -40,8 +42,16 @@ def _check_ns_shape_dtype(
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg

msg = f"shapes do not match: {actual.shape} != f{desired.shape}"
assert actual.shape == desired.shape, msg
actual_shape = actual.shape
desired_shape = desired.shape
if is_dask_namespace(desired_xp):
if any(math.isnan(i) for i in actual_shape):
actual_shape = actual.compute().shape
if any(math.isnan(i) for i in desired_shape):
desired_shape = desired.compute().shape

msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
assert actual_shape == desired_shape, msg
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will have to be replicated on the scipy PR for dask support


msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
assert actual.dtype == desired.dtype, msg
Expand All @@ -61,6 +71,11 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
The expected array (typically hardcoded).
err_msg : str, optional
Error message to display on failure.

See Also
--------
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(actual, desired)

Expand Down Expand Up @@ -112,6 +127,16 @@ def xp_assert_close(
Absolute tolerance. Default: 0.
err_msg : str, optional
Error message to display on failure.

See Also
--------
xp_assert_equal : Similar function for exact equality checks.
isclose : Public function for checking closeness.
numpy.testing.assert_allclose : Similar function for NumPy arrays.

Notes
-----
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
"""
xp = _check_ns_shape_dtype(actual, desired)

Expand Down
136 changes: 135 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
cov,
create_diagonal,
expand_dims,
isclose,
kron,
nunique,
pad,
Expand All @@ -23,7 +24,7 @@
from array_api_extra._lib._utils._typing import Array, Device

# some xp backends are untyped
# mypy: disable-error-code=no-untyped-usage
# mypy: disable-error-code=no-untyped-def


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
Expand Down Expand Up @@ -252,6 +253,139 @@ def test_xp(self, xp: ModuleType):
assert y.shape == (1, 1, 1, 3)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestIsClose:
# FIXME use lazywhere to avoid warnings on inf
@pytest.mark.filterwarnings("ignore:invalid value encountered")
@pytest.mark.parametrize(
("a", "b"),
[
(0.0, 0.0),
(1.0, 1.0),
(1.0, 2.0),
(1.0, -1.0),
(100.0, 101.0),
(0, 0),
(1, 1),
(1, 2),
(1, -1),
(1.0 + 1j, 1.0 + 1j),
(1.0 + 1j, 1.0 - 1j),
(float("inf"), float("inf")),
(float("inf"), 100.0),
(float("inf"), float("-inf")),
(float("nan"), float("nan")),
(float("nan"), 0.0),
(0.0, float("nan")),
(1e6, 1e6 + 1), # True - within rtol
(1e6, 1e6 + 100), # False - outside rtol
(1e-6, 1.1e-6), # False - outside atol
(1e-7, 1.1e-7), # True - outside atol
(1e6 + 0j, 1e6 + 1j), # True - within rtol
(1e6 + 0j, 1e6 + 100j), # False - outside rtol
],
)
def test_basic(self, a: float, b: float, xp: ModuleType):
a_xp = xp.asarray(a)
b_xp = xp.asarray(b)

xp_assert_equal(isclose(a_xp, b_xp), xp.asarray(np.isclose(a, b)))

with warnings.catch_warnings():
warnings.simplefilter("ignore")
r_xp = xp.asarray(np.arange(10), dtype=a_xp.dtype)
ar_xp = a_xp * r_xp
br_xp = b_xp * r_xp
ar_np = a * np.arange(10)
br_np = b * np.arange(10)

xp_assert_equal(isclose(ar_xp, br_xp), xp.asarray(np.isclose(ar_np, br_np)))

@pytest.mark.parametrize("dtype", ["float32", "int32"])
def test_broadcast(self, dtype: str, xp: ModuleType):
dtype = getattr(xp, dtype)
a = xp.asarray([1, 2, 3], dtype=dtype)
b = xp.asarray([[1], [5]], dtype=dtype)
actual = isclose(a, b)
expect = xp.asarray(
[[True, False, False], [False, False, False]], dtype=xp.bool
)

xp_assert_equal(actual, expect)

# FIXME use lazywhere to avoid warnings on inf
@pytest.mark.filterwarnings("ignore:invalid value encountered")
def test_some_inf(self, xp: ModuleType):
a = xp.asarray([0.0, 1.0, float("inf"), float("inf"), float("inf")])
b = xp.asarray([1e-9, 1.0, float("inf"), float("-inf"), 2.0])
actual = isclose(a, b)
xp_assert_equal(actual, xp.asarray([True, True, True, False, False]))

def test_equal_nan(self, xp: ModuleType):
a = xp.asarray([float("nan"), float("nan"), 1.0])
b = xp.asarray([float("nan"), 1.0, float("nan")])
xp_assert_equal(isclose(a, b), xp.asarray([False, False, False]))
xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False]))

@pytest.mark.parametrize("dtype", ["float32", "complex64", "int32"])
def test_tolerance(self, dtype: str, xp: ModuleType):
dtype = getattr(xp, dtype)
a = xp.asarray([100, 100], dtype=dtype)
b = xp.asarray([101, 102], dtype=dtype)
xp_assert_equal(isclose(a, b), xp.asarray([False, False]))
xp_assert_equal(isclose(a, b, atol=1), xp.asarray([True, False]))
xp_assert_equal(isclose(a, b, rtol=0.01), xp.asarray([True, False]))

# Attempt to trigger division by 0 in rtol on int dtype
xp_assert_equal(isclose(a, b, rtol=0), xp.asarray([False, False]))
xp_assert_equal(isclose(a, b, atol=1, rtol=0), xp.asarray([True, False]))

def test_very_small_numbers(self, xp: ModuleType):
a = xp.asarray([1e-9, 1e-9])
b = xp.asarray([1.0001e-9, 1.00001e-9])
# Difference is below default atol
xp_assert_equal(isclose(a, b), xp.asarray([True, True]))
# Use only rtol
xp_assert_equal(isclose(a, b, atol=0), xp.asarray([False, True]))
xp_assert_equal(isclose(a, b, atol=0, rtol=0), xp.asarray([False, False]))

def test_bool_dtype(self, xp: ModuleType):
a = xp.asarray([False, True, False])
b = xp.asarray([True, True, False])
xp_assert_equal(isclose(a, b), xp.asarray([False, True, True]))
xp_assert_equal(isclose(a, b, atol=1), xp.asarray([True, True, True]))
xp_assert_equal(isclose(a, b, atol=2), xp.asarray([True, True, True]))
xp_assert_equal(isclose(a, b, rtol=1), xp.asarray([True, True, True]))
xp_assert_equal(isclose(a, b, rtol=2), xp.asarray([True, True, True]))

# Test broadcasting
xp_assert_equal(
isclose(a, xp.asarray(True), atol=1), xp.asarray([True, True, True])
)
xp_assert_equal(
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
)

def test_none_shape(self, xp: ModuleType):
a = xp.asarray([1, 5, 0])
b = xp.asarray([1, 4, 2])
b = b[a < 5]
a = a[a < 5]
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

def test_none_shape_bool(self, xp: ModuleType):
a = xp.asarray([True, True, False])
b = xp.asarray([True, False, True])
b = b[a]
a = a[a]
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

def test_xp(self, xp: ModuleType):
a = xp.asarray([0.0, 0.0])
b = xp.asarray([1e-9, 1e-4])
xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False]))


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
class TestKron:
def test_basic(self, xp: ModuleType):
Expand Down
Loading
Loading