Skip to content

Commit 2aaeda4

Browse files
authored
Reconsider constraints involving parameter specifications (#15272)
- Fixes #15037 - Fixes #15065 - Fixes #15073 - Fixes #15388 - Fixes #15086 Yet another part of #14903 that's finally been extracted!
1 parent 5617cdd commit 2aaeda4

File tree

4 files changed

+241
-24
lines changed

4 files changed

+241
-24
lines changed

mypy/constraints.py

+106-23
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,19 @@ def __repr__(self) -> str:
8282
op_str = "<:"
8383
if self.op == SUPERTYPE_OF:
8484
op_str = ":>"
85-
return f"{self.type_var} {op_str} {self.target}"
85+
return f"{self.origin_type_var} {op_str} {self.target}"
8686

8787
def __hash__(self) -> int:
88-
return hash((self.type_var, self.op, self.target))
88+
return hash((self.origin_type_var, self.op, self.target))
8989

9090
def __eq__(self, other: object) -> bool:
9191
if not isinstance(other, Constraint):
9292
return False
93-
return (self.type_var, self.op, self.target) == (other.type_var, other.op, other.target)
93+
return (self.origin_type_var, self.op, self.target) == (
94+
other.origin_type_var,
95+
other.op,
96+
other.target,
97+
)
9498

9599

96100
def infer_constraints_for_callable(
@@ -698,25 +702,54 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
698702
)
699703
elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType):
700704
suffix = get_proper_type(instance_arg)
705+
prefix = mapped_arg.prefix
706+
length = len(prefix.arg_types)
701707

702708
if isinstance(suffix, CallableType):
703-
prefix = mapped_arg.prefix
704709
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
705710
suffix = suffix.copy_modified(from_concatenate=from_concat)
706711

707712
if isinstance(suffix, (Parameters, CallableType)):
708713
# no such thing as variance for ParamSpecs
709714
# TODO: is there a case I am missing?
710-
# TODO: constraints between prefixes
711-
prefix = mapped_arg.prefix
712-
suffix = suffix.copy_modified(
713-
suffix.arg_types[len(prefix.arg_types) :],
714-
suffix.arg_kinds[len(prefix.arg_kinds) :],
715-
suffix.arg_names[len(prefix.arg_names) :],
715+
length = min(length, len(suffix.arg_types))
716+
717+
constrained_to = suffix.copy_modified(
718+
suffix.arg_types[length:],
719+
suffix.arg_kinds[length:],
720+
suffix.arg_names[length:],
721+
)
722+
constrained_from = mapped_arg.copy_modified(
723+
prefix=prefix.copy_modified(
724+
prefix.arg_types[length:],
725+
prefix.arg_kinds[length:],
726+
prefix.arg_names[length:],
727+
)
716728
)
717-
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
729+
730+
res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained_to))
731+
res.append(Constraint(constrained_from, SUBTYPE_OF, constrained_to))
718732
elif isinstance(suffix, ParamSpecType):
719-
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
733+
suffix_prefix = suffix.prefix
734+
length = min(length, len(suffix_prefix.arg_types))
735+
736+
constrained = suffix.copy_modified(
737+
prefix=suffix_prefix.copy_modified(
738+
suffix_prefix.arg_types[length:],
739+
suffix_prefix.arg_kinds[length:],
740+
suffix_prefix.arg_names[length:],
741+
)
742+
)
743+
constrained_from = mapped_arg.copy_modified(
744+
prefix=prefix.copy_modified(
745+
prefix.arg_types[length:],
746+
prefix.arg_kinds[length:],
747+
prefix.arg_names[length:],
748+
)
749+
)
750+
751+
res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained))
752+
res.append(Constraint(constrained_from, SUBTYPE_OF, constrained))
720753
else:
721754
# This case should have been handled above.
722755
assert not isinstance(tvar, TypeVarTupleType)
@@ -768,26 +801,56 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
768801
template_arg, ParamSpecType
769802
):
770803
suffix = get_proper_type(mapped_arg)
804+
prefix = template_arg.prefix
805+
length = len(prefix.arg_types)
771806

