@@ -2613,6 +2613,62 @@ class CustomProtocolWithoutInitB(Protocol):
2613
2613
2614
2614
self .assertEqual (CustomProtocolWithoutInitA .__init__ , CustomProtocolWithoutInitB .__init__ )
2615
2615
2616
+ def test_protocol_generic_over_paramspec (self ):
2617
+ P = ParamSpec ("P" )
2618
+ T = TypeVar ("T" )
2619
+ T2 = TypeVar ("T2" )
2620
+
2621
+ class MemoizedFunc (Protocol [P , T , T2 ]):
2622
+ cache : typing .Dict [T2 , T ]
2623
+ def __call__ (self , * args : P .args , ** kwargs : P .kwargs ) -> T : ...
2624
+
2625
+ self .assertEqual (MemoizedFunc .__parameters__ , (P , T , T2 ))
2626
+ self .assertTrue (MemoizedFunc ._is_protocol )
2627
+
2628
+ with self .assertRaises (TypeError ):
2629
+ MemoizedFunc [[int , str , str ]]
2630
+
2631
+ if sys .version_info >= (3 , 10 ):
2632
+ # These unfortunately don't pass on <=3.9,
2633
+ # due to typing._type_check on older Python versions
2634
+ X = MemoizedFunc [[int , str , str ], T , T2 ]
2635
+ self .assertEqual (X .__parameters__ , (T , T2 ))
2636
+ self .assertEqual (X .__args__ , ((int , str , str ), T , T2 ))
2637
+
2638
+ Y = X [bytes , memoryview ]
2639
+ self .assertEqual (Y .__parameters__ , ())
2640
+ self .assertEqual (Y .__args__ , ((int , str , str ), bytes , memoryview ))
2641
+
2642
+ def test_protocol_generic_over_typevartuple (self ):
2643
+ Ts = TypeVarTuple ("Ts" )
2644
+ T = TypeVar ("T" )
2645
+ T2 = TypeVar ("T2" )
2646
+
2647
+ class MemoizedFunc (Protocol [Unpack [Ts ], T , T2 ]):
2648
+ cache : typing .Dict [T2 , T ]
2649
+ def __call__ (self , * args : Unpack [Ts ]) -> T : ...
2650
+
2651
+ self .assertEqual (MemoizedFunc .__parameters__ , (Ts , T , T2 ))
2652
+ self .assertTrue (MemoizedFunc ._is_protocol )
2653
+
2654
+ things = "arguments" if sys .version_info >= (3 , 11 ) else "parameters"
2655
+
2656
+ # A bug was fixed in 3.11.1
2657
+ # (https://github.com/python/cpython/commit/74920aa27d0c57443dd7f704d6272cca9c507ab3)
2658
+ # That means this assertion doesn't pass on 3.11.0,
2659
+ # but it passes on all other Python versions
2660
+ if sys .version_info [:3 ] != (3 , 11 , 0 ):
2661
+ with self .assertRaisesRegex (TypeError , f"Too few { things } " ):
2662
+ MemoizedFunc [int ]
2663
+
2664
+ X = MemoizedFunc [int , T , T2 ]
2665
+ self .assertEqual (X .__parameters__ , (T , T2 ))
2666
+ self .assertEqual (X .__args__ , (int , T , T2 ))
2667
+
2668
+ Y = X [bytes , memoryview ]
2669
+ self .assertEqual (Y .__parameters__ , ())
2670
+ self .assertEqual (Y .__args__ , (int , bytes , memoryview ))
2671
+
2616
2672
2617
2673
class Point2DGeneric (Generic [T ], TypedDict ):
2618
2674
a : T
@@ -3402,13 +3458,18 @@ def test_user_generics(self):
3402
3458
class X (Generic [T , P ]):
3403
3459
pass
3404
3460
3405
- G1 = X [int , P_2 ]
3406
- self .assertEqual (G1 .__args__ , (int , P_2 ))
3407
- self .assertEqual (G1 .__parameters__ , (P_2 ,))
3461
+ class Y (Protocol [T , P ]):
3462
+ pass
3463
+
3464
+ for klass in X , Y :
3465
+ with self .subTest (klass = klass .__name__ ):
3466
+ G1 = klass [int , P_2 ]
3467
+ self .assertEqual (G1 .__args__ , (int , P_2 ))
3468
+ self .assertEqual (G1 .__parameters__ , (P_2 ,))
3408
3469
3409
- G2 = X [int , Concatenate [int , P_2 ]]
3410
- self .assertEqual (G2 .__args__ , (int , Concatenate [int , P_2 ]))
3411
- self .assertEqual (G2 .__parameters__ , (P_2 ,))
3470
+ G2 = klass [int , Concatenate [int , P_2 ]]
3471
+ self .assertEqual (G2 .__args__ , (int , Concatenate [int , P_2 ]))
3472
+ self .assertEqual (G2 .__parameters__ , (P_2 ,))
3412
3473
3413
3474
# The following are some valid uses cases in PEP 612 that don't work:
3414
3475
# These do not work in 3.9, _type_check blocks the list and ellipsis.
@@ -3421,6 +3482,9 @@ class X(Generic[T, P]):
3421
3482
class Z (Generic [P ]):
3422
3483
pass
3423
3484
3485
+ class ProtoZ (Protocol [P ]):
3486
+ pass
3487
+
3424
3488
def test_pickle (self ):
3425
3489
global P , P_co , P_contra , P_default
3426
3490
P = ParamSpec ('P' )
@@ -3727,31 +3791,49 @@ def test_concatenation(self):
3727
3791
self .assertEqual (Tuple [int , Unpack [Xs ], str ].__args__ ,
3728
3792
(int , Unpack [Xs ], str ))
3729
3793
class C (Generic [Unpack [Xs ]]): pass
3730
- self .assertEqual (C [int , Unpack [Xs ]].__args__ , (int , Unpack [Xs ]))
3731
- self .assertEqual (C [Unpack [Xs ], int ].__args__ , (Unpack [Xs ], int ))
3732
- self .assertEqual (C [int , Unpack [Xs ], str ].__args__ ,
3733
- (int , Unpack [Xs ], str ))
3794
+ class D (Protocol [Unpack [Xs ]]): pass
3795
+ for klass in C , D :
3796
+ with self .subTest (klass = klass .__name__ ):
3797
+ self .assertEqual (klass [int , Unpack [Xs ]].__args__ , (int , Unpack [Xs ]))
3798
+ self .assertEqual (klass [Unpack [Xs ], int ].__args__ , (Unpack [Xs ], int ))
3799
+ self .assertEqual (klass [int , Unpack [Xs ], str ].__args__ ,
3800
+ (int , Unpack [Xs ], str ))
3734
3801
3735
3802
def test_class (self ):
3736
3803
Ts = TypeVarTuple ('Ts' )
3737
3804
3738
3805
class C (Generic [Unpack [Ts ]]): pass
3739
- self .assertEqual (C [int ].__args__ , (int ,))
3740
- self .assertEqual (C [int , str ].__args__ , (int , str ))
3806
+ class D (Protocol [Unpack [Ts ]]): pass
3807
+
3808
+ for klass in C , D :
3809
+ with self .subTest (klass = klass .__name__ ):
3810
+ self .assertEqual (klass [int ].__args__ , (int ,))
3811
+ self .assertEqual (klass [int , str ].__args__ , (int , str ))
3741
3812
3742
3813
with self .assertRaises (TypeError ):
3743
3814
class C (Generic [Unpack [Ts ], int ]): pass
3744
3815
3816
+ with self .assertRaises (TypeError ):
3817
+ class D (Protocol [Unpack [Ts ], int ]): pass
3818
+
3745
3819
T1 = TypeVar ('T' )
3746
3820
T2 = TypeVar ('T' )
3747
3821
class C (Generic [T1 , T2 , Unpack [Ts ]]): pass
3748
- self .assertEqual (C [int , str ].__args__ , (int , str ))
3749
- self .assertEqual (C [int , str , float ].__args__ , (int , str , float ))
3750
- self .assertEqual (C [int , str , float , bool ].__args__ , (int , str , float , bool ))
3751
- # TODO This should probably also fail on 3.11, pending changes to CPython.
3752
- if not TYPING_3_11_0 :
3753
- with self .assertRaises (TypeError ):
3754
- C [int ]
3822
+ class D (Protocol [T1 , T2 , Unpack [Ts ]]): pass
3823
+ for klass in C , D :
3824
+ with self .subTest (klass = klass .__name__ ):
3825
+ self .assertEqual (klass [int , str ].__args__ , (int , str ))
3826
+ self .assertEqual (klass [int , str , float ].__args__ , (int , str , float ))
3827
+ self .assertEqual (
3828
+ klass [int , str , float , bool ].__args__ , (int , str , float , bool )
3829
+ )
3830
+ # A bug was fixed in 3.11.1
3831
+ # (https://github.com/python/cpython/commit/74920aa27d0c57443dd7f704d6272cca9c507ab3)
3832
+ # That means this assertion doesn't pass on 3.11.0,
3833
+ # but it passes on all other Python versions
3834
+ if sys .version_info [:3 ] != (3 , 11 , 0 ):
3835
+ with self .assertRaises (TypeError ):
3836
+ klass [int ]
3755
3837
3756
3838
3757
3839
class TypeVarTupleTests (BaseTestCase ):
0 commit comments