Skip to content

Commit 1b732fe

Browse files
authored
Merge pull request #2606 from kalekundert/simplify-numpy
Make approx more compatible with numpy
2 parents b35554c + ebc7346 commit 1b732fe

File tree

2 files changed

+9
-55
lines changed

2 files changed

+9
-55
lines changed

_pytest/python_api.py

Lines changed: 8 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -46,60 +46,13 @@ def _yield_comparisons(self, actual):
4646
raise NotImplementedError
4747

4848

49-
class ApproxNumpyBase(ApproxBase):
49+
class ApproxNumpy(ApproxBase):
5050
"""
5151
Perform approximate comparisons for numpy arrays.
52-
53-
This class should not be used directly. Instead, the `inherit_ndarray()`
54-
class method should be used to make a subclass that also inherits from
55-
`np.ndarray`. This indirection is necessary because the object doing the
56-
approximate comparison must inherit from `np.ndarray`, or it will only work
57-
on the left side of the `==` operator. But importing numpy is relatively
58-
expensive, so we also want to avoid that unless we actually have a numpy
59-
array to compare.
60-
61-
The reason why the approx object needs to inherit from `np.ndarray` has to
62-
do with how python decides whether to call `a.__eq__()` or `b.__eq__()`
63-
when it parses `a == b`. If `a` and `b` are not related by inheritance,
64-
`a` gets priority. So as long as `a.__eq__` is defined, it will be called.
65-
Because most implementations of `a.__eq__` end up calling `b.__eq__`, this
66-
detail usually doesn't matter. However, `np.ndarray.__eq__` treats the
67-
approx object as a scalar and builds a new array by comparing it to each
68-
item in the original array. `b.__eq__` is called to compare against each
69-
individual element in the array, but it has no way (that I can see) to
70-
prevent the return value from being an boolean array, and boolean arrays
71-
can't be used with assert because "the truth value of an array with more
72-
than one element is ambiguous."
73-
74-
The trick is that the priority rules change if `a` and `b` are related
75-
by inheritance. Specifically, `b.__eq__` gets priority if `b` is a
76-
subclass of `a`. So by inheriting from `np.ndarray`, we can guarantee that
77-
`ApproxNumpy.__eq__` gets called no matter which side of the `==` operator
78-
it appears on.
7952
"""
8053

81-
subclass = None
82-
83-
@classmethod
84-
def inherit_ndarray(cls):
85-
import numpy as np
86-
assert not isinstance(cls, np.ndarray)
87-
88-
if cls.subclass is None:
89-
cls.subclass = type('ApproxNumpy', (cls, np.ndarray), {})
90-
91-
return cls.subclass
92-
93-
def __new__(cls, expected, rel=None, abs=None, nan_ok=False):
94-
"""
95-
Numpy uses __new__ (rather than __init__) to initialize objects.
96-
97-
The `expected` argument must be a numpy array. This should be
98-
ensured by the approx() delegator function.
99-
"""
100-
obj = super(ApproxNumpyBase, cls).__new__(cls, ())
101-
obj.__init__(expected, rel, abs, nan_ok)
102-
return obj
54+
# Tell numpy to use our `__eq__` operator instead of its.
55+
__array_priority__ = 100
10356

10457
def __repr__(self):
10558
# It might be nice to rewrite this function to account for the
@@ -113,7 +66,7 @@ def __eq__(self, actual):
11366
try:
11467
actual = np.asarray(actual)
11568
except:
116-
raise ValueError("cannot cast '{0}' to numpy.ndarray".format(actual))
69+
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
11770

11871
if actual.shape != self.expected.shape:
11972
return False
@@ -157,6 +110,9 @@ class ApproxSequence(ApproxBase):
157110
Perform approximate comparisons for sequences of numbers.
158111
"""
159112

113+
# Tell numpy to use our `__eq__` operator instead of its.
114+
__array_priority__ = 100
115+
160116
def __repr__(self):
161117
seq_type = type(self.expected)
162118
if seq_type not in (tuple, list, set):
@@ -422,9 +378,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
422378
# their keys, which is probably not what most people would expect.
423379

424380
if _is_numpy_array(expected):
425-
# Create the delegate class on the fly. This allow us to inherit from
426-
# ``np.ndarray`` while still not importing numpy unless we need to.
427-
cls = ApproxNumpyBase.inherit_ndarray()
381+
cls = ApproxNumpy
428382
elif isinstance(expected, Mapping):
429383
cls = ApproxMapping
430384
elif isinstance(expected, Sequence) and not isinstance(expected, String):

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ envlist =
1212
py36
1313
py37
1414
pypy
15-
{py27,py35}-{pexpect,xdist,trial}
15+
{py27,py35}-{pexpect,xdist,trial,numpy}
1616
py27-nobyte
1717
doctesting
1818
freeze

0 commit comments

Comments
 (0)