772807
if isinstance(suffix, CallableType):
773808
prefix = template_arg.prefix
774809
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
775810
suffix = suffix.copy_modified(from_concatenate=from_concat)
776811

812+
# TODO: this is almost a copy-paste of code above: make this into a function
777813
if isinstance(suffix, (Parameters, CallableType)):
778814
# no such thing as variance for ParamSpecs
779815
# TODO: is there a case I am missing?
780-
# TODO: constraints between prefixes
781-
prefix = template_arg.prefix
816+
length = min(length, len(suffix.arg_types))
782817

783-
suffix = suffix.copy_modified(
784-
suffix.arg_types[len(prefix.arg_types) :],
785-
suffix.arg_kinds[len(prefix.arg_kinds) :],
786-
suffix.arg_names[len(prefix.arg_names) :],
818+
constrained_to = suffix.copy_modified(
819+
suffix.arg_types[length:],
820+
suffix.arg_kinds[length:],
821+
suffix.arg_names[length:],
787822
)
788-
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
823+
constrained_from = template_arg.copy_modified(
824+
prefix=prefix.copy_modified(
825+
prefix.arg_types[length:],
826+
prefix.arg_kinds[length:],
827+
prefix.arg_names[length:],
828+
)
829+
)
830+
831+
res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained_to))
832+
res.append(Constraint(constrained_from, SUBTYPE_OF, constrained_to))
789833
elif isinstance(suffix, ParamSpecType):
790-
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
834+
suffix_prefix = suffix.prefix
835+
length = min(length, len(suffix_prefix.arg_types))
836+
837+
constrained = suffix.copy_modified(
838+
prefix=suffix_prefix.copy_modified(
839+
suffix_prefix.arg_types[length:],
840+
suffix_prefix.arg_kinds[length:],
841+
suffix_prefix.arg_names[length:],
842+
)
843+
)
844+
constrained_from = template_arg.copy_modified(
845+
prefix=prefix.copy_modified(
846+
prefix.arg_types[length:],
847+
prefix.arg_kinds[length:],
848+
prefix.arg_names[length:],
849+
)
850+
)
851+
852+
res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained))
853+
res.append(Constraint(constrained_from, SUBTYPE_OF, constrained))
791854
else:
792855
# This case should have been handled above.
793856
assert not isinstance(tvar, TypeVarTupleType)
@@ -954,9 +1017,19 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
9541017
prefix_len = len(prefix.arg_types)
9551018
cactual_ps = cactual.param_spec()
9561019

1020+
cactual_prefix: Parameters | CallableType
1021+
if cactual_ps:
1022+
cactual_prefix = cactual_ps.prefix
1023+
else:
1024+
cactual_prefix = cactual
1025+
1026+
max_prefix_len = len(
1027+
[k for k in cactual_prefix.arg_kinds if k in (ARG_POS, ARG_OPT)]
1028+
)
1029+
prefix_len = min(prefix_len, max_prefix_len)
1030+
1031+
# we could check the prefixes match here, but that should be caught elsewhere.
9571032
if not cactual_ps:
958-
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
959-
prefix_len = min(prefix_len, max_prefix_len)
9601033
res.append(
9611034
Constraint(
9621035
param_spec,
@@ -970,7 +1043,17 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
9701043
)
9711044
)
9721045
else:
973-
res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps))
1046+
# earlier, cactual_prefix = cactual_ps.prefix. thus, this is guaranteed
1047+
assert isinstance(cactual_prefix, Parameters)
1048+
1049+
constrained_by = cactual_ps.copy_modified(
1050+
prefix=cactual_prefix.copy_modified(
1051+
cactual_prefix.arg_types[prefix_len:],
1052+
cactual_prefix.arg_kinds[prefix_len:],
1053+
cactual_prefix.arg_names[prefix_len:],
1054+
)
1055+
)
1056+
res.append(Constraint(param_spec, SUBTYPE_OF, constrained_by))
9741057

