forked from data-apis/array-api-extra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_delegation.py
171 lines (143 loc) · 5.97 KB
/
_delegation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Delegation to existing implementations for Public API Functions."""
from collections.abc import Sequence
from types import ModuleType
from typing import Literal
from ._lib import _funcs
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array
__all__ = ["isclose", "pad"]
def isclose(
a: Array | complex,
b: Array | complex,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Return a boolean array where two arrays are element-wise equal within a tolerance.
The tolerance values are positive, typically very small numbers. The relative
difference ``(rtol * abs(b))`` and the absolute difference `atol` are added together
to compare against the absolute difference between `a` and `b`.
NaNs are treated as equal if they are in the same place and if ``equal_nan=True``.
Infs are treated as equal if they are in the same place and of the same sign in both
arrays.
Parameters
----------
a, b : Array | int | float | complex | bool
Input objects to compare. At least one must be an array.
rtol : array_like, optional
The relative tolerance parameter (see Notes).
atol : array_like, optional
The absolute tolerance parameter (see Notes).
equal_nan : bool, optional
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
equal to NaN's in `b` in the output array.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.
Returns
-------
Array
A boolean array of shape broadcasted from `a` and `b`, containing ``True`` where
`a` is close to `b`, and ``False`` otherwise.
Warnings
--------
The default `atol` is not appropriate for comparing numbers with magnitudes much
smaller than one (see notes).
See Also
--------
math.isclose : Similar function in stdlib for Python scalars.
Notes
-----
For finite values, `isclose` uses the following equation to test whether two
floating point values are equivalent::
absolute(a - b) <= (atol + rtol * absolute(b))
Unlike the built-in `math.isclose`,
the above equation is not symmetric in `a` and `b`,
so that ``isclose(a, b)`` might be different from ``isclose(b, a)`` in some rare
cases.
The default value of `atol` is not appropriate when the reference value `b` has
magnitude smaller than one. For example, it is unlikely that ``a = 1e-9`` and
``b = 2e-9`` should be considered "close", yet ``isclose(1e-9, 2e-9)`` is ``True``
with default settings. Be sure to select `atol` for the use case at hand, especially
for defining the threshold below which a non-zero value in `a` will be considered
"close" to a very small or zero value in `b`.
The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
`b` need not have the same shape in order for ``isclose(a, b)`` to evaluate to
``True``.
`isclose` is not defined for non-numeric data types.
``bool`` is considered a numeric data-type for this purpose.
"""
xp = array_namespace(a, b) if xp is None else xp
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
):
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
if is_torch_namespace(xp):
a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
def pad(
x: Array,
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
mode: Literal["constant"] = "constant",
*,
constant_values: complex = 0,
xp: ModuleType | None = None,
) -> Array:
"""
Pad the input array.
Parameters
----------
x : array
Input array.
pad_width : int or tuple of ints or sequence of pairs of ints
Pad the input array with this many elements from each side.
If a sequence of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
each pair applies to the corresponding axis of ``x``.
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
copies of this tuple.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
xp = array_namespace(x) if xp is None else xp
if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_pydata_sparse_namespace(xp)
):
return xp.pad(x, pad_width, mode, constant_values=constant_values)
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
if is_torch_namespace(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)