@@ -46,60 +46,13 @@ def _yield_comparisons(self, actual):
46
46
raise NotImplementedError
47
47
48
48
49
- class ApproxNumpyBase (ApproxBase ):
49
+ class ApproxNumpy (ApproxBase ):
50
50
"""
51
51
Perform approximate comparisons for numpy arrays.
52
-
53
- This class should not be used directly. Instead, the `inherit_ndarray()`
54
- class method should be used to make a subclass that also inherits from
55
- `np.ndarray`. This indirection is necessary because the object doing the
56
- approximate comparison must inherit from `np.ndarray`, or it will only work
57
- on the left side of the `==` operator. But importing numpy is relatively
58
- expensive, so we also want to avoid that unless we actually have a numpy
59
- array to compare.
60
-
61
- The reason why the approx object needs to inherit from `np.ndarray` has to
62
- do with how python decides whether to call `a.__eq__()` or `b.__eq__()`
63
- when it parses `a == b`. If `a` and `b` are not related by inheritance,
64
- `a` gets priority. So as long as `a.__eq__` is defined, it will be called.
65
- Because most implementations of `a.__eq__` end up calling `b.__eq__`, this
66
- detail usually doesn't matter. However, `np.ndarray.__eq__` treats the
67
- approx object as a scalar and builds a new array by comparing it to each
68
- item in the original array. `b.__eq__` is called to compare against each
69
- individual element in the array, but it has no way (that I can see) to
70
- prevent the return value from being an boolean array, and boolean arrays
71
- can't be used with assert because "the truth value of an array with more
72
- than one element is ambiguous."
73
-
74
- The trick is that the priority rules change if `a` and `b` are related
75
- by inheritance. Specifically, `b.__eq__` gets priority if `b` is a
76
- subclass of `a`. So by inheriting from `np.ndarray`, we can guarantee that
77
- `ApproxNumpy.__eq__` gets called no matter which side of the `==` operator
78
- it appears on.
79
52
"""
80
53
81
- subclass = None
82
-
83
- @classmethod
84
- def inherit_ndarray (cls ):
85
- import numpy as np
86
- assert not isinstance (cls , np .ndarray )
87
-
88
- if cls .subclass is None :
89
- cls .subclass = type ('ApproxNumpy' , (cls , np .ndarray ), {})
90
-
91
- return cls .subclass
92
-
93
- def __new__ (cls , expected , rel = None , abs = None , nan_ok = False ):
94
- """
95
- Numpy uses __new__ (rather than __init__) to initialize objects.
96
-
97
- The `expected` argument must be a numpy array. This should be
98
- ensured by the approx() delegator function.
99
- """
100
- obj = super (ApproxNumpyBase , cls ).__new__ (cls , ())
101
- obj .__init__ (expected , rel , abs , nan_ok )
102
- return obj
54
+ # Tell numpy to use our `__eq__` operator instead of its.
55
+ __array_priority__ = 100
103
56
104
57
def __repr__ (self ):
105
58
# It might be nice to rewrite this function to account for the
@@ -113,7 +66,7 @@ def __eq__(self, actual):
113
66
try :
114
67
actual = np .asarray (actual )
115
68
except :
116
- raise ValueError ("cannot cast '{0}' to numpy.ndarray" .format (actual ))
69
+ raise TypeError ("cannot compare '{0}' to numpy.ndarray" .format (actual ))
117
70
118
71
if actual .shape != self .expected .shape :
119
72
return False
@@ -157,6 +110,9 @@ class ApproxSequence(ApproxBase):
157
110
Perform approximate comparisons for sequences of numbers.
158
111
"""
159
112
113
+ # Tell numpy to use our `__eq__` operator instead of its.
114
+ __array_priority__ = 100
115
+
160
116
def __repr__ (self ):
161
117
seq_type = type (self .expected )
162
118
if seq_type not in (tuple , list , set ):
@@ -422,9 +378,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
422
378
# their keys, which is probably not what most people would expect.
423
379
424
380
if _is_numpy_array (expected ):
425
- # Create the delegate class on the fly. This allow us to inherit from
426
- # ``np.ndarray`` while still not importing numpy unless we need to.
427
- cls = ApproxNumpyBase .inherit_ndarray ()
381
+ cls = ApproxNumpy
428
382
elif isinstance (expected , Mapping ):
429
383
cls = ApproxMapping
430
384
elif isinstance (expected , Sequence ) and not isinstance (expected , String ):
0 commit comments