9751058
# compare prefixes
9761059
cactual_prefix = cactual.copy_modified(

mypy/test/testconstraints.py

+62
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,65 @@ def test_var_length_tuple_with_fixed_length_tuple(self) -> None:
156156
Instance(fx.std_tuplei, [fx.a]),
157157
SUPERTYPE_OF,
158158
)
159+
160+
def test_paramspec_constrained_with_concatenate(self) -> None:
161+
# for legibility (and my own understanding), `Tester.normal()` is `Tester[P]`
162+
# and `Tester.concatenate()` is `Tester[Concatenate[A, P]]`
163+
# ... and 2nd arg to infer_constraints ends up on LHS of equality
164+
fx = self.fx
165+
166+
# I don't think we can parametrize...
167+
for direction in (SUPERTYPE_OF, SUBTYPE_OF):
168+
print(f"direction is {direction}")
169+
# equiv to: x: Tester[Q] = Tester.normal()
170+
assert set(
171+
infer_constraints(Instance(fx.gpsi, [fx.p]), Instance(fx.gpsi, [fx.q]), direction)
172+
) == {
173+
Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q),
174+
Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q),
175+
}
176+
177+
# equiv to: x: Tester[Q] = Tester.concatenate()
178+
assert set(
179+
infer_constraints(
180+
Instance(fx.gpsi, [fx.p_concatenate]), Instance(fx.gpsi, [fx.q]), direction
181+
)
182+
) == {
183+
Constraint(type_var=fx.p_concatenate, op=SUPERTYPE_OF, target=fx.q),
184+
Constraint(type_var=fx.p_concatenate, op=SUBTYPE_OF, target=fx.q),
185+
}
186+
187+
# equiv to: x: Tester[Concatenate[B, Q]] = Tester.normal()
188+
assert set(
189+
infer_constraints(
190+
Instance(fx.gpsi, [fx.p]), Instance(fx.gpsi, [fx.q_concatenate]), direction
191+
)
192+
) == {
193+
Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q_concatenate),
194+
Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q_concatenate),
195+
}
196+
197+
# equiv to: x: Tester[Concatenate[B, Q]] = Tester.concatenate()
198+
assert set(
199+
infer_constraints(
200+
Instance(fx.gpsi, [fx.p_concatenate]),
201+
Instance(fx.gpsi, [fx.q_concatenate]),
202+
direction,
203+
)
204+
) == {
205+
# this is correct as we assume other parts of mypy will warn that [B] != [A]
206+
Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q),
207+
Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q),
208+
}
209+
210+
# equiv to: x: Tester[Concatenate[A, Q]] = Tester.concatenate()
211+
assert set(
212+
infer_constraints(
213+
Instance(fx.gpsi, [fx.p_concatenate]),
214+
Instance(fx.gpsi, [fx.q_concatenate]),
215+
direction,
216+
)
217+
) == {
218+
Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q),
219+
Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q),
220+
}

mypy/test/typefixture.py

+42
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
from typing import Sequence
9+
810
from mypy.nodes import (
911
ARG_OPT,
1012
ARG_POS,
@@ -26,6 +28,9 @@
2628
Instance,
2729
LiteralType,
2830
NoneType,
31+
Parameters,
32+
ParamSpecFlavor,
33+
ParamSpecType,
2934
Type,
3035
TypeAliasType,
3136
TypeOfAny,
@@ -238,6 +243,31 @@ def make_type_var_tuple(name: str, id: int, upper_bound: Type) -> TypeVarTupleTy
238243
"GV2", mro=[self.oi], typevars=["T", "Ts", "S"], typevar_tuple_index=1
239244
)
240245

