Skip to content

Commit 1fe521a

Browse files
committed
ENH: new functions isclose and allclose
1 parent 48fb66a commit 1fe521a

File tree

7 files changed

+317
-4
lines changed

7 files changed

+317
-4
lines changed

Diff for: docs/api-reference.md

+2
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
:nosignatures:
77
:toctree: generated
88
9+
allclose
910
at
1011
atleast_nd
1112
cov
1213
create_diagonal
1314
expand_dims
15+
isclose
1416
kron
1517
nunique
1618
pad

Diff for: pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ checks = [
293293
"all", # report on all checks, except the below
294294
"EX01", # most docstrings do not need an example
295295
"SA01", # data-apis/array-api-extra#87
296+
"SA04", # Missing description for See Also cross-reference
296297
"ES01", # most docstrings do not need an extended summary
297298
]
298299
exclude = [ # don't report on objects that match any of these regex

Diff for: src/array_api_extra/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import pad
3+
from ._delegation import allclose, isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
atleast_nd,
@@ -18,11 +18,13 @@
1818
# pylint: disable=duplicate-code
1919
__all__ = [
2020
"__version__",
21+
"allclose",
2122
"at",
2223
"atleast_nd",
2324
"cov",
2425
"create_diagonal",
2526
"expand_dims",
27+
"isclose",
2628
"kron",
2729
"nunique",
2830
"pad",

Diff for: src/array_api_extra/_delegation.py

+139-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ._lib._utils._compat import array_namespace
88
from ._lib._utils._typing import Array
99

10-
__all__ = ["pad"]
10+
__all__ = ["allclose", "isclose", "pad"]
1111

1212

1313
def _delegate(xp: ModuleType, *backends: Backend) -> bool:
@@ -29,6 +29,144 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:
2929
return any(backend.is_namespace(xp) for backend in backends)
3030

3131

32+
def allclose(
33+
a: Array,
34+
b: Array,
35+
*,
36+
rtol: float = 1e-05,
37+
atol: float = 1e-08,
38+
equal_nan: bool = False,
39+
xp: ModuleType | None = None,
40+
) -> Array:
41+
"""
42+
Return True if two arrays are element-wise equal within a tolerance.
43+
44+
This is a simple convenience reduction around `isclose`.
45+
46+
Parameters
47+
----------
48+
a, b : Array
49+
Input arrays to compare.
50+
rtol : array_like, optional
51+
The relative tolerance parameter.
52+
atol : array_like, optional
53+
The absolute tolerance parameter.
54+
equal_nan : bool, optional
55+
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
56+
equal to NaN's in `b` in the output array.
57+
xp : array_namespace, optional
58+
The standard-compatible namespace for `a` and `b`. Default: infer.
59+
60+
Returns
61+
-------
62+
Array
63+
A 0-dimensional boolean array, containing `True` if all `a` is elementwise close
64+
to `b` and `False` otherwise.
65+
66+
See Also
67+
--------
68+
isclose
69+
math.isclose
70+
71+
Notes
72+
-----
73+
If `xp` is a lazy backend (e.g. Dask, JAX), you may not be able to test the result
74+
contents with ``bool(allclose(a, b))`` or ``if allclose(a, b): ...``.
75+
"""
76+
xp = array_namespace(a, b) if xp is None else xp
77+
return xp.all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp))
78+
79+
80+
def isclose(
81+
a: Array,
82+
b: Array,
83+
*,
84+
rtol: float = 1e-05,
85+
atol: float = 1e-08,
86+
equal_nan: bool = False,
87+
xp: ModuleType | None = None,
88+
) -> Array:
89+
"""
90+
Return a boolean array where two arrays are element-wise equal within a tolerance.
91+
92+
The tolerance values are positive, typically very small numbers. The relative
93+
difference `(rtol * abs(b))` and the absolute difference atol are added together to
94+
compare against the absolute difference between a and b.
95+
96+
NaNs are treated as equal if they are in the same place and if equal_nan=True. Infs
97+
are treated as equal if they are in the same place and of the same sign in both
98+
arrays.
99+
100+
Parameters
101+
----------
102+
a, b : Array
103+
Input arrays to compare.
104+
rtol : array_like, optional
105+
The relative tolerance parameter (see Notes).
106+
atol : array_like, optional
107+
The absolute tolerance parameter (see Notes).
108+
equal_nan : bool, optional
109+
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
110+
equal to NaN's in `b` in the output array.
111+
xp : array_namespace, optional
112+
The standard-compatible namespace for `a` and `b`. Default: infer.
113+
114+
Returns
115+
-------
116+
Array
117+
A boolean array of shape broadcasted from `a` and `b`, containing `True` where
118+
``a`` is close to ``b``, and `False` otherwise.
119+
120+
Warnings
121+
--------
122+
The default atol is not appropriate for comparing numbers with magnitudes much
123+
smaller than one ) (see notes).
124+
125+
See Also
126+
--------
127+
allclose
128+
math.isclose
129+
130+
Notes
131+
-----
132+
For finite values, `isclose` uses the following equation to test whether two
133+
floating point values are equivalent::
134+
135+
absolute(a - b) <= (atol + rtol * absolute(b))
136+
137+
Unlike the built-in `math.isclose`, the above equation is not symmetric in a and b,
138+
so that `isclose(a, b)` might be different from `isclose(b, a)` in some rare
139+
cases.
140+
141+
The default value of `atol` is not appropriate when the reference value `b` has
142+
magnitude smaller than one. For example, it is unlikely that ``a = 1e-9`` and
143+
``b = 2e-9`` should be considered "close", yet ``isclose(1e-9, 2e-9)`` is `True`
144+
with default settings. Be sure to select atol for the use case at hand, especially
145+
for defining the threshold below which a non-zero value in `a` will be considered
146+
"close" to a very small or zero value in `b`.
147+
148+
The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
149+
`b` need not have the same shape in order for `isclose(a, b)` to evaluate to
150+
`True`.
151+
152+
`isclose` is not defined for non-numeric data types. `bool` is considered a numeric
153+
data-type for this purpose.
154+
"""
155+
xp = array_namespace(a, b) if xp is None else xp
156+
157+
if _delegate(
158+
xp,
159+
Backend.NUMPY,
160+
Backend.CUPY,
161+
Backend.DASK,
162+
Backend.JAX,
163+
Backend.TORCH,
164+
):
165+
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
166+
167+
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
168+
169+
32170
def pad(
33171
x: Array,
34172
pad_width: int | tuple[int, int] | list[tuple[int, int]],

Diff for: src/array_api_extra/_lib/_funcs.py

+35
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,41 @@ def expand_dims(
304304
return a
305305

306306

307+
def isclose(
308+
a: Array,
309+
b: Array,
310+
*,
311+
rtol: float = 1e-05,
312+
atol: float = 1e-08,
313+
equal_nan: bool = False,
314+
xp: ModuleType | None = None,
315+
) -> Array: # numpydoc ignore=PR01,RT01
316+
"""See docstring in array_api_extra._delegation."""
317+
xp = array_namespace(a, b) if xp is None else xp
318+
319+
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
320+
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
321+
if a_inexact or b_inexact:
322+
# FIXME: use scipy's lazywhere to suppress warnings on inf
323+
out = xp.abs(a - b) <= (atol + rtol * xp.abs(b))
324+
out = xp.where(xp.isinf(a) & xp.isinf(b), xp.sign(a) == xp.sign(b), out)
325+
if equal_nan:
326+
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
327+
return out
328+
329+
if xp.isdtype(a.dtype, "bool") or xp.isdtype(b.dtype, "bool"):
330+
if atol >= 1 or rtol >= 1:
331+
return xp.ones_like(a == b)
332+
return a == b
333+
334+
# integer types
335+
atol = int(atol)
336+
if rtol == 0:
337+
return xp.abs(a - b) <= atol
338+
nrtol = int(1.0 / rtol)
339+
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)
340+
341+
307342
def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
308343
"""
309344
Kronecker product of two arrays.

0 commit comments

Comments
 (0)