forked from data-apis/array-api-extra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
158 lines (125 loc) · 5.54 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Pytest fixtures."""
from collections.abc import Callable
from contextlib import suppress
from functools import partial, wraps
from types import ModuleType
from typing import ParamSpec, TypeVar, cast
import numpy as np
import pytest
from array_api_extra._lib import Backend
from array_api_extra._lib._testing import xfail
from array_api_extra._lib._utils._compat import array_namespace
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._typing import Device
from array_api_extra.testing import patch_lazy_xp_functions
T = TypeVar("T")
P = ParamSpec("P")
np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
@pytest.fixture(params=tuple(Backend))
def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,RT03
"""
Parameterized fixture that iterates on all libraries.
Returns
-------
The current Backend enum.
"""
elem = cast(Backend, request.param)
for marker_name, skip_or_xfail in (
("skip_xp_backend", pytest.skip),
("xfail_xp_backend", partial(xfail, request)),
):
for marker in request.node.iter_markers(marker_name):
library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage]
if not isinstance(library, Backend):
msg = f"argument of {marker_name} must be a Backend enum"
raise TypeError(msg)
if library == elem:
reason = library.value
with suppress(KeyError):
reason += ":" + cast(str, marker.kwargs["reason"])
skip_or_xfail(reason=reason)
return elem
class NumPyReadOnly:
"""
Variant of array_api_compat.numpy producing read-only arrays.
Read-only NumPy arrays fail on `__iadd__` etc., whereas read-only libraries such as
JAX and Sparse simply don't define those methods, which makes calls to `+=` fall
back to `__add__`.
Note that this is not a full read-only Array API library. Notably,
`array_namespace(x)` returns array_api_compat.numpy. This is actually the desired
behaviour, so that when a tested function internally calls `xp =
array_namespace(*args) or xp`, it will internally create writeable arrays.
For this reason, tests that explicitly pass xp=xp to the tested functions may
misbehave and should be skipped for NUMPY_READONLY.
"""
def __getattr__(self, name: str) -> object: # numpydoc ignore=PR01,RT01
"""Wrap all functions that return arrays to make their output read-only."""
func = getattr(np_compat, name)
if not callable(func) or isinstance(func, type):
return func
return self._wrap(func)
@staticmethod
def _wrap(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
"""Wrap func to make all np.ndarrays it returns read-only."""
def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01
"""Unset the writeable flag in o."""
try:
# Don't use is_numpy_array(o), as it includes np.generic
if isinstance(o, np.ndarray):
o.flags.writeable = False
except TypeError:
# Cannot interpret as a data type
return o
# This works with namedtuples too
if isinstance(o, tuple | list):
return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType]
return o
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
return as_readonly(func(*args, **kwargs))
return wrapper
@pytest.fixture
def xp(
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
) -> ModuleType: # numpydoc ignore=PR01,RT03
"""
Parameterized fixture that iterates on all libraries.
Returns
-------
The current array namespace.
"""
if library == Backend.NUMPY_READONLY:
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
xp = pytest.importorskip(library.value)
# Possibly wrap module with array_api_compat
xp = array_namespace(xp.empty(0))
# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
# in the global scope of the module containing the test function.
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
if library == Backend.JAX:
import jax
# suppress unused-ignore to run mypy in -e lint as well as -e dev
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
return xp
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
def da(
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
) -> ModuleType: # numpydoc ignore=PR01,RT01
"""Variant of the `xp` fixture that only yields dask.array."""
xp = pytest.importorskip("dask.array")
xp = array_namespace(xp.empty(0))
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
return xp
@pytest.fixture
def device(
library: Backend, xp: ModuleType
) -> Device: # numpydoc ignore=PR01,RT01,RT03
"""
Return a valid device for the backend.
Where possible, return a device that is not the default one.
"""
if library == Backend.ARRAY_API_STRICT:
d = xp.Device("device1")
assert get_device(xp.empty(0)) != d
return d
return get_device(xp.empty(0))