Skip to content

Commit 1f2cc5b

Browse files
committed
fix
1 parent e771e01 commit 1f2cc5b

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

Diff for: src/array_api_extra/testing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_myfunc(xp):
143143
}
144144
try:
145145
func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
146-
except AttributeError: # numpy.ufunc
146+
except AttributeError: # @cython.vectorize
147147
_ufuncs_tags[func] = tags
148148

149149

Diff for: tests/test_testing.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import pytest
6-
from numpy import min as numpy_min
76

87
from array_api_extra._lib import Backend
98
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
@@ -205,21 +204,30 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra
205204
xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0]))
206205

207206

208-
lazy_xp_function(numpy_min, static_argnames="axis")
207+
try:
208+
# Test an arbitrary Cython ufunc (@cython.vectorize).
209+
# When SCIPY_ARRAY_API is not set, this is the same as
210+
# scipy.special.erf.
211+
from scipy.special._ufuncs import erf # type: ignore[import-not-found]
209212

213+
lazy_xp_function(erf) # pyright: ignore[reportUnknownArgumentType]
214+
except ImportError:
215+
erf = None
210216

211-
def test_lazy_xp_function_ufunc(xp: ModuleType, library: Backend):
212-
x = xp.asarray([[1, 4], [3, 2]])
213-
if library in (Backend.ARRAY_API_STRICT, Backend.TORCH, Backend.JAX):
214-
# array-api-strict, torch and jax don't define __array_ufunc__
215-
# numpy ufuncs can't auto-convert to numpy from torch
217+
218+
@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # torch
219+
def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
220+
pytest.importorskip("scipy")
221+
assert erf is not None
222+
x = xp.asarray([6.0, 7.0])
223+
if library in (Backend.ARRAY_API_STRICT, Backend.JAX):
216224
# array-api-strict arrays are auto-converted to numpy
217225
# eager jax arrays are auto-converted to numpy in eager jax
218226
# and fail in jax.jit (which lazy_xp_function tests here)
219227
with pytest.raises((TypeError, AssertionError)):
220-
xp_assert_equal(numpy_min(x, axis=0), xp.asarray([1, 2]))
228+
xp_assert_equal(erf(x), xp.asarray([1.0, 1.0]))
221229
else:
222230
# cupy, dask and sparse define __array_ufunc__ and dispatch accordingly
223231
# note that when sparse reduces to scalar it returns a np.generic, which
224232
# would make xp_assert_equal fail.
225-
xp_assert_equal(numpy_min(x, axis=0), xp.asarray([1, 2]))
233+
xp_assert_equal(erf(x), xp.asarray([1.0, 1.0]))

0 commit comments

Comments
 (0)