Skip to content

Commit 780044b

Browse files
authored
Merge pull request #8147 from nicoddemus/backport-8137
[6.2.x] python_api: handle array-like args in approx() #8137
2 parents 8b8b121 + 8354995 commit 780044b

File tree

3 files changed

+67
-6
lines changed

3 files changed

+67
-6
lines changed

changelog/8132.bugfix.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Fixed regression in ``approx``: in 6.2.0 ``approx`` no longer raises
2+
``TypeError`` when dealing with non-numeric types, falling back to normal comparison.
3+
Before 6.2.0, array types like tf.DeviceArray fell through to the scalar case,
4+
and happened to compare correctly to a scalar if they had only one element.
5+
After 6.2.0, these types began failing, because they inherited neither from
6+
standard Python number hierarchy nor from ``numpy.ndarray``.
7+
8+
``approx`` now converts arguments to ``numpy.ndarray`` if they expose the array
9+
protocol and are not scalars. This treats array-like objects like numpy arrays,
10+
regardless of size.

src/_pytest/python_api.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515
from typing import Pattern
1616
from typing import Tuple
1717
from typing import Type
18+
from typing import TYPE_CHECKING
1819
from typing import TypeVar
1920
from typing import Union
2021

22+
if TYPE_CHECKING:
23+
from numpy import ndarray
24+
25+
2126
import _pytest._code
2227
from _pytest.compat import final
2328
from _pytest.compat import STRING_TYPES
@@ -232,10 +237,11 @@ def __repr__(self) -> str:
232237
def __eq__(self, actual) -> bool:
233238
"""Return whether the given value is equal to the expected value
234239
within the pre-specified tolerance."""
235-
if _is_numpy_array(actual):
240+
asarray = _as_numpy_array(actual)
241+
if asarray is not None:
236242
# Call ``__eq__()`` manually to prevent infinite-recursion with
237243
# numpy<1.13. See #3748.
238-
return all(self.__eq__(a) for a in actual.flat)
244+
return all(self.__eq__(a) for a in asarray.flat)
239245

240246
# Short-circuit exact equality.
241247
if actual == self.expected:
@@ -521,6 +527,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
521527
elif isinstance(expected, Mapping):
522528
cls = ApproxMapping
523529
elif _is_numpy_array(expected):
530+
expected = _as_numpy_array(expected)
524531
cls = ApproxNumpy
525532
elif (
526533
isinstance(expected, Iterable)
@@ -536,16 +543,30 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
536543

537544

538545
def _is_numpy_array(obj: object) -> bool:
539-
"""Return true if the given object is a numpy array.
546+
"""
547+
Return true if the given object is implicitly convertible to ndarray,
548+
and numpy is already imported.
549+
"""
550+
return _as_numpy_array(obj) is not None
551+
540552

541-
A special effort is made to avoid importing numpy unless it's really necessary.
553+
def _as_numpy_array(obj: object) -> Optional["ndarray"]:
554+
"""
555+
Return an ndarray if the given object is implicitly convertible to ndarray,
556+
and numpy is already imported, otherwise None.
542557
"""
543558
import sys
544559

545560
np: Any = sys.modules.get("numpy")
546561
if np is not None:
547-
return isinstance(obj, np.ndarray)
548-
return False
562+
# avoid infinite recursion on numpy scalars, which have __array__
563+
if np.isscalar(obj):
564+
return None
565+
elif isinstance(obj, np.ndarray):
566+
return obj
567+
elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"):
568+
return np.asarray(obj)
569+
return None
549570

550571

551572
# builtin pytest.raises helper

testing/python/approx.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,36 @@ def test_numpy_array_wrong_shape(self):
447447
assert a12 != approx(a21)
448448
assert a21 != approx(a12)
449449

450+
def test_numpy_array_protocol(self):
451+
"""
452+
array-like objects such as tensorflow's DeviceArray are handled like ndarray.
453+
See issue #8132
454+
"""
455+
np = pytest.importorskip("numpy")
456+
457+
class DeviceArray:
458+
def __init__(self, value, size):
459+
self.value = value
460+
self.size = size
461+
462+
def __array__(self):
463+
return self.value * np.ones(self.size)
464+
465+
class DeviceScalar:
466+
def __init__(self, value):
467+
self.value = value
468+
469+
def __array__(self):
470+
return np.array(self.value)
471+
472+
expected = 1
473+
actual = 1 + 1e-6
474+
assert approx(expected) == DeviceArray(actual, size=1)
475+
assert approx(expected) == DeviceArray(actual, size=2)
476+
assert approx(expected) == DeviceScalar(actual)
477+
assert approx(DeviceScalar(expected)) == actual
478+
assert approx(DeviceScalar(expected)) == DeviceScalar(actual)
479+
450480
def test_doctests(self, mocked_doctest_runner) -> None:
451481
import doctest
452482

0 commit comments

Comments
 (0)