@@ -82,15 +82,19 @@ def __repr__(self) -> str:
82
82
op_str = "<:"
83
83
if self .op == SUPERTYPE_OF :
84
84
op_str = ":>"
85
- return f"{ self .type_var } { op_str } { self .target } "
85
+ return f"{ self .origin_type_var } { op_str } { self .target } "
86
86
87
87
def __hash__ (self ) -> int :
88
- return hash ((self .type_var , self .op , self .target ))
88
+ return hash ((self .origin_type_var , self .op , self .target ))
89
89
90
90
def __eq__ (self , other : object ) -> bool :
91
91
if not isinstance (other , Constraint ):
92
92
return False
93
- return (self .type_var , self .op , self .target ) == (other .type_var , other .op , other .target )
93
+ return (self .origin_type_var , self .op , self .target ) == (
94
+ other .origin_type_var ,
95
+ other .op ,
96
+ other .target ,
97
+ )
94
98
95
99
96
100
def infer_constraints_for_callable (
@@ -698,25 +702,54 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
698
702
)
699
703
elif isinstance (tvar , ParamSpecType ) and isinstance (mapped_arg , ParamSpecType ):
700
704
suffix = get_proper_type (instance_arg )
705
+ prefix = mapped_arg .prefix
706
+ length = len (prefix .arg_types )
701
707
702
708
if isinstance (suffix , CallableType ):
703
- prefix = mapped_arg .prefix
704
709
from_concat = bool (prefix .arg_types ) or suffix .from_concatenate
705
710
suffix = suffix .copy_modified (from_concatenate = from_concat )
706
711
707
712
if isinstance (suffix , (Parameters , CallableType )):
708
713
# no such thing as variance for ParamSpecs
709
714
# TODO: is there a case I am missing?
710
- # TODO: constraints between prefixes
711
- prefix = mapped_arg .prefix
712
- suffix = suffix .copy_modified (
713
- suffix .arg_types [len (prefix .arg_types ) :],
714
- suffix .arg_kinds [len (prefix .arg_kinds ) :],
715
- suffix .arg_names [len (prefix .arg_names ) :],
715
+ length = min (length , len (suffix .arg_types ))
716
+
717
+ constrained_to = suffix .copy_modified (
718
+ suffix .arg_types [length :],
719
+ suffix .arg_kinds [length :],
720
+ suffix .arg_names [length :],
721
+ )
722
+ constrained_from = mapped_arg .copy_modified (
723
+ prefix = prefix .copy_modified (
724
+ prefix .arg_types [length :],
725
+ prefix .arg_kinds [length :],
726
+ prefix .arg_names [length :],
727
+ )
716
728
)
717
- res .append (Constraint (mapped_arg , SUPERTYPE_OF , suffix ))
729
+
730
+ res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained_to ))
731
+ res .append (Constraint (constrained_from , SUBTYPE_OF , constrained_to ))
718
732
elif isinstance (suffix , ParamSpecType ):
719
- res .append (Constraint (mapped_arg , SUPERTYPE_OF , suffix ))
733
+ suffix_prefix = suffix .prefix
734
+ length = min (length , len (suffix_prefix .arg_types ))
735
+
736
+ constrained = suffix .copy_modified (
737
+ prefix = suffix_prefix .copy_modified (
738
+ suffix_prefix .arg_types [length :],
739
+ suffix_prefix .arg_kinds [length :],
740
+ suffix_prefix .arg_names [length :],
741
+ )
742
+ )
743
+ constrained_from = mapped_arg .copy_modified (
744
+ prefix = prefix .copy_modified (
745
+ prefix .arg_types [length :],
746
+ prefix .arg_kinds [length :],
747
+ prefix .arg_names [length :],
748
+ )
749
+ )
750
+
751
+ res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained ))
752
+ res .append (Constraint (constrained_from , SUBTYPE_OF , constrained ))
720
753
else :
721
754
# This case should have been handled above.
722
755
assert not isinstance (tvar , TypeVarTupleType )
@@ -768,26 +801,56 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
768
801
template_arg , ParamSpecType
769
802
):
770
803
suffix = get_proper_type (mapped_arg )
804
+ prefix = template_arg .prefix
805
+ length = len (prefix .arg_types )
771
806
772
807
if isinstance (suffix , CallableType ):
773
808
prefix = template_arg .prefix
774
809
from_concat = bool (prefix .arg_types ) or suffix .from_concatenate
775
810
suffix = suffix .copy_modified (from_concatenate = from_concat )
776
811
812
+ # TODO: this is almost a copy-paste of code above: make this into a function
777
813
if isinstance (suffix , (Parameters , CallableType )):
778
814
# no such thing as variance for ParamSpecs
779
815
# TODO: is there a case I am missing?
780
- # TODO: constraints between prefixes
781
- prefix = template_arg .prefix
816
+ length = min (length , len (suffix .arg_types ))
782
817
783
- suffix = suffix .copy_modified (
784
- suffix .arg_types [len ( prefix . arg_types ) :],
785
- suffix .arg_kinds [len ( prefix . arg_kinds ) :],
786
- suffix .arg_names [len ( prefix . arg_names ) :],
818
+ constrained_to = suffix .copy_modified (
819
+ suffix .arg_types [length :],
820
+ suffix .arg_kinds [length :],
821
+ suffix .arg_names [length :],
787
822
)
788
- res .append (Constraint (template_arg , SUPERTYPE_OF , suffix ))
823
+ constrained_from = template_arg .copy_modified (
824
+ prefix = prefix .copy_modified (
825
+ prefix .arg_types [length :],
826
+ prefix .arg_kinds [length :],
827
+ prefix .arg_names [length :],
828
+ )
829
+ )
830
+
831
+ res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained_to ))
832
+ res .append (Constraint (constrained_from , SUBTYPE_OF , constrained_to ))
789
833
elif isinstance (suffix , ParamSpecType ):
790
- res .append (Constraint (template_arg , SUPERTYPE_OF , suffix ))
834
+ suffix_prefix = suffix .prefix
835
+ length = min (length , len (suffix_prefix .arg_types ))
836
+
837
+ constrained = suffix .copy_modified (
838
+ prefix = suffix_prefix .copy_modified (
839
+ suffix_prefix .arg_types [length :],
840
+ suffix_prefix .arg_kinds [length :],
841
+ suffix_prefix .arg_names [length :],
842
+ )
843
+ )
844
+ constrained_from = template_arg .copy_modified (
845
+ prefix = prefix .copy_modified (
846
+ prefix .arg_types [length :],
847
+ prefix .arg_kinds [length :],
848
+ prefix .arg_names [length :],
849
+ )
850
+ )
851
+
852
+ res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained ))
853
+ res .append (Constraint (constrained_from , SUBTYPE_OF , constrained ))
791
854
else :
792
855
# This case should have been handled above.
793
856
assert not isinstance (tvar , TypeVarTupleType )
@@ -954,9 +1017,19 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
954
1017
prefix_len = len (prefix .arg_types )
955
1018
cactual_ps = cactual .param_spec ()
956
1019
1020
+ cactual_prefix : Parameters | CallableType
1021
+ if cactual_ps :
1022
+ cactual_prefix = cactual_ps .prefix
1023
+ else :
1024
+ cactual_prefix = cactual
1025
+
1026
+ max_prefix_len = len (
1027
+ [k for k in cactual_prefix .arg_kinds if k in (ARG_POS , ARG_OPT )]
1028
+ )
1029
+ prefix_len = min (prefix_len , max_prefix_len )
1030
+
1031
+ # we could check the prefixes match here, but that should be caught elsewhere.
957
1032
if not cactual_ps :
958
- max_prefix_len = len ([k for k in cactual .arg_kinds if k in (ARG_POS , ARG_OPT )])
959
- prefix_len = min (prefix_len , max_prefix_len )
960
1033
res .append (
961
1034
Constraint (
962
1035
param_spec ,
@@ -970,7 +1043,17 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
970
1043
)
971
1044
)
972
1045
else :
973
- res .append (Constraint (param_spec , SUBTYPE_OF , cactual_ps ))
1046
+ # earlier, cactual_prefix = cactual_ps.prefix. thus, this is guaranteed
1047
+ assert isinstance (cactual_prefix , Parameters )
1048
+
1049
+ constrained_by = cactual_ps .copy_modified (
1050
+ prefix = cactual_prefix .copy_modified (
1051
+ cactual_prefix .arg_types [prefix_len :],
1052
+ cactual_prefix .arg_kinds [prefix_len :],
1053
+ cactual_prefix .arg_names [prefix_len :],
1054
+ )
1055
+ )
1056
+ res .append (Constraint (param_spec , SUBTYPE_OF , constrained_by ))
974
1057
975
1058
# compare prefixes
976
1059
cactual_prefix = cactual .copy_modified (
0 commit comments