Skip to content

Commit b1e6639

Browse files
committed
test the right way
1 parent be8a4b1 commit b1e6639

File tree

4 files changed

+9
-16
lines changed

4 files changed

+9
-16
lines changed

Diff for: src/array_api_extra/_lib/_backends.py

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
2424
"""
2525

2626
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
27-
ARRAY_API_STRICT_DEVICE1 = "array_api_strict", _compat.is_array_api_strict_namespace
2827
NUMPY = "numpy", _compat.is_numpy_namespace
2928
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
3029
CUPY = "cupy", _compat.is_cupy_namespace

Diff for: src/array_api_extra/_lib/_testing.py

-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import cast
1111

1212
import pytest
13-
from array_api_compat import is_array_api_strict_namespace
1413

1514
from ._utils._compat import (
1615
array_namespace,
@@ -107,11 +106,6 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
107106
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
108107

109108
# JAX uses `np.testing`
110-
if is_array_api_strict_namespace(xp):
111-
# Have to move to CPU for array API strict devices before
112-
# we're allowed to convert into numpy
113-
actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
114-
desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
115109
np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
116110

117111

Diff for: tests/conftest.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def xp(
113113
The current array namespace.
114114
"""
115115
if library == Backend.NUMPY_READONLY:
116-
return NumPyReadOnly(), None # type: ignore[return-value] # pyright: ignore[reportReturnType]
116+
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
117117
xp = pytest.importorskip(library.value)
118118
# Possibly wrap module with array_api_compat
119119
xp = array_namespace(xp.empty(0))
@@ -131,11 +131,7 @@ def xp(
131131
# suppress unused-ignore to run mypy in -e lint as well as -e dev
132132
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
133133

134-
device = None
135-
if library == Backend.ARRAY_API_STRICT_DEVICE1:
136-
import array_api_strict
137-
device = array_api_strict.Device("device1")
138-
return xp, device
134+
return xp
139135

140136

141137
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`

Diff for: tests/test_funcs.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,8 @@ def test_some_inf(self, xp: ModuleType):
633633
xp_assert_equal(actual, xp.asarray([True, True, True, False, False]))
634634

635635
def test_equal_nan(self, xp: ModuleType):
636-
xp, device = xp
637-
a = xp.asarray([xp.nan, xp.nan, 1.0], device=device)
638-
b = xp.asarray([xp.nan, 1.0, xp.nan], device=device)
636+
a = xp.asarray([xp.nan, xp.nan, 1.0])
637+
b = xp.asarray([xp.nan, 1.0, xp.nan])
639638
xp_assert_equal(isclose(a, b), xp.asarray([False, False, False]))
640639
xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False]))
641640

@@ -721,6 +720,11 @@ def test_xp(self, xp: ModuleType):
721720
b = xp.asarray([1e-9, 1e-4])
722721
xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False]))
723722

723+
@pytest.mark.parametrize("equal_nan", [True, False])
724+
def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
725+
a = xp.asarray([0.0, 0.0, xp.nan], device=device)
726+
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
727+
xp_assert_equal(isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan]))
724728

725729
class TestKron:
726730
def test_basic(self, xp: ModuleType):

0 commit comments

Comments
 (0)