Skip to content

Commit 5add91d

Browse files
authored
unittest: Improve self.assert(Not)AlmostEqual(s) (#8066)
1 parent 7de1ed9 commit 5add91d

File tree

3 files changed

+67
-22
lines changed

3 files changed

+67
-22
lines changed

stdlib/_typeshed/__init__.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class SupportsAdd(Protocol[_T_contra, _T_co]):
7575
class SupportsRAdd(Protocol[_T_contra, _T_co]):
7676
def __radd__(self, __x: _T_contra) -> _T_co: ...
7777

78+
class SupportsSub(Protocol[_T_contra, _T_co]):
79+
def __sub__(self, __x: _T_contra) -> _T_co: ...
80+
7881
class SupportsDivMod(Protocol[_T_contra, _T_co]):
7982
def __divmod__(self, __other: _T_contra) -> _T_co: ...
8083

stdlib/unittest/case.pyi

+35-22
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1-
import datetime
21
import logging
32
import sys
43
import unittest.result
5-
from _typeshed import Self
4+
from _typeshed import Self, SupportsDunderGE, SupportsSub
65
from collections.abc import Callable, Container, Iterable, Mapping, Sequence, Set as AbstractSet
76
from contextlib import AbstractContextManager
87
from types import TracebackType
9-
from typing import Any, AnyStr, ClassVar, Generic, NamedTuple, NoReturn, Pattern, TypeVar, overload
8+
from typing import (
9+
Any,
10+
AnyStr,
11+
ClassVar,
12+
Generic,
13+
NamedTuple,
14+
NoReturn,
15+
Pattern,
16+
Protocol,
17+
SupportsAbs,
18+
SupportsRound,
19+
TypeVar,
20+
overload,
21+
)
1022
from typing_extensions import ParamSpec
1123
from warnings import WarningMessage
1224

1325
if sys.version_info >= (3, 9):
1426
from types import GenericAlias
1527

1628
_T = TypeVar("_T")
29+
_S = TypeVar("_S", bound=SupportsSub[Any, Any])
1730
_E = TypeVar("_E", bound=BaseException)
1831
_FT = TypeVar("_FT", bound=Callable[..., Any])
1932
_P = ParamSpec("_P")
@@ -62,6 +75,8 @@ def skipUnless(condition: object, reason: str) -> Callable[[_FT], _FT]: ...
6275
class SkipTest(Exception):
6376
def __init__(self, reason: str) -> None: ...
6477

78+
class _SupportsAbsAndDunderGE(SupportsDunderGE, SupportsAbs[Any], Protocol): ...
79+
6580
class TestCase:
6681
failureException: type[BaseException]
6782
longMessage: bool
@@ -165,33 +180,35 @@ class TestCase:
165180
self, logger: str | logging.Logger | None = ..., level: int | str | None = ...
166181
) -> _AssertLogsContext[None]: ...
167182

