forked from data-apis/array-api-compat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_common.py
316 lines (256 loc) · 10.1 KB
/
test_common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import math
import pytest
import numpy as np
import array
from numpy.testing import assert_allclose
from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
)
from array_api_compat import (
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
)
from ._helpers import import_, wrapped_libraries, all_libraries
is_array_functions = {
'numpy': 'is_numpy_array',
'cupy': 'is_cupy_array',
'torch': 'is_torch_array',
'dask.array': 'is_dask_array',
'jax.numpy': 'is_jax_array',
'sparse': 'is_pydata_sparse_array',
}
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
'cupy': 'is_cupy_namespace',
'torch': 'is_torch_namespace',
'dask.array': 'is_dask_namespace',
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
}
@pytest.mark.parametrize('library', is_array_functions.keys())
@pytest.mark.parametrize('func', is_array_functions.values())
def test_is_xp_array(library, func):
lib = import_(library)
is_func = globals()[func]
x = lib.asarray([1, 2, 3])
assert is_func(x) == (func == is_array_functions[library])
assert is_array_api_obj(x)
@pytest.mark.parametrize('library', is_namespace_functions.keys())
@pytest.mark.parametrize('func', is_namespace_functions.values())
def test_is_xp_namespace(library, func):
lib = import_(library)
is_func = globals()[func]
assert is_func(lib) == (func == is_namespace_functions[library])
@pytest.mark.parametrize('library', all_libraries)
def test_xp_is_array_generics(library):
"""
Test that scalar selection on a xp.ndarray always returns
an object that matches with exactly one among the is_*_array
function of the same library and is_numpy_array.
"""
lib = import_(library)
x = lib.asarray([1, 2, 3])
x0 = x[0]
matches = []
for library2, func in is_array_functions.items():
is_func = globals()[func]
if is_func(x0):
matches.append(library2)
assert matches in ([library], ["numpy"])
@pytest.mark.parametrize("library", all_libraries)
def test_is_writeable_array(library):
lib = import_(library)
x = lib.asarray([1, 2, 3])
if is_writeable_array(x):
x[1] = 4
else:
with pytest.raises((TypeError, ValueError)):
x[1] = 4
def test_is_writeable_array_numpy():
x = np.asarray([1, 2, 3])
assert is_writeable_array(x)
x.flags.writeable = False
assert not is_writeable_array(x)
@pytest.mark.parametrize("library", all_libraries)
def test_size(library):
xp = import_(library)
x = xp.asarray([1, 2, 3])
assert size(x) == 3
@pytest.mark.parametrize("library", all_libraries)
def test_size_none(library):
if library == "sparse":
pytest.skip("No arange(); no indexing by sparse arrays")
xp = import_(library)
x = xp.arange(10)
x = x[x < 5]
# dask.array now has shape=(nan, ) and size=nan
# ndonnx now has shape=(None, ) and size=None
# Eager libraries have shape=(5, ) and size=5
assert size(x) in (None, 5)
@pytest.mark.parametrize("library", all_libraries)
def test_is_lazy_array(library):
lib = import_(library)
x = lib.asarray([1, 2, 3])
assert isinstance(is_lazy_array(x), bool)
@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)])
def test_is_lazy_array_nan_size(shape, monkeypatch):
"""Test is_lazy_array() on an unknown Array API compliant object
with NaN (like Dask) or None (like ndonnx) in its shape
"""
xp = import_("array_api_strict")
x = xp.asarray(1)
assert not is_lazy_array(x)
monkeypatch.setattr(type(x), "shape", shape)
assert is_lazy_array(x)
@pytest.mark.parametrize("exc", [TypeError, AssertionError])
def test_is_lazy_array_bool_raises(exc, monkeypatch):
"""Test is_lazy_array() on an unknown Array API compliant object
where calling bool() raises:
- TypeError: e.g. like jitted JAX. This is the proper exception which
lazy arrays should raise as per the Array API specification
- something else: e.g. like Dask, where bool() triggers compute()
which can result in any kind of exception to be raised
"""
xp = import_("array_api_strict")
x = xp.asarray(1)
assert not is_lazy_array(x)
def __bool__(self):
raise exc("Hello world")
monkeypatch.setattr(type(x), "__bool__", __bool__)
assert is_lazy_array(x)
@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
# We can't test much for device() and to_device() other than that
# x.to_device(x.device) works.
x = xp.asarray([1, 2, 3])
dev = device(x)
x2 = to_device(x, dev)
assert device(x) == device(x2)
@pytest.mark.parametrize("library", wrapped_libraries)
def test_to_device_host(library):
# different libraries have different semantics
# for DtoH transfers; ensure that we support a portable
# shim for common array libs
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
xp = import_(library, wrapper=True)
expected = np.array([1, 2, 3])
x = xp.asarray([1, 2, 3])
x = to_device(x, "cpu")
# torch will return a genuine Device object, but
# the other libs will do something different with
# a `device(x)` query; however, what's really important
# here is that we can test portably after calling
# to_device(x, "cpu") to return to host
assert_allclose(x, expected)
@pytest.mark.parametrize("target_library", is_array_functions.keys())
@pytest.mark.parametrize("source_library", is_array_functions.keys())
def test_asarray_cross_library(source_library, target_library, request):
if source_library == "dask.array" and target_library == "torch":
# Allow rest of test to execute instead of immediately xfailing
# xref https://github.com/pandas-dev/pandas/issues/38902
# TODO: remove xfail once
# https://github.com/dask/dask/issues/8260 is resolved
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
if source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
elif source_library == "sparse" and target_library != "sparse":
pytest.skip(reason="`sparse` does not allow implicit densification")
src_lib = import_(source_library, wrapper=True)
tgt_lib = import_(target_library, wrapper=True)
is_tgt_type = globals()[is_array_functions[target_library]]
a = src_lib.asarray([1, 2, 3])
b = tgt_lib.asarray(a)
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
# Note, we have this test here because the test suite currently doesn't
# test the copy flag to asarray() very rigorously. Once
# https://github.com/data-apis/array-api-tests/issues/241 is fixed we
# should be able to delete this.
xp = import_(library, wrapper=True)
asarray = xp.asarray
is_lib_func = globals()[is_array_functions[library]]
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
supports_copy_false = False
elif library in ['cupy', 'dask.array']:
supports_copy_false = False
else:
supports_copy_false = True
a = asarray([1])
b = asarray(a, copy=True)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 1)
assert all(a[0] == 0)
a = asarray([1])
if supports_copy_false:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 0)
else:
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
a = asarray([1])
if supports_copy_false:
pytest.raises(ValueError, lambda: asarray(a, copy=False,
dtype=xp.float64))
else:
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))
a = asarray([1])
b = asarray(a, copy=None)
assert is_lib_func(b)
a[0] = 0
assert all(b[0] == 0)
a = asarray([1.0], dtype=xp.float32)
assert a.dtype == xp.float32
b = asarray(a, dtype=xp.float64, copy=None)
assert is_lib_func(b)
assert b.dtype == xp.float64
a[0] = 0.0
assert all(b[0] == 1.0)
a = asarray([1.0], dtype=xp.float64)
assert a.dtype == xp.float64
b = asarray(a, dtype=xp.float64, copy=None)
assert is_lib_func(b)
assert b.dtype == xp.float64
a[0] = 0.0
assert all(b[0] == 0.0)
# Python built-in types
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
asarray(obj, copy=True) # No error
asarray(obj, copy=None) # No error
if supports_copy_false:
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
else:
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
# Use the standard library array to test the buffer protocol
a = array.array('f', [1.0])
b = asarray(a, copy=True)
assert is_lib_func(b)
a[0] = 0.0
assert all(b[0] == 1.0)
a = array.array('f', [1.0])
if supports_copy_false:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0.0
assert all(b[0] == 0.0)
else:
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
a = array.array('f', [1.0])
b = asarray(a, copy=None)
assert is_lib_func(b)
a[0] = 0.0
if library in ('cupy', 'dask.array'):
# A copy is required for libraries where the default device is not CPU
# dask changed behaviour of copy=None in 2024.12 to copy;
# this wrapper ensures the same behaviour in older versions too.
# https://github.com/dask/dask/pull/11524/
assert all(b[0] == 1.0)
else:
assert all(b[0] == 0.0)