Skip to content

Commit 88a7f68

Browse files
authored
Have Protocol inherit from typing.Generic on 3.8+ (#184)
1 parent b306e56 commit 88a7f68

File tree

3 files changed

+274
-129
lines changed

3 files changed

+274
-129
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
- Change deprecated `@runtime` to formal API `@runtime_checkable` in the error
44
message. Patch by Xuehai Pan.
5+
- Fix regression in 4.6.0 where attempting to define a `Protocol` that was
6+
generic over a `ParamSpec` or a `TypeVarTuple` would cause `TypeError` to be
7+
raised. Patch by Alex Waygood.
58

69
# Release 4.6.0 (May 22, 2023)
710

src/test_typing_extensions.py

+101-19
Original file line numberDiff line numberDiff line change
@@ -2613,6 +2613,62 @@ class CustomProtocolWithoutInitB(Protocol):
26132613

26142614
self.assertEqual(CustomProtocolWithoutInitA.__init__, CustomProtocolWithoutInitB.__init__)
26152615

2616+
def test_protocol_generic_over_paramspec(self):
2617+
P = ParamSpec("P")
2618+
T = TypeVar("T")
2619+
T2 = TypeVar("T2")
2620+
2621+
class MemoizedFunc(Protocol[P, T, T2]):
2622+
cache: typing.Dict[T2, T]
2623+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
2624+
2625+
self.assertEqual(MemoizedFunc.__parameters__, (P, T, T2))
2626+
self.assertTrue(MemoizedFunc._is_protocol)
2627+
2628+
with self.assertRaises(TypeError):
2629+
MemoizedFunc[[int, str, str]]
2630+
2631+
if sys.version_info >= (3, 10):
2632+
# These unfortunately don't pass on <=3.9,
2633+
# due to typing._type_check on older Python versions
2634+
X = MemoizedFunc[[int, str, str], T, T2]
2635+
self.assertEqual(X.__parameters__, (T, T2))
2636+
self.assertEqual(X.__args__, ((int, str, str), T, T2))
2637+
2638+
Y = X[bytes, memoryview]
2639+
self.assertEqual(Y.__parameters__, ())
2640+
self.assertEqual(Y.__args__, ((int, str, str), bytes, memoryview))
2641+
2642+
def test_protocol_generic_over_typevartuple(self):
2643+
Ts = TypeVarTuple("Ts")
2644+
T = TypeVar("T")
2645+
T2 = TypeVar("T2")
2646+
2647+
class MemoizedFunc(Protocol[Unpack[Ts], T, T2]):
2648+
cache: typing.Dict[T2, T]
2649+
def __call__(self, *args: Unpack[Ts]) -> T: ...
2650+
2651+
self.assertEqual(MemoizedFunc.__parameters__, (Ts, T, T2))
2652+
self.assertTrue(MemoizedFunc._is_protocol)
2653+
2654+
things = "arguments" if sys.version_info >= (3, 11) else "parameters"
2655+
2656+
# A bug was fixed in 3.11.1
2657+
# (https://github.com/python/cpython/commit/74920aa27d0c57443dd7f704d6272cca9c507ab3)
2658+
# That means this assertion doesn't pass on 3.11.0,
2659+
# but it passes on all other Python versions
2660+
if sys.version_info[:3] != (3, 11, 0):
2661+
with self.assertRaisesRegex(TypeError, f"Too few {things}"):
2662+
MemoizedFunc[int]
2663+
2664+
X = MemoizedFunc[int, T, T2]
2665+
self.assertEqual(X.__parameters__, (T, T2))
2666+
self.assertEqual(X.__args__, (int, T, T2))
2667+
2668+
Y = X[bytes, memoryview]
2669+
self.assertEqual(Y.__parameters__, ())
2670+
self.assertEqual(Y.__args__, (int, bytes, memoryview))
2671+
26162672

26172673
class Point2DGeneric(Generic[T], TypedDict):
26182674
a: T
@@ -3402,13 +3458,18 @@ def test_user_generics(self):
34023458
class X(Generic[T, P]):
34033459
pass
34043460

3405-
G1 = X[int, P_2]
3406-
self.assertEqual(G1.__args__, (int, P_2))
3407-
self.assertEqual(G1.__parameters__, (P_2,))
3461+
class Y(Protocol[T, P]):
3462+
pass
3463+
3464+
for klass in X, Y:
3465+
with self.subTest(klass=klass.__name__):
3466+
G1 = klass[int, P_2]
3467+
self.assertEqual(G1.__args__, (int, P_2))
3468+
self.assertEqual(G1.__parameters__, (P_2,))
34083469

3409-
G2 = X[int, Concatenate[int, P_2]]
3410-
self.assertEqual(G2.__args__, (int, Concatenate[int, P_2]))
3411-
self.assertEqual(G2.__parameters__, (P_2,))
3470+
G2 = klass[int, Concatenate[int, P_2]]
3471+
self.assertEqual(G2.__args__, (int, Concatenate[int, P_2]))
3472+
self.assertEqual(G2.__parameters__, (P_2,))
34123473

34133474
# The following are some valid uses cases in PEP 612 that don't work:
34143475
# These do not work in 3.9, _type_check blocks the list and ellipsis.
@@ -3421,6 +3482,9 @@ class X(Generic[T, P]):
34213482
class Z(Generic[P]):
34223483
pass
34233484

3485+
class ProtoZ(Protocol[P]):
3486+
pass
3487+
34243488
def test_pickle(self):
34253489
global P, P_co, P_contra, P_default
34263490
P = ParamSpec('P')
@@ -3727,31 +3791,49 @@ def test_concatenation(self):
37273791
self.assertEqual(Tuple[int, Unpack[Xs], str].__args__,
37283792
(int, Unpack[Xs], str))
37293793
class C(Generic[Unpack[Xs]]): pass
3730-
self.assertEqual(C[int, Unpack[Xs]].__args__, (int, Unpack[Xs]))
3731-
self.assertEqual(C[Unpack[Xs], int].__args__, (Unpack[Xs], int))
3732-
self.assertEqual(C[int, Unpack[Xs], str].__args__,
3733-
(int, Unpack[Xs], str))
3794+
class D(Protocol[Unpack[Xs]]): pass
3795+
for klass in C, D:
3796+
with self.subTest(klass=klass.__name__):
3797+
self.assertEqual(klass[int, Unpack[Xs]].__args__, (int, Unpack[Xs]))
3798+
self.assertEqual(klass[Unpack[Xs], int].__args__, (Unpack[Xs], int))
3799+
self.assertEqual(klass[int, Unpack[Xs], str].__args__,
3800+
(int, Unpack[Xs], str))
37343801

37353802
def test_class(self):
37363803
Ts = TypeVarTuple('Ts')
37373804

37383805
class C(Generic[Unpack[Ts]]): pass
3739-
self.assertEqual(C[int].__args__, (int,))
3740-
self.assertEqual(C[int, str].__args__, (int, str))
3806+
class D(Protocol[Unpack[Ts]]): pass
3807+
3808+
for klass in C, D:
3809+
with self.subTest(klass=klass.__name__):
3810+
self.assertEqual(klass[int].__args__, (int,))
3811+
self.assertEqual(klass[int, str].__args__, (int, str))
37413812

37423813
with self.assertRaises(TypeError):
37433814
class C(Generic[Unpack[Ts], int]): pass
37443815

3816+
with self.assertRaises(TypeError):
3817+
class D(Protocol[Unpack[Ts], int]): pass
3818+
37453819
T1 = TypeVar('T')
37463820
T2 = TypeVar('T')
37473821
class C(Generic[T1, T2, Unpack[Ts]]): pass
3748-
self.assertEqual(C[int, str].__args__, (int, str))
3749-
self.assertEqual(C[int, str, float].__args__, (int, str, float))
3750-
self.assertEqual(C[int, str, float, bool].__args__, (int, str, float, bool))
3751-
# TODO This should probably also fail on 3.11, pending changes to CPython.
3752-
if not TYPING_3_11_0:
3753-
with self.assertRaises(TypeError):
3754-
C[int]
3822+
class D(Protocol[T1, T2, Unpack[Ts]]): pass
3823+
for klass in C, D:
3824+
with self.subTest(klass=klass.__name__):
3825+
self.assertEqual(klass[int, str].__args__, (int, str))
3826+
self.assertEqual(klass[int, str, float].__args__, (int, str, float))
3827+
self.assertEqual(
3828+
klass[int, str, float, bool].__args__, (int, str, float, bool)
3829+
)
3830+
# A bug was fixed in 3.11.1
3831+
# (https://github.com/python/cpython/commit/74920aa27d0c57443dd7f704d6272cca9c507ab3)
3832+
# That means this assertion doesn't pass on 3.11.0,
3833+
# but it passes on all other Python versions
3834+
if sys.version_info[:3] != (3, 11, 0):
3835+
with self.assertRaises(TypeError):
3836+
klass[int]
37553837

37563838

37573839
class TypeVarTupleTests(BaseTestCase):

0 commit comments

Comments
 (0)