Skip to content

Commit a228a13

Browse files
committed
BUG: isclose PyTorch Array API 2024.12 compliance
1 parent 70c7c80 commit a228a13

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/array_api_extra/_delegation.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from ._lib import Backend, _funcs
88
from ._lib._utils._compat import array_namespace
9+
from ._lib._utils._helpers import asarrays
910
from ._lib._utils._typing import Array
1011

1112
__all__ = ["isclose", "pad"]
@@ -112,11 +113,14 @@ def isclose(
112113
Backend.NUMPY,
113114
Backend.CUPY,
114115
Backend.DASK,
115-
Backend.JAX,
116-
Backend.TORCH,
116+
Backend.JAX
117117
):
118118
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
119119

120+
if _delegate(xp, Backend.TORCH):
121+
a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support
122+
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
123+
120124
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
121125

122126

tests/test_funcs.py

-1
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,6 @@ def test_none_shape_bool(self, xp: ModuleType):
689689
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
690690

691691
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
692-
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
693692
def test_python_scalar(self, xp: ModuleType):
694693
a = xp.asarray([0.0, 0.1], dtype=xp.float32)
695694
xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))

0 commit comments

Comments
 (0)