183+
@overload
184+
def assertAlmostEqual(self, first: _S, second: _S, places: None, msg: Any, delta: _SupportsAbsAndDunderGE) -> None: ...
168185
@overload
169186
def assertAlmostEqual(
170-
self, first: float, second: float, places: int | None = ..., msg: Any = ..., delta: float | None = ...
187+
self, first: _S, second: _S, places: None = ..., msg: Any = ..., *, delta: _SupportsAbsAndDunderGE
171188
) -> None: ...
172189
@overload
173190
def assertAlmostEqual(
174191
self,
175-
first: datetime.datetime,
176-
second: datetime.datetime,
192+
first: SupportsSub[_T, SupportsAbs[SupportsRound[object]]],
193+
second: _T,
177194
places: int | None = ...,
178195
msg: Any = ...,
179-
delta: datetime.timedelta | None = ...,
196+
delta: None = ...,
180197
) -> None: ...
181198
@overload
182-
def assertNotAlmostEqual(self, first: float, second: float, *, msg: Any = ...) -> None: ...
199+
def assertNotAlmostEqual(self, first: _S, second: _S, places: None, msg: Any, delta: _SupportsAbsAndDunderGE) -> None: ...
183200
@overload
184-
def assertNotAlmostEqual(self, first: float, second: float, places: int | None = ..., msg: Any = ...) -> None: ...
185-
@overload
186-
def assertNotAlmostEqual(self, first: float, second: float, *, msg: Any = ..., delta: float | None = ...) -> None: ...
201+
def assertNotAlmostEqual(
202+
self, first: _S, second: _S, places: None = ..., msg: Any = ..., *, delta: _SupportsAbsAndDunderGE
203+
) -> None: ...
187204
@overload
188205
def assertNotAlmostEqual(
189206
self,
190-
first: datetime.datetime,
191-
second: datetime.datetime,
207+
first: SupportsSub[_T, SupportsAbs[SupportsRound[object]]],
208+
second: _T,
192209
places: int | None = ...,
193210
msg: Any = ...,
194-
delta: datetime.timedelta | None = ...,
211+
delta: None = ...,
195212
) -> None: ...
196213
def assertRegex(self, text: AnyStr, expected_regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ...
197214
def assertNotRegex(self, text: AnyStr, unexpected_regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ...
@@ -249,14 +266,10 @@ class TestCase:
249266
) -> None: ...
250267
@overload
251268
def failUnlessRaises(self, exception: type[_E] | tuple[type[_E], ...], msg: Any = ...) -> _AssertRaisesContext[_E]: ...
252-
def failUnlessAlmostEqual(self, first: float, second: float, places: int = ..., msg: Any = ...) -> None: ...
253-
def assertAlmostEquals(
254-
self, first: float, second: float, places: int = ..., msg: Any = ..., delta: float = ...
255-
) -> None: ...
256-
def failIfAlmostEqual(self, first: float, second: float, places: int = ..., msg: Any = ...) -> None: ...
257-
def assertNotAlmostEquals(
258-
self, first: float, second: float, places: int = ..., msg: Any = ..., delta: float = ...
259-
) -> None: ...
269+
failUnlessAlmostEqual = assertAlmostEqual
270+
assertAlmostEquals = assertAlmostEqual
271+
failIfAlmostEqual = assertNotAlmostEqual
272+
assertNotAlmostEquals = assertNotAlmostEqual
260273
def assertRegexpMatches(self, text: AnyStr, regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ...
261274
def assertNotRegexpMatches(self, text: AnyStr, regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ...
262275
@overload

test_cases/stdlib/test_unittest.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# pyright: reportUnnecessaryTypeIgnoreComment=true
2+
3+
import unittest
4+
from datetime import datetime, timedelta
5+
from decimal import Decimal
6+
from fractions import Fraction
7+
8+
case = unittest.TestCase()
9+
10+
case.assertAlmostEqual(2.4, 2.41)
11+
case.assertAlmostEqual(Fraction(49, 50), Fraction(48, 50))
12+
case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1))
13+
case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), None, "foo", timedelta(hours=1))
14+
case.assertAlmostEqual(Decimal("1.1"), Decimal("1.11"))
15+
case.assertAlmostEqual(2.4, 2.41, places=8)
16+
case.assertAlmostEqual(2.4, 2.41, delta=0.02)
17+
case.assertAlmostEqual(2.4, 2.41, None, "foo", 0.02)
18+
19+
case.assertAlmostEqual(2.4, 2.41, places=9, delta=0.02) # type: ignore[call-overload]
20+
case.assertAlmostEqual("foo", "bar") # type: ignore[call-overload]
21+
case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore[arg-type]
22+
23+
case.assertNotAlmostEqual(Fraction(49, 50), Fraction(48, 50))
24+
case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1))
25+
case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), None, "foo", timedelta(hours=1))
26+
27+
case.assertNotAlmostEqual(2.4, 2.41, places=9, delta=0.02) # type: ignore[call-overload]
28+
case.assertNotAlmostEqual("foo", "bar") # type: ignore[call-overload]
29+
case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore[arg-type]

0 commit comments

Comments
 (0)