Skip to content

Commit a16e8ea

Browse files
authored
approx: use exact comparison for bool
Fixes #9353
1 parent b938e70 commit a16e8ea

File tree

3 files changed

+49
-19
lines changed

3 files changed

+49
-19
lines changed

changelog/9353.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:func:`pytest.approx` now uses strict equality when given booleans.

src/_pytest/python_api.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -259,19 +259,22 @@ def _repr_compare(self, other_side: Mapping[object, float]) -> list[str]:
259259
):
260260
if approx_value != other_value:
261261
if approx_value.expected is not None and other_value is not None:
262-
max_abs_diff = max(
263-
max_abs_diff, abs(approx_value.expected - other_value)
264-
)
265-
if approx_value.expected == 0.0:
266-
max_rel_diff = math.inf
267-
else:
268-
max_rel_diff = max(
269-
max_rel_diff,
270-
abs(
271-
(approx_value.expected - other_value)
272-
/ approx_value.expected
273-
),
262+
try:
263+
max_abs_diff = max(
264+
max_abs_diff, abs(approx_value.expected - other_value)
274265
)
266+
if approx_value.expected == 0.0:
267+
max_rel_diff = math.inf
268+
else:
269+
max_rel_diff = max(
270+
max_rel_diff,
271+
abs(
272+
(approx_value.expected - other_value)
273+
/ approx_value.expected
274+
),
275+
)
276+
except ZeroDivisionError:
277+
pass
275278
different_ids.append(approx_key)
276279

277280
message_data = [
@@ -395,8 +398,10 @@ def __repr__(self) -> str:
395398
# Don't show a tolerance for values that aren't compared using
396399
# tolerances, i.e. non-numerics and infinities. Need to call abs to
397400
# handle complex numbers, e.g. (inf + 1j).
398-
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
399-
abs(self.expected)
401+
if (
402+
isinstance(self.expected, bool)
403+
or (not isinstance(self.expected, (Complex, Decimal)))
404+
or math.isinf(abs(self.expected) or isinstance(self.expected, bool))
400405
):
401406
return str(self.expected)
402407

@@ -428,14 +433,17 @@ def __eq__(self, actual) -> bool:
428433
# numpy<1.13. See #3748.
429434
return all(self.__eq__(a) for a in asarray.flat)
430435

431-
# Short-circuit exact equality.
432-
if actual == self.expected:
436+
# Short-circuit exact equality, except for bool
437+
if isinstance(self.expected, bool) and not isinstance(actual, bool):
438+
return False
439+
elif actual == self.expected:
433440
return True
434441

435442
# If either type is non-numeric, fall back to strict equality.
436443
# NB: we need Complex, rather than just Number, to ensure that __abs__,
437-
# __sub__, and __float__ are defined.
438-
if not (
444+
# __sub__, and __float__ are defined. Also, consider bool to be
445+
# nonnumeric, even though it has the required arithmetic.
446+
if isinstance(self.expected, bool) or not (
439447
isinstance(self.expected, (Complex, Decimal))
440448
and isinstance(actual, (Complex, Decimal))
441449
):

testing/python/approx.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,26 @@ def do_assert(lhs, rhs, expected_message, verbosity_level=0):
9090
return do_assert
9191

9292

93-
SOME_FLOAT = r"[+-]?([0-9]*[.])?[0-9]+\s*"
93+
SOME_FLOAT = r"[+-]?((?:([0-9]*[.])?[0-9]+(e-?[0-9]+)?)|inf|nan)\s*"
9494
SOME_INT = r"[0-9]+\s*"
9595
SOME_TOLERANCE = rf"({SOME_FLOAT}|[+-]?[0-9]+(\.[0-9]+)?[eE][+-]?[0-9]+\s*)"
9696

9797

9898
class TestApprox:
9999
def test_error_messages_native_dtypes(self, assert_approx_raises_regex):
100+
# Treat bool exactly.
101+
assert_approx_raises_regex(
102+
{"a": 1.0, "b": True},
103+
{"a": 1.0, "b": False},
104+
[
105+
"",
106+
" comparison failed. Mismatched elements: 1 / 2:",
107+
f" Max absolute difference: {SOME_FLOAT}",
108+
f" Max relative difference: {SOME_FLOAT}",
109+
r" Index\s+\| Obtained\s+\| Expected",
110+
r".*(True|False)\s+",
111+
],
112+
)
100113
assert_approx_raises_regex(
101114
2.0,
102115
1.0,
@@ -596,6 +609,13 @@ def test_complex(self):
596609
assert approx(x, rel=5e-6, abs=0) == a
597610
assert approx(x, rel=5e-7, abs=0) != a
598611

612+
def test_expecting_bool(self) -> None:
613+
assert True == approx(True) # noqa: E712
614+
assert False == approx(False) # noqa: E712
615+
assert True != approx(False) # noqa: E712
616+
assert True != approx(False, abs=2) # noqa: E712
617+
assert 1 != approx(True)
618+
599619
def test_list(self):
600620
actual = [1 + 1e-7, 2 + 1e-8]
601621
expected = [1, 2]
@@ -661,6 +681,7 @@ def test_dict_wrong_len(self):
661681
def test_dict_nonnumeric(self):
662682
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
663683
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})
684+
assert {"a": 1.0, "b": True} != pytest.approx({"a": 1.0, "b": False}, abs=2)
664685

665686
def test_dict_vs_other(self):
666687
assert 1 != approx({"a": 0})

0 commit comments

Comments
 (0)