|
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 | import pytest
|
6 |
| -from numpy import min as numpy_min |
7 | 6 |
|
8 | 7 | from array_api_extra._lib import Backend
|
9 | 8 | from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
|
@@ -205,21 +204,30 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra
|
205 | 204 | xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0]))
|
206 | 205 |
|
207 | 206 |
|
208 |
| -lazy_xp_function(numpy_min, static_argnames="axis") |
| 207 | +try: |
| 208 | + # Test an arbitrary Cython ufunc (@cython.vectorize). |
| 209 | + # When SCIPY_ARRAY_API is not set, this is the same as |
| 210 | + # scipy.special.erf. |
| 211 | + from scipy.special._ufuncs import erf # type: ignore[import-not-found] |
209 | 212 |
|
| 213 | + lazy_xp_function(erf) # pyright: ignore[reportUnknownArgumentType] |
| 214 | +except ImportError: |
| 215 | + erf = None |
210 | 216 |
|
211 |
| -def test_lazy_xp_function_ufunc(xp: ModuleType, library: Backend): |
212 |
| - x = xp.asarray([[1, 4], [3, 2]]) |
213 |
| - if library in (Backend.ARRAY_API_STRICT, Backend.TORCH, Backend.JAX): |
214 |
| - # array-api-strict, torch and jax don't define __array_ufunc__ |
215 |
| - # numpy ufuncs can't auto-convert to numpy from torch |
| 217 | + |
| 218 | +@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # torch |
| 219 | +def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend): |
| 220 | + pytest.importorskip("scipy") |
| 221 | + assert erf is not None |
| 222 | + x = xp.asarray([6.0, 7.0]) |
| 223 | + if library in (Backend.ARRAY_API_STRICT, Backend.JAX): |
216 | 224 | # array-api-strict arrays are auto-converted to numpy
|
217 | 225 | # eager jax arrays are auto-converted to numpy in eager jax
|
218 | 226 | # and fail in jax.jit (which lazy_xp_function tests here)
|
219 | 227 | with pytest.raises((TypeError, AssertionError)):
|
220 |
| - xp_assert_equal(numpy_min(x, axis=0), xp.asarray([1, 2])) |
| 228 | + xp_assert_equal(erf(x), xp.asarray([1.0, 1.0])) |
221 | 229 | else:
|
222 | 230 | # cupy, dask and sparse define __array_ufunc__ and dispatch accordingly
|
223 | 231 | # note that when sparse reduces to scalar it returns a np.generic, which
|
224 | 232 | # would make xp_assert_equal fail.
|
225 |
| - xp_assert_equal(numpy_min(x, axis=0), xp.asarray([1, 2])) |
| 233 | + xp_assert_equal(erf(x), xp.asarray([1.0, 1.0])) |
0 commit comments