Skip to content

New type inference: complete transitive closure #15754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 26 additions & 46 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,8 +733,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# def foo(x: str) -> str: ...
#
# See Python 2's map function for a concrete example of this kind of overload.
current_class = self.scope.active_class()
type_vars = current_class.defn.type_vars if current_class else []
with state.strict_optional_set(True):
if is_unsafe_overlapping_overload_signatures(sig1, sig2):
if is_unsafe_overlapping_overload_signatures(sig1, sig2, type_vars):
self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func)

if impl_type is not None:
Expand Down Expand Up @@ -1697,7 +1699,9 @@ def is_unsafe_overlapping_op(
first = forward_tweaked
second = reverse_tweaked

return is_unsafe_overlapping_overload_signatures(first, second)
current_class = self.scope.active_class()
type_vars = current_class.defn.type_vars if current_class else []
return is_unsafe_overlapping_overload_signatures(first, second, type_vars)

def check_inplace_operator_method(self, defn: FuncBase) -> None:
"""Check an inplace operator method such as __iadd__.
Expand Down Expand Up @@ -3913,11 +3917,12 @@ def is_valid_defaultdict_partial_value_type(self, t: ProperType) -> bool:
return True
if len(t.args) == 1:
arg = get_proper_type(t.args[0])
# TODO: This is too permissive -- we only allow TypeVarType since
# they leak in cases like defaultdict(list) due to a bug.
# This can result in incorrect types being inferred, but only
# in rare cases.
if isinstance(arg, (TypeVarType, UninhabitedType, NoneType)):
if self.options.new_type_inference:
allowed = isinstance(arg, (UninhabitedType, NoneType))
else:
# Allow leaked TypeVars for legacy inference logic.
allowed = isinstance(arg, (UninhabitedType, NoneType, TypeVarType))
if allowed:
return True
return False

Expand Down Expand Up @@ -7153,7 +7158,7 @@ def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool:


def is_unsafe_overlapping_overload_signatures(
signature: CallableType, other: CallableType
signature: CallableType, other: CallableType, class_type_vars: list[TypeVarLikeType]
) -> bool:
"""Check if two overloaded signatures are unsafely overlapping or partially overlapping.

Expand All @@ -7172,8 +7177,8 @@ def is_unsafe_overlapping_overload_signatures(
# This lets us identify cases where the two signatures use completely
# incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars
# test case.
signature = detach_callable(signature)
other = detach_callable(other)
signature = detach_callable(signature, class_type_vars)
other = detach_callable(other, class_type_vars)

# Note: We repeat this check twice in both directions due to a slight
# asymmetry in 'is_callable_compatible'. When checking for partial overlaps,
Expand Down Expand Up @@ -7204,7 +7209,7 @@ def is_unsafe_overlapping_overload_signatures(
)


def detach_callable(typ: CallableType) -> CallableType:
def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) -> CallableType:
"""Ensures that the callable's type variables are 'detached' and independent of the context.

A callable normally keeps track of the type variables it uses within its 'variables' field.
Expand All @@ -7214,42 +7219,17 @@ def detach_callable(typ: CallableType) -> CallableType:
This function will traverse the callable and find all used type vars and add them to the
variables field if it isn't already present.

The caller can then unify on all type variables whether or not the callable is originally
from a class or not."""
type_list = typ.arg_types + [typ.ret_type]

appear_map: dict[str, list[int]] = {}
for i, inner_type in enumerate(type_list):
typevars_available = get_type_vars(inner_type)
for var in typevars_available:
if var.fullname not in appear_map:
appear_map[var.fullname] = []
appear_map[var.fullname].append(i)

used_type_var_names = set()
for var_name, appearances in appear_map.items():
used_type_var_names.add(var_name)

all_type_vars = get_type_vars(typ)
new_variables = []
for var in set(all_type_vars):
if var.fullname not in used_type_var_names:
continue
new_variables.append(
TypeVarType(
name=var.name,
fullname=var.fullname,
id=var.id,
values=var.values,
upper_bound=var.upper_bound,
default=var.default,
variance=var.variance,
)
)
out = typ.copy_modified(
variables=new_variables, arg_types=type_list[:-1], ret_type=type_list[-1]
The caller can then unify on all type variables whether the callable is originally from
the class or not."""
if not class_type_vars:
# Fast path, nothing to update.
return typ
seen_type_vars = set()
for t in typ.arg_types + [typ.ret_type]:
seen_type_vars |= set(get_type_vars(t))
return typ.copy_modified(
variables=list(typ.variables) + [tv for tv in class_type_vars if tv in seen_type_vars]
)
return out


def overload_can_never_match(signature: CallableType, other: CallableType) -> bool:
Expand Down
40 changes: 19 additions & 21 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@ def infer_function_type_arguments_using_context(
# expects_literal(identity(3)) # Should type-check
if not is_generic_instance(ctx) and not is_literal_type_like(ctx):
return callable.copy_modified()
args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx)
args = infer_type_arguments(callable.variables, ret_type, erased_ctx)
# Only substitute non-Uninhabited and non-erased types.
new_args: list[Type | None] = []
for arg in args:
Expand Down Expand Up @@ -1901,7 +1901,7 @@ def infer_function_type_arguments(
else:
pass1_args.append(arg)

inferred_args = infer_function_type_arguments(
inferred_args, _ = infer_function_type_arguments(
callee_type,
pass1_args,
arg_kinds,
Expand Down Expand Up @@ -1943,7 +1943,7 @@ def infer_function_type_arguments(
# variables while allowing for polymorphic solutions, i.e. for solutions
# potentially involving free variables.
# TODO: support the similar inference for return type context.
poly_inferred_args = infer_function_type_arguments(
poly_inferred_args, free_vars = infer_function_type_arguments(
callee_type,
arg_types,
arg_kinds,
Expand All @@ -1952,30 +1952,28 @@ def infer_function_type_arguments(
strict=self.chk.in_checked_function(),
allow_polymorphic=True,
)
for i, pa in enumerate(get_proper_types(poly_inferred_args)):
if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa):
# Indicate that free variables should not be applied in the call below.
poly_inferred_args[i] = None
poly_callee_type = self.apply_generic_arguments(
callee_type, poly_inferred_args, context
)
yes_vars = poly_callee_type.variables
no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables}
if not set(get_type_vars(poly_callee_type)) & no_vars:
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
applied = apply_poly(poly_callee_type, yes_vars)
if applied is not None and poly_inferred_args != [UninhabitedType()] * len(
poly_inferred_args
):
freeze_all_type_vars(applied)
return applied
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
applied = apply_poly(poly_callee_type, free_vars)
if applied is not None and all(
a is not None and not isinstance(get_proper_type(a), UninhabitedType)
for a in poly_inferred_args
):
freeze_all_type_vars(applied)
return applied
# If it didn't work, erase free variables as <nothing>, to avoid confusing errors.
unknown = UninhabitedType()
unknown.ambiguous = True
inferred_args = [
expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables})
expand_type(
a, {v.id: unknown for v in list(callee_type.variables) + free_vars}
)
if a is not None
else None
for a in inferred_args
for a in poly_inferred_args
]
else:
# In dynamically typed functions use implicit 'Any' types for
Expand Down Expand Up @@ -2014,7 +2012,7 @@ def infer_function_type_arguments_pass2(

arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual)

inferred_args = infer_function_type_arguments(
inferred_args, _ = infer_function_type_arguments(
callee_type,
arg_types,
arg_kinds,
Expand Down
63 changes: 41 additions & 22 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None:
self.op = op
self.target = target
self.origin_type_var = type_var
# These are additional type variables that should be solved for together with type_var.
# TODO: A cleaner solution may be to modify the return type of infer_constraints()
# to include these instead, but this is a rather big refactoring.
self.extra_tvars: list[TypeVarLikeType] = []

def __repr__(self) -> str:
op_str = "<:"
Expand Down Expand Up @@ -168,7 +172,9 @@ def infer_constraints_for_callable(
return constraints


def infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:
def infer_constraints(
template: Type, actual: Type, direction: int, skip_neg_op: bool = False
) -> list[Constraint]:
"""Infer type constraints.

Match a template type, which may contain type variable references,
Expand All @@ -187,7 +193,9 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons
((T, S), (X, Y)) --> T :> X and S :> Y
(X[T], Any) --> T <: Any and T :> Any

The constraints are represented as Constraint objects.
The constraints are represented as Constraint objects. If skip_neg_op == True,
then skip adding reverse (polymorphic) constraints (since this is already a call
to infer such constraints).
"""
if any(
get_proper_type(template) == get_proper_type(t)
Expand All @@ -202,13 +210,15 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons
# Return early on an empty branch.
return []
type_state.inferring.append((template, actual))
res = _infer_constraints(template, actual, direction)
res = _infer_constraints(template, actual, direction, skip_neg_op)
type_state.inferring.pop()
return res
return _infer_constraints(template, actual, direction)
return _infer_constraints(template, actual, direction, skip_neg_op)


def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:
def _infer_constraints(
template: Type, actual: Type, direction: int, skip_neg_op: bool
) -> list[Constraint]:
orig_template = template
template = get_proper_type(template)
actual = get_proper_type(actual)
Expand Down Expand Up @@ -284,7 +294,7 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Con
return []

# Remaining cases are handled by ConstraintBuilderVisitor.
return template.accept(ConstraintBuilderVisitor(actual, direction))
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op))


def infer_constraints_if_possible(
Expand Down Expand Up @@ -510,10 +520,14 @@ class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]):
# TODO: The value may be None. Is that actually correct?
actual: ProperType

def __init__(self, actual: ProperType, direction: int) -> None:
def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> None:
# Direction must be SUBTYPE_OF or SUPERTYPE_OF.
self.actual = actual
self.direction = direction
# Whether to skip polymorphic inference (involves inference in opposite direction)
# this is used to prevent infinite recursion when both template and actual are
# generic callables.
self.skip_neg_op = skip_neg_op

# Trivial leaf types

Expand Down Expand Up @@ -648,13 +662,13 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
assert mapped.type.type_var_tuple_prefix is not None
assert mapped.type.type_var_tuple_suffix is not None

unpack_constraints, mapped_args, instance_args = build_constraints_for_unpack(
mapped.args,
mapped.type.type_var_tuple_prefix,
mapped.type.type_var_tuple_suffix,
unpack_constraints, instance_args, mapped_args = build_constraints_for_unpack(
instance.args,
instance.type.type_var_tuple_prefix,
instance.type.type_var_tuple_suffix,
mapped.args,
mapped.type.type_var_tuple_prefix,
mapped.type.type_var_tuple_suffix,
self.direction,
)
res.extend(unpack_constraints)
Expand Down Expand Up @@ -879,6 +893,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# Note that non-normalized callables can be created in annotations
# using e.g. callback protocols.
template = template.with_unpacked_kwargs()
extra_tvars = False
if isinstance(self.actual, CallableType):
res: list[Constraint] = []
cactual = self.actual.with_unpacked_kwargs()
Expand All @@ -890,25 +905,23 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
type_state.infer_polymorphic
and cactual.variables
and cactual.param_spec() is None
and not self.skip_neg_op
# Technically, the correct inferred type for application of e.g.
# Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic
# like U -> U, should be Callable[..., Any], but if U is a self-type, we can
# allow it to leak, to be later bound to self. A bunch of existing code
# depends on this old behaviour.
and not any(tv.id.raw_id == 0 for tv in cactual.variables)
):
# If actual is generic, unify it with template. Note: this is
# not an ideal solution (which would be adding the generic variables
# to the constraint inference set), but it's a good first approximation,
# and this will prevent leaking these variables in the solutions.
# Note: this may infer constraints like T <: S or T <: List[S]
# that contain variables in the target.
unified = mypy.subtypes.unify_generic_callable(
cactual, template, ignore_return=True
# If the actual callable is generic, infer constraints in the opposite
# direction, and indicate to the solver there are extra type variables
# to solve for (see more details in mypy/solve.py).
res.extend(
infer_constraints(
cactual, template, neg_op(self.direction), skip_neg_op=True
)
)
if unified is not None:
cactual = unified
res.extend(infer_constraints(cactual, template, neg_op(self.direction)))
extra_tvars = True

# We can't infer constraints from arguments if the template is Callable[..., T]
# (with literal '...').
Expand Down Expand Up @@ -978,6 +991,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
cactual_ret_type = cactual.type_guard

res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))
if extra_tvars:
for c in res:
c.extra_tvars = list(cactual.variables)
return res
elif isinstance(self.actual, AnyType):
param_spec = template.param_spec()
Expand Down Expand Up @@ -1205,6 +1221,9 @@ def find_and_build_constraints_for_unpack(


def build_constraints_for_unpack(
# TODO: this naming is misleading, these should be "actual", not "mapped"
# both template and actual can be mapped before, depending on direction.
# Also the convention is to put template related args first.
mapped: tuple[Type, ...],
mapped_prefix_len: int | None,
mapped_suffix_len: int | None,
Expand Down
5 changes: 5 additions & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
return repl

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# Sometimes solver may need to expand a type variable with (a copy of) itself
# (usually together with other TypeVars, but it is hard to filter out TypeVarTuples).
repl = self.variables[t.id]
if isinstance(repl, TypeVarTupleType):
return repl
raise NotImplementedError

def visit_unpack_type(self, t: UnpackType) -> Type:
Expand Down
Loading