246+
def make_parameter_specification(
247+
name: str, id: int, concatenate: Sequence[Type]
248+
) -> ParamSpecType:
249+
return ParamSpecType(
250+
name,
251+
name,
252+
id,
253+
ParamSpecFlavor.BARE,
254+
self.o,
255+
AnyType(TypeOfAny.from_omitted_generics),
256+
prefix=Parameters(
257+
concatenate, [ARG_POS for _ in concatenate], [None for _ in concatenate]
258+
),
259+
)
260+
261+
self.p = make_parameter_specification("P", 1, [])
262+
self.p_concatenate = make_parameter_specification("P", 1, [self.a])
263+
self.q = make_parameter_specification("Q", 2, [])
264+
self.q_concatenate = make_parameter_specification("Q", 2, [self.b])
265+
self.q_concatenate_a = make_parameter_specification("Q", 2, [self.a])
266+
267+
self.gpsi = self.make_type_info(
268+
"GPS", mro=[self.oi], typevars=["P"], paramspec_indexes={0}
269+
)
270+
241271
def _add_bool_dunder(self, type_info: TypeInfo) -> None:
242272
signature = CallableType([], [], [], Instance(self.bool_type_info, []), self.function)
243273
bool_func = FuncDef("__bool__", [], Block([]))
@@ -299,6 +329,7 @@ def make_type_info(
299329
bases: list[Instance] | None = None,
300330
typevars: list[str] | None = None,
301331
typevar_tuple_index: int | None = None,
332+
paramspec_indexes: set[int] | None = None,
302333
variances: list[int] | None = None,
303334
) -> TypeInfo:
304335
"""Make a TypeInfo suitable for use in unit tests."""
@@ -326,6 +357,17 @@ def make_type_info(
326357
AnyType(TypeOfAny.from_omitted_generics),
327358
)
328359
)
360+
elif paramspec_indexes is not None and id - 1 in paramspec_indexes:
361+
v.append(
362+
ParamSpecType(
363+
n,
364+
n,
365+
id,
366+
ParamSpecFlavor.BARE,
367+
self.o,
368+
AnyType(TypeOfAny.from_omitted_generics),
369+
)
370+
)
329371
else:
330372
if variances:
331373
variance = variances[id - 1]

test-data/unit/check-parameter-specification.test

+31-1
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ _P = ParamSpec("_P")
776776

777777
class Job(Generic[_P]):
778778
def __init__(self, target: Callable[_P, None]) -> None:
779-
self.target = target
779+
...
780780

781781
def func(
782782
action: Union[Job[int], Callable[[int], None]],
@@ -1535,6 +1535,36 @@ def identity(func: Callable[P, None]) -> Callable[P, None]: ...
15351535
def f(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ...
15361536
[builtins fixtures/paramspec.pyi]
15371537

1538+
[case testComplicatedParamSpecReturnType]
1539+
# regression test for https://github.com/python/mypy/issues/15073
1540+
from typing import TypeVar, Callable
1541+
from typing_extensions import ParamSpec, Concatenate
1542+
1543+
R = TypeVar("R")
1544+
P = ParamSpec("P")
1545+
1546+
def f(
1547+
) -> Callable[[Callable[Concatenate[Callable[P, R], P], R]], Callable[P, R]]:
1548+
def r(fn: Callable[Concatenate[Callable[P, R], P], R]) -> Callable[P, R]: ...
1549+
return r
1550+
[builtins fixtures/paramspec.pyi]
1551+
1552+
[case testParamSpecToParamSpecAssignment]
1553+
# minimized from https://github.com/python/mypy/issues/15037
1554+
# ~ the same as https://github.com/python/mypy/issues/15065
1555+
from typing import Callable
1556+
from typing_extensions import Concatenate, ParamSpec
1557+
1558+
P = ParamSpec("P")
1559+
1560+
def f(f: Callable[Concatenate[int, P], None]) -> Callable[P, None]: ...
1561+
1562+
x: Callable[
1563+
[Callable[Concatenate[int, P], None]],
1564+
Callable[P, None],
1565+
] = f
1566+
[builtins fixtures/paramspec.pyi]
1567+
15381568
[case testParamSpecDecoratorAppliedToGeneric]
15391569
# flags: --new-type-inference
15401570
from typing import Callable, List, TypeVar

0 commit comments

Comments
 (0)