Skip to content

Commit becc7ab

Browse files
committed
MAINT: run self-tests even if a library is missing
1 parent 17beb40 commit becc7ab

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

Diff for: tests/test_array_namespace.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
import sys
33
import warnings
44

5-
import jax
65
import numpy as np
76
import pytest
8-
import torch
97

108
import array_api_compat
119
from array_api_compat import array_namespace
@@ -76,6 +74,7 @@ def test_array_namespace(library, api_version, use_compat):
7674
subprocess.run([sys.executable, "-c", code], check=True)
7775

7876
def test_jax_zero_gradient():
77+
jax = import_("jax")
7978
jx = jax.numpy.arange(4)
8079
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
8180
assert array_namespace(jax_zero) is array_namespace(jx)
@@ -89,11 +88,13 @@ def test_array_namespace_errors():
8988
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
9089

9190
def test_array_namespace_errors_torch():
91+
torch = import_("torch")
9292
y = torch.asarray([1, 2])
9393
x = np.asarray([1, 2])
9494
pytest.raises(TypeError, lambda: array_namespace(x, y))
9595

9696
def test_api_version_torch():
97+
torch = import_("torch")
9798
x = torch.asarray([1, 2])
9899
torch_ = import_("torch", wrapper=True)
99100
assert array_namespace(x, api_version="2023.12") == torch_
@@ -118,6 +119,7 @@ def test_get_namespace():
118119
assert array_api_compat.get_namespace is array_namespace
119120

120121
def test_python_scalars():
122+
torch = import_("torch")
121123
a = torch.asarray([1, 2])
122124
xp = import_("torch", wrapper=True)
123125

Diff for: tests/test_dask.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from contextlib import contextmanager
22

33
import array_api_strict
4-
import dask
54
import numpy as np
65
import pytest
7-
import dask.array as da
6+
7+
try:
8+
import dask
9+
import dask.array as da
10+
except ImportError:
11+
pytestmark = pytest.skip(allow_module_level=True, reason="dask not found")
812

913
from array_api_compat import array_namespace
1014

Diff for: tests/test_jax.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
import jax
2-
import jax.numpy as jnp
31
from numpy.testing import assert_equal
42
import pytest
53

64
from array_api_compat import device, to_device
75

6+
try:
7+
import jax
8+
import jax.numpy as jnp
9+
except ImportError:
10+
pytestmark = pytest.skip(allow_module_level=True, reason="jax not found")
11+
812
HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31"
913

1014

Diff for: tests/test_vendoring.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ def test_vendoring_cupy():
1616

1717

1818
def test_vendoring_torch():
19+
pytest.importorskip("torch")
1920
from vendor_test import uses_torch
2021

2122
uses_torch._test_torch()
2223

2324

2425
def test_vendoring_dask():
26+
pytest.importorskip("dask")
2527
from vendor_test import uses_dask
2628
uses_dask._test_dask()

0 commit comments

Comments
 (0)