From 116558ab043a119f76bce1d32c0aa96fd5da7ae7 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 18 Feb 2025 13:53:38 +0000 Subject: [PATCH] BUG: `isclose` finite vs. infinite --- src/array_api_extra/_lib/_funcs.py | 8 ++++++-- tests/test_funcs.py | 22 ++++++++++++---------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index a5729559..bd1b5f06 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -386,8 +386,12 @@ def isclose( b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating")) if a_inexact or b_inexact: # FIXME: use scipy's lazywhere to suppress warnings on inf - out = xp.abs(a - b) <= (atol + rtol * xp.abs(b)) - out = xp.where(xp.isinf(a) & xp.isinf(b), xp.sign(a) == xp.sign(b), out) + out = xp.where( + xp.isinf(a) | xp.isinf(b), + xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)), + # Note: inf <= inf is True! + xp.abs(a - b) <= (atol + rtol * xp.abs(b)), + ) if equal_nan: out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out) return out diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 84d2f5d1..f7a2c4fb 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -336,6 +336,7 @@ def test_xp(self, xp: ModuleType): class TestIsClose: # FIXME use lazywhere to avoid warnings on inf @pytest.mark.filterwarnings("ignore:invalid value encountered") + @pytest.mark.parametrize("swap", [False, True]) @pytest.mark.parametrize( ("a", "b"), [ @@ -353,9 +354,9 @@ class TestIsClose: (float("inf"), float("inf")), (float("inf"), 100.0), (float("inf"), float("-inf")), + (float("-inf"), float("-inf")), (float("nan"), float("nan")), - (float("nan"), 0.0), - (0.0, float("nan")), + (float("nan"), 100.0), (1e6, 1e6 + 1), # True - within rtol (1e6, 1e6 + 100), # False - outside rtol (1e-6, 1.1e-6), # False - outside atol @@ -364,7 +365,9 @@ class TestIsClose: (1e6 + 0j, 1e6 + 100j), # False - outside rtol ], ) - def test_basic(self, a: float, b: float, xp: ModuleType): + def test_basic(self, a: float, b: float, swap: bool, xp: ModuleType): + if swap: + b, a = a, b a_xp = xp.asarray(a) b_xp = xp.asarray(b) @@ -372,11 +375,10 @@ def test_basic(self, a: float, b: float, xp: ModuleType): with warnings.catch_warnings(): warnings.simplefilter("ignore") - r_xp = xp.asarray(np.arange(10), dtype=a_xp.dtype) - ar_xp = a_xp * r_xp - br_xp = b_xp * r_xp ar_np = a * np.arange(10) br_np = b * np.arange(10) + ar_xp = xp.asarray(ar_np) + br_xp = xp.asarray(br_np) xp_assert_equal(isclose(ar_xp, br_xp), xp.asarray(np.isclose(ar_np, br_np))) @@ -395,14 +397,14 @@ def test_broadcast(self, dtype: str, xp: ModuleType): # FIXME use lazywhere to avoid warnings on inf @pytest.mark.filterwarnings("ignore:invalid value encountered") def test_some_inf(self, xp: ModuleType): - a = xp.asarray([0.0, 1.0, float("inf"), float("inf"), float("inf")]) - b = xp.asarray([1e-9, 1.0, float("inf"), float("-inf"), 2.0]) + a = xp.asarray([0.0, 1.0, xp.inf, xp.inf, xp.inf]) + b = xp.asarray([1e-9, 1.0, xp.inf, -xp.inf, 2.0]) actual = isclose(a, b) xp_assert_equal(actual, xp.asarray([True, True, True, False, False])) def test_equal_nan(self, xp: ModuleType): - a = xp.asarray([float("nan"), float("nan"), 1.0]) - b = xp.asarray([float("nan"), 1.0, float("nan")]) + a = xp.asarray([xp.nan, xp.nan, 1.0]) + b = xp.asarray([xp.nan, 1.0, xp.nan]) xp_assert_equal(isclose(a, b), xp.asarray([False, False, False])) xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False]))