15
15
from typing import Pattern
16
16
from typing import Tuple
17
17
from typing import Type
18
+ from typing import TYPE_CHECKING
18
19
from typing import TypeVar
19
20
from typing import Union
20
21
22
+ if TYPE_CHECKING :
23
+ from numpy import ndarray
24
+
25
+
21
26
import _pytest ._code
22
27
from _pytest .compat import final
23
28
from _pytest .compat import STRING_TYPES
@@ -232,10 +237,11 @@ def __repr__(self) -> str:
232
237
def __eq__ (self , actual ) -> bool :
233
238
"""Return whether the given value is equal to the expected value
234
239
within the pre-specified tolerance."""
235
- if _is_numpy_array (actual ):
240
+ asarray = _as_numpy_array (actual )
241
+ if asarray is not None :
236
242
# Call ``__eq__()`` manually to prevent infinite-recursion with
237
243
# numpy<1.13. See #3748.
238
- return all (self .__eq__ (a ) for a in actual .flat )
244
+ return all (self .__eq__ (a ) for a in asarray .flat )
239
245
240
246
# Short-circuit exact equality.
241
247
if actual == self .expected :
@@ -521,6 +527,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
521
527
elif isinstance (expected , Mapping ):
522
528
cls = ApproxMapping
523
529
elif _is_numpy_array (expected ):
530
+ expected = _as_numpy_array (expected )
524
531
cls = ApproxNumpy
525
532
elif (
526
533
isinstance (expected , Iterable )
@@ -536,16 +543,30 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
536
543
537
544
538
545
def _is_numpy_array (obj : object ) -> bool :
539
- """Return true if the given object is a numpy array.
546
+ """
547
+ Return true if the given object is implicitly convertible to ndarray,
548
+ and numpy is already imported.
549
+ """
550
+ return _as_numpy_array (obj ) is not None
551
+
540
552
541
- A special effort is made to avoid importing numpy unless it's really necessary.
553
+ def _as_numpy_array (obj : object ) -> Optional ["ndarray" ]:
554
+ """
555
+ Return an ndarray if the given object is implicitly convertible to ndarray,
556
+ and numpy is already imported, otherwise None.
542
557
"""
543
558
import sys
544
559
545
560
np : Any = sys .modules .get ("numpy" )
546
561
if np is not None :
547
- return isinstance (obj , np .ndarray )
548
- return False
562
+ # avoid infinite recursion on numpy scalars, which have __array__
563
+ if np .isscalar (obj ):
564
+ return None
565
+ elif isinstance (obj , np .ndarray ):
566
+ return obj
567
+ elif hasattr (obj , "__array__" ) or hasattr ("obj" , "__array_interface__" ):
568
+ return np .asarray (obj )
569
+ return None
549
570
550
571
551
572
# builtin pytest.raises helper
0 commit comments