forked from data-apis/array-api-extra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_testing.py
171 lines (142 loc) · 4.93 KB
/
_testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Testing utilities.
Note that this is private API; don't expect it to be stable.
"""
import math
from types import ModuleType
from ._utils._compat import (
array_namespace,
is_cupy_namespace,
is_dask_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._utils._typing import Array
__all__ = ["xp_assert_close", "xp_assert_equal"]
def _check_ns_shape_dtype(
actual: Array, desired: Array
) -> ModuleType: # numpydoc ignore=RT03
"""
Assert that namespace, shape and dtype of the two arrays match.
Parameters
----------
actual : Array
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).
Returns
-------
Arrays namespace.
"""
actual_xp = array_namespace(actual) # Raises on scalars and lists
desired_xp = array_namespace(desired)
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg
actual_shape = actual.shape
desired_shape = desired.shape
if is_dask_namespace(desired_xp):
if any(math.isnan(i) for i in actual_shape):
actual_shape = actual.compute().shape
if any(math.isnan(i) for i in desired_shape):
desired_shape = desired.compute().shape
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
assert actual_shape == desired_shape, msg
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
assert actual.dtype == desired.dtype, msg
return desired_xp
def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
"""
Array-API compatible version of `np.testing.assert_array_equal`.
Parameters
----------
actual : Array
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).
err_msg : str, optional
Error message to display on failure.
See Also
--------
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(actual, desired)
if is_cupy_namespace(xp):
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
elif is_torch_namespace(xp):
# PyTorch recommends using `rtol=0, atol=0` like this
# to test for exact equality
xp.testing.assert_close(
actual,
desired,
rtol=0,
atol=0,
equal_nan=True,
check_dtype=False,
msg=err_msg or None,
)
else:
import numpy as np # pylint: disable=import-outside-toplevel
if is_pydata_sparse_namespace(xp):
actual = actual.todense()
desired = desired.todense()
# JAX uses `np.testing`
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
def xp_assert_close(
actual: Array,
desired: Array,
*,
rtol: float | None = None,
atol: float = 0,
err_msg: str = "",
) -> None:
"""
Array-API compatible version of `np.testing.assert_allclose`.
Parameters
----------
actual : Array
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).
rtol : float, optional
Relative tolerance. Default: dtype-dependent.
atol : float, optional
Absolute tolerance. Default: 0.
err_msg : str, optional
Error message to display on failure.
See Also
--------
xp_assert_equal : Similar function for exact equality checks.
isclose : Public function for checking closeness.
numpy.testing.assert_allclose : Similar function for NumPy arrays.
Notes
-----
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
"""
xp = _check_ns_shape_dtype(actual, desired)
floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
if rtol is None and floating:
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
# roughly half way between sqrt(eps) and the default for
# `numpy.testing.assert_allclose`, 1e-7
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
elif rtol is None:
rtol = 1e-7
if is_cupy_namespace(xp):
xp.testing.assert_allclose(
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
)
elif is_torch_namespace(xp):
xp.testing.assert_close(
actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None
)
else:
import numpy as np # pylint: disable=import-outside-toplevel
if is_pydata_sparse_namespace(xp):
actual = actual.to_dense()
desired = desired.to_dense()
# JAX uses `np.testing`
assert isinstance(rtol, float)
np.testing.assert_allclose(
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
)