Skip to content

Fix some higher-order (?) ParamSpec usage #14903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
LiteralValue,
NoneType,
Overloaded,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
Expand Down Expand Up @@ -1429,6 +1430,7 @@ def check_callable_call(
need_refresh = any(
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
)
old_callee = callee
callee = freshen_function_type_vars(callee)
callee = self.infer_function_type_arguments_using_context(callee, context)
if need_refresh:
Expand All @@ -1443,7 +1445,7 @@ def check_callable_call(
lambda i: self.accept(args[i]),
)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, formal_to_actual, context
callee, args, arg_kinds, formal_to_actual, context, old_callee
)
if need_refresh:
formal_to_actual = map_actuals_to_formals(
Expand Down Expand Up @@ -1733,6 +1735,7 @@ def infer_function_type_arguments(
arg_kinds: list[ArgKind],
formal_to_actual: list[list[int]],
context: Context,
unfreshened_callee_type: CallableType,
) -> CallableType:
"""Infer the type arguments for a generic callee type.

Expand Down Expand Up @@ -1776,6 +1779,28 @@ def infer_function_type_arguments(
callee_type, args, arg_kinds, formal_to_actual, inferred_args, context
)

return_type = get_proper_type(callee_type.ret_type)
if isinstance(return_type, CallableType):
# fixup:
# def [T] () -> def (T) -> T
# into
# def () -> def [T] (T) -> T
for i, argument in enumerate(inferred_args):
if isinstance(get_proper_type(argument), UninhabitedType):
# un-"freshen" the type variable :^)
variable = unfreshened_callee_type.variables[i]
inferred_args[i] = variable

# handle multiple type variables
return_type = return_type.copy_modified(
variables=[*return_type.variables, variable]
)

callee_type = callee_type.copy_modified(
# am I allowed to assign the get_proper_type'd thing?
ret_type=return_type
)

if (
callee_type.special_sig == "dict"
and len(inferred_args) == 2
Expand Down Expand Up @@ -4070,6 +4095,20 @@ def apply_type_arguments_to_callable(
tp = get_proper_type(tp)

if isinstance(tp, CallableType):
if (
len(tp.variables) == 1
and isinstance(tp.variables[0], ParamSpecType)
and (
len(args) != 1
or not isinstance(
get_proper_type(args[0]), (Parameters, ParamSpecType, AnyType)
)
)
):
# TODO: I don't think AnyType here is valid in the general case, there's 2 cases:
# 1. invalid paramspec expression (in which case we should transform it into an ellipsis)
# 2. user passed it (in which case we should pass it into Parameters(...))
args = [Parameters(args, [nodes.ARG_POS for _ in args], [None for _ in args])]
if len(tp.variables) != len(args):
if tp.is_type_obj() and tp.type_object().fullname == "builtins.tuple":
# TODO: Specialize the callable for the type arguments
Expand Down
98 changes: 79 additions & 19 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, List, Sequence
from typing import TYPE_CHECKING, Iterable, List, Sequence, Union
from typing_extensions import Final

import mypy.subtypes
Expand Down Expand Up @@ -176,12 +176,13 @@ def infer_constraints_for_callable(
def infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:
"""Infer type constraints.

Match a template type, which may contain type variable references,
recursively against a type which does not contain (the same) type
variable references. The result is a list of type constrains of
form 'T is a supertype/subtype of x', where T is a type variable
present in the template and x is a type without reference to type
variables present in the template.
Match a template type, which may contain type variable and parameter
specification references, recursively against a type which does not
contain (the same) type variable and parameter specification references.
The result is a list of type constraints of form 'T is a supertype/subtype
of x', where T is a type variable present in the template or a parameter
specification without its prefix and x is a type without reference to type
variables nor parameters present in the template.

Assume T and S are type variables. Now the following results can be
calculated (read as '(template, actual) --> result'):
Expand All @@ -192,6 +193,23 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons
((T, S), (X, Y)) --> T :> X and S :> Y
(X[T], Any) --> T <: Any and T :> Any

Assume P and Q are prefix-less parameter specifications. The following
results can be calculated in a similar format:

(P, [...W]) --> P :> [...W]
(X[P], X[[...W]]) --> P :> [...W]
// note that parameter specifications are *always* contravariant as
// they echo Callable arguments.
((P, P), ([...W], [...U])) --> P :> [...W] and P :> [...U]
((P, Q), ([...W], [...U])) --> P :> [...W] and Q :> [...U]
(P, ...) --> P :> ...

With prefixes (note that I am not sure these cases are implemented):

([...Z, P], [...Z, ...W]) --> P :> [...W]
([...Z, P], Q) --> [...Z, P] :> Q
(P, [...Z, Q]) --> P :> [...Z, Q]

The constraints are represented as Constraint objects.
"""
if any(
Expand Down Expand Up @@ -695,19 +713,37 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
suffix = suffix.copy_modified(from_concatenate=from_concat)

prefix = mapped_arg.prefix
length = len(prefix.arg_types)
if isinstance(suffix, Parameters) or isinstance(suffix, CallableType):
# no such thing as variance for ParamSpecs
# TODO: is there a case I am missing?
# TODO: constraints between prefixes
prefix = mapped_arg.prefix
suffix = suffix.copy_modified(
suffix.arg_types[len(prefix.arg_types) :],
suffix.arg_kinds[len(prefix.arg_kinds) :],
suffix.arg_names[len(prefix.arg_names) :],
res.append(
Constraint(
mapped_arg,
SUPERTYPE_OF,
suffix.copy_modified(
arg_types=suffix.arg_types[length:],
arg_kinds=suffix.arg_kinds[length:],
arg_names=suffix.arg_names[length:],
),
)
)
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
suffix_prefix = suffix.prefix
res.append(
Constraint(
mapped_arg,
SUPERTYPE_OF,
suffix.copy_modified(
prefix=suffix_prefix.copy_modified(
arg_types=suffix_prefix.arg_types[length:],
arg_kinds=suffix_prefix.arg_kinds[length:],
arg_names=suffix_prefix.arg_names[length:],
)
),
)
)
else:
# This case should have been handled above.
assert not isinstance(tvar, TypeVarTupleType)
Expand Down Expand Up @@ -918,14 +954,23 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# sometimes, it appears we try to get constraints between two paramspec callables?

# TODO: Direction
# TODO: check the prefixes match
prefix = param_spec.prefix
prefix_len = len(prefix.arg_types)
cactual_ps = cactual.param_spec()

cactual_prefix: Union[Parameters, CallableType]
if cactual_ps:
cactual_prefix = cactual_ps.prefix
else:
cactual_prefix = cactual

max_prefix_len = len(
[k for k in cactual_prefix.arg_kinds if k in (ARG_POS, ARG_OPT)]
)
prefix_len = min(prefix_len, max_prefix_len)

# we could check the prefixes match here, but that should be caught elsewhere.
if not cactual_ps:
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
prefix_len = min(prefix_len, max_prefix_len)
res.append(
Constraint(
param_spec,
Expand All @@ -939,7 +984,22 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
)
)
else:
res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps))
# guaranteed due to if conditions
assert isinstance(cactual_prefix, Parameters)

res.append(
Constraint(
param_spec,
SUBTYPE_OF,
cactual_ps.copy_modified(
prefix=cactual_prefix.copy_modified(
arg_types=cactual_prefix.arg_types[prefix_len:],
arg_kinds=cactual_prefix.arg_kinds[prefix_len:],
arg_names=cactual_prefix.arg_names[prefix_len:],
)
),
)
)

# compare prefixes
cactual_prefix = cactual.copy_modified(
Expand Down
6 changes: 5 additions & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:

def visit_param_spec(self, t: ParamSpecType) -> Type:
if self.erase_id(t.id):
return self.replacement
return t.prefix.copy_modified(
arg_types=t.prefix.arg_types + [self.replacement, self.replacement],
arg_kinds=t.prefix.arg_kinds + [ARG_STAR, ARG_STAR2],
arg_names=t.prefix.arg_names + [None, None],
)
return t

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
Expand Down
1 change: 0 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def freshen_function_type_vars(callee: F) -> F:
if isinstance(v, TypeVarType):
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
elif isinstance(v, TypeVarTupleType):
assert isinstance(v, TypeVarTupleType)
tv = TypeVarTupleType.new_unification_variable(v)
else:
assert isinstance(v, ParamSpecType)
Expand Down
6 changes: 4 additions & 2 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
else:
# ParamSpec type variables behave the same, independent of variance
if not is_equivalent(ta, sa):
return get_proper_type(type_var.upper_bound)
return object_from_instance(t)
new_type = join_types(ta, sa, self)
assert new_type is not None
args.append(new_type)
Expand Down Expand Up @@ -311,9 +311,11 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
return self.default(self.s)

def visit_param_spec(self, t: ParamSpecType) -> ProperType:
# TODO: should this mirror the `isinstance(...) ...` above?
if self.s == t:
return t
return self.default(self.s)
else:
return self.default(self.s)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
if self.s == t:
Expand Down
6 changes: 6 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2519,6 +2519,12 @@ class ParamSpecExpr(TypeVarLikeExpr):

__match_args__ = ("name", "upper_bound")

# TODO: Technically the variance cannot be customized. Nor can the upper bound.
def __init__(
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
) -> None:
super().__init__(name, fullname, upper_bound, variance)

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_paramspec_expr(self)

Expand Down
10 changes: 9 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4168,7 +4168,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool:

if not call.analyzed:
paramspec_var = ParamSpecExpr(
name, self.qualified_name(name), self.object_type(), INVARIANT
name, self.qualified_name(name), self.top_caller(), INVARIANT
)
paramspec_var.line = call.line
call.analyzed = paramspec_var
Expand Down Expand Up @@ -5612,6 +5612,14 @@ def lookup_fully_qualified_or_none(self, fullname: str) -> SymbolTableNode | Non
def object_type(self) -> Instance:
return self.named_type("builtins.object")

def top_caller(self) -> Parameters:
return Parameters(
arg_types=[self.object_type(), self.object_type()],
arg_kinds=[ARG_STAR, ARG_STAR2],
arg_names=[None, None],
is_ellipsis_args=True,
)

def str_type(self) -> Instance:
return self.named_type("builtins.str")

Expand Down
5 changes: 3 additions & 2 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,9 @@ def visit_paramspec_expr(self, o: mypy.nodes.ParamSpecExpr) -> str:
a += ["Variance(COVARIANT)"]
if o.variance == mypy.nodes.CONTRAVARIANT:
a += ["Variance(CONTRAVARIANT)"]
if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"):
a += [f"UpperBound({o.upper_bound})"]
# ParamSpecs do not have upper bounds!!! (should this be left for future proofing?)
# if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"):
# a += [f"UpperBound({o.upper_bound})"]
return self.dump(a, o)

def visit_type_var_tuple_expr(self, o: mypy.nodes.TypeVarTupleExpr) -> str:
Expand Down
2 changes: 2 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ def check_mixed(
):
nominal = False
else:
# TODO: I'm *pretty* sure `CONTRAVARIANT` should be here...
# But it's erroring!
if not check_type_parameter(
lefta, righta, COVARIANT, self.proper_subtype, self.subtype_context
):
Expand Down
3 changes: 3 additions & 0 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,9 @@ def pack_paramspec_args(self, an_args: Sequence[Type]) -> list[Type]:
if count > 0:
first_arg = get_proper_type(an_args[0])
if not (count == 1 and isinstance(first_arg, (Parameters, ParamSpecType, AnyType))):
# TODO: I don't think AnyType here is valid in the general case, there's 2 cases:
# 1. invalid paramspec expression (in which case we should transform it into an ellipsis)
# 2. user passed it (in which case we should pass it into Parameters(...))
return [Parameters(an_args, [ARG_POS] * count, [None] * count)]
return list(an_args)

Expand Down
19 changes: 16 additions & 3 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,8 @@ class ParamSpecType(TypeVarLikeType):
The upper_bound is really used as a fallback type -- it's shared
with TypeVarType for simplicity. It can't be specified by the user
and the value is directly derived from the flavor (currently
always just 'object').
always just '(*Any, **Any)' or '(*object, **object)' depending on
context).
"""

__slots__ = ("flavor", "prefix")
Expand Down Expand Up @@ -696,13 +697,14 @@ def copy_modified(
id: Bogus[TypeVarId | int] = _dummy,
flavor: int = _dummy_int,
prefix: Bogus[Parameters] = _dummy,
upper_bound: Bogus[Type] = _dummy,
) -> ParamSpecType:
return ParamSpecType(
self.name,
self.fullname,
id if id is not _dummy else self.id,
flavor if flavor != _dummy_int else self.flavor,
self.upper_bound,
upper_bound if upper_bound is not _dummy else self.upper_bound,
line=self.line,
column=self.column,
prefix=prefix if prefix is not _dummy else self.prefix,
Expand Down Expand Up @@ -1986,7 +1988,18 @@ def param_spec(self) -> ParamSpecType | None:
# TODO: confirm that all arg kinds are positional
prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2])

return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix)
# TODO: should this take in `object`s?
any_type = AnyType(TypeOfAny.special_form)
return arg_type.copy_modified(
flavor=ParamSpecFlavor.BARE,
prefix=prefix,
upper_bound=Parameters(
arg_types=[any_type, any_type],
arg_kinds=[ARG_STAR, ARG_STAR2],
arg_names=[None, None],
is_ellipsis_args=True,
),
)

def expand_param_spec(
self, c: CallableType | Parameters, no_prefix: bool = False
Expand Down
Loading