Skip to content

Commit 438310c

Browse files
committed
BUG: isclose PyTorch Array API 2024.12 compliance
1 parent 70c7c80 commit 438310c

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

src/array_api_extra/_delegation.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from ._lib import Backend, _funcs
88
from ._lib._utils._compat import array_namespace
9+
from ._lib._utils._helpers import asarrays
910
from ._lib._utils._typing import Array
1011

1112
__all__ = ["isclose", "pad"]
@@ -107,14 +108,11 @@ def isclose(
107108
"""
108109
xp = array_namespace(a, b) if xp is None else xp
109110

110-
if _delegate(
111-
xp,
112-
Backend.NUMPY,
113-
Backend.CUPY,
114-
Backend.DASK,
115-
Backend.JAX,
116-
Backend.TORCH,
117-
):
111+
if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX):
112+
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
113+
114+
if _delegate(xp, Backend.TORCH):
115+
a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support
118116
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
119117

120118
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)

tests/test_funcs.py

-1
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,6 @@ def test_none_shape_bool(self, xp: ModuleType):
689689
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
690690

691691
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
692-
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
693692
def test_python_scalar(self, xp: ModuleType):
694693
a = xp.asarray([0.0, 0.1], dtype=xp.float32)
695694
xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))

0 commit comments

Comments
 (0)