From e4601897b4f1a792a82ef35df0cc29982ded32f6 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 23 Jun 2023 23:03:35 +0100 Subject: [PATCH 01/14] Start with a plan (and a test) --- mypy/constraints.py | 17 +++++------------ test-data/unit/check-generics.test | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 803b9819be6f..3cac01133f56 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -898,18 +898,11 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # depends on this old behaviour. and not any(tv.id.raw_id == 0 for tv in cactual.variables) ): - # If actual is generic, unify it with template. Note: this is - # not an ideal solution (which would be adding the generic variables - # to the constraint inference set), but it's a good first approximation, - # and this will prevent leaking these variables in the solutions. - # Note: this may infer constraints like T <: S or T <: List[S] - # that contain variables in the target. - unified = mypy.subtypes.unify_generic_callable( - cactual, template, ignore_return=True - ) - if unified is not None: - cactual = unified - res.extend(infer_constraints(cactual, template, neg_op(self.direction))) + # TODO: prevent infinite neg_op() + # TODO: notify solver of extra type vars + # TODO: implement secondary constraints + # TODO: fix polymorphic application to support new vars + res.extend(infer_constraints(cactual, template, neg_op(self.direction))) # We can't infer constraints from arguments if the template is Callable[..., T] # (with literal '...'). diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index b78fd21d4817..14d810560c78 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2900,7 +2900,22 @@ def test2(x: V, y: V) -> V: ... reveal_type(dec1(test1)) # N: Revealed type is "def () -> def [T] (T`1) -> T`1" # TODO: support this situation reveal_type(dec2(test2)) # N: Revealed type is "def (builtins.object) -> def (builtins.object) -> builtins.object" -[builtins fixtures/paramspec.pyi] +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableNewVariable] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[S], T]) -> Callable[[S], T]: + ... +def test(x: List[U]) -> List[U]: + ... +reveal_type(dec(test)) # N: Revealed type is "def [U] (builtins.list[U`1]) -> builtins.list[U`1]" +[builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericAlias] # flags: --new-type-inference From 495259b9c3e43bdff719707c6936fcd407560b24 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 2 Jul 2023 22:15:05 +0100 Subject: [PATCH 02/14] Make small progress --- mypy/constraints.py | 32 ++++++++++++++++++++++++-------- mypy/solve.py | 12 +++++++++--- mypy_self_check.ini | 1 + 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 3cac01133f56..4820547118d2 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -74,6 +74,10 @@ def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None: self.op = op self.target = target self.origin_type_var = type_var + # These are additional type variables that should be solved for together with type_var. + # TODO: A cleaner solution may be to modify the return type of infer_constraints() + # to include these instead, but this is a rather big refactoring. + self.extra_tvars: list[TypeVarLikeType] = [] def __repr__(self) -> str: op_str = "<:" @@ -169,7 +173,9 @@ def infer_constraints_for_callable( return constraints -def infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]: +def infer_constraints( + template: Type, actual: Type, direction: int, skip_neg_op: bool = False +) -> list[Constraint]: """Infer type constraints. Match a template type, which may contain type variable references, @@ -203,13 +209,15 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons # Return early on an empty branch. return [] type_state.inferring.append((template, actual)) - res = _infer_constraints(template, actual, direction) + res = _infer_constraints(template, actual, direction, skip_neg_op) type_state.inferring.pop() return res - return _infer_constraints(template, actual, direction) + return _infer_constraints(template, actual, direction, skip_neg_op) -def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]: +def _infer_constraints( + template: Type, actual: Type, direction: int, skip_neg_op: bool +) -> list[Constraint]: orig_template = template template = get_proper_type(template) actual = get_proper_type(actual) @@ -285,7 +293,7 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Con return [] # Remaining cases are handled by ConstraintBuilderVisitor. - return template.accept(ConstraintBuilderVisitor(actual, direction)) + return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op)) def infer_constraints_if_possible( @@ -511,10 +519,14 @@ class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]): # TODO: The value may be None. Is that actually correct? actual: ProperType - def __init__(self, actual: ProperType, direction: int) -> None: + def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> None: # Direction must be SUBTYPE_OF or SUPERTYPE_OF. self.actual = actual self.direction = direction + # Whether to skip polymorphic inference (involves inference in opposite direction) + # this is used to prevent infinite recursion when both template and actual are + # generic callables. + self.skip_neg_op = skip_neg_op # Trivial leaf types @@ -880,6 +892,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # Note that non-normalized callables can be created in annotations # using e.g. callback protocols. template = template.with_unpacked_kwargs() + extra_tvars = False if isinstance(self.actual, CallableType): res: list[Constraint] = [] cactual = self.actual.with_unpacked_kwargs() @@ -891,6 +904,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: type_state.infer_polymorphic and cactual.variables and cactual.param_spec() is None + and not self.skip_neg_op # Technically, the correct inferred type for application of e.g. # Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic # like U -> U, should be Callable[..., Any], but if U is a self-type, we can @@ -898,11 +912,10 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # depends on this old behaviour. and not any(tv.id.raw_id == 0 for tv in cactual.variables) ): - # TODO: prevent infinite neg_op() - # TODO: notify solver of extra type vars # TODO: implement secondary constraints # TODO: fix polymorphic application to support new vars res.extend(infer_constraints(cactual, template, neg_op(self.direction))) + extra_tvars = True # We can't infer constraints from arguments if the template is Callable[..., T] # (with literal '...'). @@ -972,6 +985,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: cactual_ret_type = cactual.type_guard res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction)) + if extra_tvars: + for c in res: + c.extra_tvars = list(cactual.variables) return res elif isinstance(self.actual, AnyType): param_spec = template.param_spec() diff --git a/mypy/solve.py b/mypy/solve.py index 6693d66f3479..6b43d2434e17 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -48,14 +48,20 @@ def solve_constraints( # represented differently. Normalize the constraint list w.r.t this equivalence. constraints = normalize_constraints(constraints, vars) + extra_vars: list[TypeVarId] = [] + if allow_polymorphic: + # Get additional variables from generic actuals. + for c in constraints: + extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) + # Collect a list of constraints for each type variable. - cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars} + cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars} for con in constraints: - if con.type_var in vars: + if con.type_var in vars + extra_vars: cmap[con.type_var].append(con) if allow_polymorphic: - solutions = solve_non_linear(vars, constraints, cmap) + solutions = solve_non_linear(vars + extra_vars, constraints, cmap) else: solutions = {} for tv, cs in cmap.items(): diff --git a/mypy_self_check.ini b/mypy_self_check.ini index d20fcd60a9cb..5a3e46e9de09 100644 --- a/mypy_self_check.ini +++ b/mypy_self_check.ini @@ -8,6 +8,7 @@ always_false = MYPYC plugins = misc/proper_plugin.py python_version = 3.7 exclude = mypy/typeshed/|mypyc/test-data/|mypyc/lib-rt/ +new_type_inference = True enable_error_code = ignore-without-code,redundant-expr [mypy-mypy.visitor] From c9e4454a59f78f5444dc681efc23c46473f81836 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 16 Jul 2023 21:36:21 +0100 Subject: [PATCH 03/14] Complete transitive closure (not used correctly yet) --- mypy/solve.py | 51 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/mypy/solve.py b/mypy/solve.py index 6b43d2434e17..8a7dff79e469 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -4,7 +4,7 @@ from typing import Iterable -from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op, infer_constraints from mypy.expandtype import expand_type from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.join import join_types @@ -191,10 +191,12 @@ def solve_iteratively( result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars) solutions[solvable_tv] = result if result is None: - # TODO: support backtracking lower/upper bound choices + # TODO: support backtracking lower/upper bound choices and order within SCCs. # (will require switching this function from iterative to recursive). continue # Update the (transitive) constraints if there is a solution. + # TODO: do update from graph even within batch + # TODO: no need to update lowers/uppers/cmap within batch subs = {solvable_tv: result} lowers = {tv: {expand_type(l, subs) for l in lowers[tv]} for tv in lowers} uppers = {tv: {expand_type(u, subs) for u in uppers[tv]} for tv in uppers} @@ -325,32 +327,20 @@ def transitive_closure( * {T} <: S <: {U, int} * {T, S} <: U <: {int} """ - # TODO: merge propagate_constraints_for() into this function. - # TODO: add secondary constraints here to make the algorithm complete. uppers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} lowers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} - graph: set[tuple[TypeVarId, TypeVarId]] = set() + graph: set[tuple[TypeVarId, TypeVarId]] = {(tv, tv) for tv in tvars} - # Prime the closure with the initial trivial values. - for c in constraints: - if isinstance(c.target, TypeVarType) and c.target.id in tvars: - if c.op == SUBTYPE_OF: - graph.add((c.type_var, c.target.id)) - else: - graph.add((c.target.id, c.type_var)) - if c.op == SUBTYPE_OF: - uppers[c.type_var].add(c.target) - else: - lowers[c.type_var].add(c.target) - - # At this stage we know that constant bounds have been propagated already, so we - # only need to propagate linear constraints. - for c in constraints: + remaining = set(constraints) + while remaining: + c = remaining.pop() if isinstance(c.target, TypeVarType) and c.target.id in tvars: if c.op == SUBTYPE_OF: lower, upper = c.type_var, c.target.id else: lower, upper = c.target.id, c.type_var + if (lower, upper) in graph: + continue extras = { (l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph } @@ -361,6 +351,27 @@ def transitive_closure( for l in tvars: if (l, lower) in graph: uppers[l] |= uppers[upper] + for lt in lowers[lower]: + for ut in uppers[upper]: + # TODO: what if secondary constraints result in inference + # against polymorphic actual (also in below branches)? + remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF)) + remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF)) + elif c.op == SUBTYPE_OF: + for l in tvars: + if (l, c.type_var) in graph: + uppers[l].add(c.target) + for lt in lowers[c.type_var]: + remaining |= set(infer_constraints(lt, c.target, SUBTYPE_OF)) + remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF)) + else: + assert c.op == SUPERTYPE_OF + for u in tvars: + if (c.type_var, u) in graph: + lowers[u].add(c.target) + for ut in uppers[c.type_var]: + remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF)) + remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF)) return lowers, uppers From e30c05787ab93534efec18ebb71f19c7ab7d06eb Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 20 Jul 2023 23:07:35 +0100 Subject: [PATCH 04/14] Correctly use full algorithm --- mypy/constraints.py | 1 - mypy/solve.py | 165 +++++++++++++++++++++----------------------- 2 files changed, 79 insertions(+), 87 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 4820547118d2..46aed550f7cb 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -912,7 +912,6 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # depends on this old behaviour. and not any(tv.id.raw_id == 0 for tv in cactual.variables) ): - # TODO: implement secondary constraints # TODO: fix polymorphic application to support new vars res.extend(infer_constraints(cactual, template, neg_op(self.direction))) extra_tvars = True diff --git a/mypy/solve.py b/mypy/solve.py index 8a7dff79e469..05a6462bec65 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -61,7 +61,7 @@ def solve_constraints( cmap[con.type_var].append(con) if allow_polymorphic: - solutions = solve_non_linear(vars + extra_vars, constraints, cmap) + solutions = solve_non_linear(vars + extra_vars, constraints) else: solutions = {} for tv, cs in cmap.items(): @@ -88,7 +88,7 @@ def solve_constraints( def solve_non_linear( - vars: list[TypeVarId], constraints: list[Constraint], cmap: dict[TypeVarId, list[Constraint]] + vars: list[TypeVarId], constraints: list[Constraint] ) -> dict[TypeVarId, Type | None]: """Solve set of constraints that may include non-linear ones, like T <: List[S]. @@ -99,21 +99,15 @@ def solve_non_linear( * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) * Solve constraints iteratively starting from leafs, updating targets after each step. """ - extra_constraints = [] - for tvar in vars: - extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap)) - extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap)) - constraints += remove_dups(extra_constraints) - - # Recompute constraint map after propagating. - cmap = {tv: [] for tv in vars} - for con in constraints: - if con.type_var in vars: - cmap[con.type_var].append(con) + originals = {c.type_var: c.origin_type_var for c in constraints} + for c in constraints: + for extra_tv in c.extra_tvars: + originals[extra_tv.id] = extra_tv + graph, lowers, uppers = transitive_closure(vars, constraints) - dmap = compute_dependencies(cmap) + dmap = compute_dependencies(vars, graph, lowers, uppers) sccs = list(strongly_connected_components(set(vars), dmap)) - if all(check_linear(scc, cmap) for scc in sccs): + if all(check_linear(scc, lowers, uppers) for scc in sccs): raw_batches = list(topsort(prepare_sccs(sccs, dmap))) leafs = raw_batches[0] free_vars = [] @@ -122,11 +116,7 @@ def solve_non_linear( # same SCC then the only meaningful solution we can express, is that # each variable is equal to a new free variable. For example if we # have T <: S, S <: U, we deduce: T = S = U = . - if all( - isinstance(c.target, TypeVarType) and c.target.id in vars - for tv in scc - for c in cmap[tv] - ): + if all(not lowers[tv] and not uppers[tv] for tv in scc): # For convenience with current type application machinery, we randomly # choose one of the existing type variables in SCC and designate it as free # instead of defining a new type variable as a common solution. @@ -142,9 +132,21 @@ def solve_non_linear( next_bc.extend(list(scc)) batches.append(next_bc) + # Update lowers/uppers with free vars, so these can now be used + # as valid solutions. + for l, u in graph.copy(): + if l == u: + continue + if l in free_vars: + lowers[u].add(originals[l]) + graph.remove((l, u)) + if u in free_vars: + uppers[l].add(originals[u]) + graph.remove((l, u)) + solutions: dict[TypeVarId, Type | None] = {} for flat_batch in batches: - solutions.update(solve_iteratively(flat_batch, cmap, free_vars)) + solutions.update(solve_iteratively(flat_batch, graph, lowers, uppers, free_vars)) # We remove the solutions like T = T for free variables. This will indicate # to the apply function, that they should not be touched. # TODO: return list of free type variables explicitly, this logic is fragile @@ -157,7 +159,11 @@ def solve_non_linear( def solve_iteratively( - batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] + batch: list[TypeVarId], + graph: set[tuple[TypeVarId, TypeVarId]], + lowers: dict[TypeVarId, set[Type]], + uppers: dict[TypeVarId, set[Type]], + free_vars: list[TypeVarId], ) -> dict[TypeVarId, Type | None]: """Solve constraints sequentially, updating constraint targets after each step. @@ -171,17 +177,11 @@ def solve_iteratively( designate S as free, and therefore T = List[S] is a valid solution for T. """ solutions = {} - relevant_constraints = [] - for tv in batch: - relevant_constraints.extend(cmap.get(tv, [])) - lowers, uppers = transitive_closure(batch, relevant_constraints) s_batch = set(batch) not_allowed_vars = [v for v in batch if v not in free_vars] while s_batch: - for tv in s_batch: - if any(not get_vars(l, not_allowed_vars) for l in lowers[tv]) or any( - not get_vars(u, not_allowed_vars) for u in uppers[tv] - ): + for tv in sorted(s_batch, key=lambda x: x.raw_id): + if lowers[tv] or uppers[tv]: solvable_tv = tv break else: @@ -194,15 +194,30 @@ def solve_iteratively( # TODO: support backtracking lower/upper bound choices and order within SCCs. # (will require switching this function from iterative to recursive). continue - # Update the (transitive) constraints if there is a solution. - # TODO: do update from graph even within batch - # TODO: no need to update lowers/uppers/cmap within batch - subs = {solvable_tv: result} - lowers = {tv: {expand_type(l, subs) for l in lowers[tv]} for tv in lowers} - uppers = {tv: {expand_type(u, subs) for u in uppers[tv]} for tv in uppers} - for v in cmap: - for c in cmap[v]: - c.target = expand_type(c.target, subs) + + # Update the (transitive) bounds from graph if there is a solution. + # This is needed to guarantee solutions will never contradict the initial + # constraints. For example, consider {T <: S, T <: A, S :> B} with A :> B. + # If we would not update the uppers/lowers from graph, we would infer T = A, S = B + # which is not correct. + for l, u in graph.copy(): + if l == u: + continue + if l == solvable_tv: + lowers[u].add(result) + graph.remove((l, u)) + if u == solvable_tv: + uppers[l].add(result) + graph.remove((l, u)) + + # We can update uppers/lowers only once after solving the whole SCC, + # since uppers/lowers can't depend on type variables in the SCC + # (and we would reject such SCC as non-linear and therefore not solvable). + subs = {tv: s for (tv, s) in solutions.items() if s is not None} + for tv in lowers: + lowers[tv] = {expand_type(lt, subs) for lt in lowers[tv]} + for tv in uppers: + uppers[tv] = {expand_type(ut, subs) for ut in uppers[tv]} return solutions @@ -278,44 +293,11 @@ def normalize_constraints( return [c for c in remove_dups(constraints) if c.type_var in vars] -def propagate_constraints_for( - var: TypeVarId, direction: int, cmap: dict[TypeVarId, list[Constraint]] -) -> list[Constraint]: - """Propagate via linear constraints to get additional constraints for `var`. - - For example if we have constraints: - [T <: int, S <: T, S :> str] - we can add two more - [S <: int, T :> str] - """ - extra_constraints = [] - seen = set() - front = [var] - if cmap[var]: - var_def = cmap[var][0].origin_type_var - else: - return [] - while front: - tv = front.pop(0) - for c in cmap[tv]: - if ( - isinstance(c.target, TypeVarType) - and c.target.id not in seen - and c.target.id in cmap - and c.op == direction - ): - front.append(c.target.id) - seen.add(c.target.id) - elif c.op == direction: - new_c = Constraint(var_def, direction, c.target) - if new_c not in cmap[var]: - extra_constraints.append(new_c) - return extra_constraints - - def transitive_closure( tvars: list[TypeVarId], constraints: list[Constraint] -) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]: +) -> tuple[ + set[tuple[TypeVarId, TypeVarId]], dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]] +]: """Find transitive closure for given constraints on type variables. Transitive closure gives maximal set of lower/upper bounds for each type variable, @@ -372,11 +354,14 @@ def transitive_closure( for ut in uppers[c.type_var]: remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF)) remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF)) - return lowers, uppers + return graph, lowers, uppers def compute_dependencies( - cmap: dict[TypeVarId, list[Constraint]] + tvars: list[TypeVarId], + graph: set[tuple[TypeVarId, TypeVarId]], + lowers: dict[TypeVarId, set[Type]], + uppers: dict[TypeVarId, set[Type]], ) -> dict[TypeVarId, list[TypeVarId]]: """Compute dependencies between type variables induced by constraints. @@ -384,25 +369,33 @@ def compute_dependencies( we will need to solve for S first before we can solve for T. """ res = {} - vars = list(cmap.keys()) - for tv in cmap: + for tv in tvars: deps = set() - for c in cmap[tv]: - deps |= get_vars(c.target, vars) + for lt in lowers[tv]: + deps |= get_vars(lt, tvars) + for ut in uppers[tv]: + deps |= get_vars(ut, tvars) + for other in tvars: + if other == tv: + # TODO: is there a value in either skipping or adding trivial deps? + continue + if (tv, other) in graph or (other, tv) in graph: + deps.add(other) res[tv] = list(deps) return res -def check_linear(scc: set[TypeVarId], cmap: dict[TypeVarId, list[Constraint]]) -> bool: +def check_linear( + scc: set[TypeVarId], lowers: dict[TypeVarId, set[Type]], uppers: dict[TypeVarId, set[Type]] +) -> bool: """Check there are only linear constraints between type variables in SCC. Linear are constraints like T <: S (while T <: F[S] are non-linear). """ for tv in scc: - if any( - get_vars(c.target, list(scc)) and not isinstance(c.target, TypeVarType) - for c in cmap[tv] - ): + if any(get_vars(lt, list(scc)) for lt in lowers[tv]): + return False + if any(get_vars(ut, list(scc)) for ut in uppers[tv]): return False return True From 71e46cc7a55072cee7a4e6569c4f356f517ca192 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 22 Jul 2023 18:25:53 +0100 Subject: [PATCH 05/14] Make some progress --- mypy/constraints.py | 6 +++++- mypy/options.py | 2 +- mypy/solve.py | 33 ++++++++++++++++++------------ test-data/unit/check-generics.test | 3 +-- 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 46aed550f7cb..7e0abdb1dd19 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -913,7 +913,11 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: and not any(tv.id.raw_id == 0 for tv in cactual.variables) ): # TODO: fix polymorphic application to support new vars - res.extend(infer_constraints(cactual, template, neg_op(self.direction))) + res.extend( + infer_constraints( + cactual, template, neg_op(self.direction), skip_neg_op=True + ) + ) extra_tvars = True # We can't infer constraints from arguments if the template is Callable[..., T] diff --git a/mypy/options.py b/mypy/options.py index daa666dc7638..f4abc992e670 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -350,7 +350,7 @@ def __init__(self) -> None: # -1 means unlimited. self.many_errors_threshold = defaults.MANY_ERRORS_THRESHOLD # Enable new experimental type inference algorithm. - self.new_type_inference = False + self.new_type_inference = True # Disable recursive type aliases (currently experimental) self.disable_recursive_aliases = False # Deprecated reverse version of the above, do not use. diff --git a/mypy/solve.py b/mypy/solve.py index 05a6462bec65..575c94952216 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections import defaultdict from typing import Iterable from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op, infer_constraints @@ -43,16 +44,15 @@ def solve_constraints( """ if not vars: return [] - if allow_polymorphic: - # Constraints like T :> S and S <: T are semantically the same, but they are - # represented differently. Normalize the constraint list w.r.t this equivalence. - constraints = normalize_constraints(constraints, vars) extra_vars: list[TypeVarId] = [] + # Get additional variables from generic actuals. + for c in constraints: + extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) if allow_polymorphic: - # Get additional variables from generic actuals. - for c in constraints: - extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) + # Constraints like T :> S and S <: T are semantically the same, but they are + # represented differently. Normalize the constraint list w.r.t this equivalence. + constraints = normalize_constraints(constraints, vars + extra_vars) # Collect a list of constraints for each type variable. cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars} @@ -61,7 +61,7 @@ def solve_constraints( cmap[con.type_var].append(con) if allow_polymorphic: - solutions = solve_non_linear(vars + extra_vars, constraints) + solutions = solve_non_linear(vars + extra_vars, constraints, vars) else: solutions = {} for tv, cs in cmap.items(): @@ -69,7 +69,14 @@ def solve_constraints( continue lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] uppers = [c.target for c in cs if c.op == SUBTYPE_OF] - solutions[tv] = solve_one(lowers, uppers, []) + solution = solve_one(lowers, uppers, []) + + # Do not leak type variables in non-polymorphic solutions. + # XXX: somehow this breaks defaultdict partial types. + if solution is None or not get_vars( + solution, [tv for tv in extra_vars if tv not in vars] + ): + solutions[tv] = solution res: list[Type | None] = [] for v in vars: @@ -88,7 +95,7 @@ def solve_constraints( def solve_non_linear( - vars: list[TypeVarId], constraints: list[Constraint] + vars: list[TypeVarId], constraints: list[Constraint], original_vars: list[TypeVarId] ) -> dict[TypeVarId, Type | None]: """Solve set of constraints that may include non-linear ones, like T <: List[S]. @@ -121,7 +128,7 @@ def solve_non_linear( # choose one of the existing type variables in SCC and designate it as free # instead of defining a new type variable as a common solution. # TODO: be careful about upper bounds (or values) when introducing free vars. - free_vars.append(sorted(scc, key=lambda x: x.raw_id)[0]) + free_vars.append(sorted(scc, key=lambda x: (x not in original_vars, x.raw_id))[0]) # Flatten the SCCs that are independent, we can solve them together, # since we don't need to update any targets in between. @@ -309,8 +316,8 @@ def transitive_closure( * {T} <: S <: {U, int} * {T, S} <: U <: {int} """ - uppers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} - lowers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} + uppers: dict[TypeVarId, set[Type]] = defaultdict(set) + lowers: dict[TypeVarId, set[Type]] = defaultdict(set) graph: set[tuple[TypeVarId, TypeVarId]] = {(tv, tv) for tv in tvars} remaining = set(constraints) diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 14d810560c78..3e5a751746ef 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2898,8 +2898,7 @@ def test1(x: V) -> V: ... def test2(x: V, y: V) -> V: ... reveal_type(dec1(test1)) # N: Revealed type is "def () -> def [T] (T`1) -> T`1" -# TODO: support this situation -reveal_type(dec2(test2)) # N: Revealed type is "def (builtins.object) -> def (builtins.object) -> builtins.object" +reveal_type(dec2(test2)) # N: Revealed type is "def [T] (T`3) -> def (T`3) -> T`3" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableNewVariable] From 9b0aa1933bf4c4d1b8884a2bcf1c40393fc22474 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 22 Jul 2023 23:31:38 +0100 Subject: [PATCH 06/14] Complete feature work --- mypy/checkexpr.py | 32 ++++---- mypy/constraints.py | 1 - mypy/expandtype.py | 5 ++ mypy/infer.py | 10 +-- mypy/solve.py | 116 ++++++++++++++--------------- mypy/subtypes.py | 3 +- mypy/test/testsolve.py | 50 ++++++------- mypy/typeops.py | 6 +- test-data/unit/check-generics.test | 2 +- 9 files changed, 106 insertions(+), 119 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 986e58c21762..d68bbccd187f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1852,7 +1852,7 @@ def infer_function_type_arguments_using_context( # expects_literal(identity(3)) # Should type-check if not is_generic_instance(ctx) and not is_literal_type_like(ctx): return callable.copy_modified() - args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx) + args = infer_type_arguments(callable.variables, ret_type, erased_ctx) # Only substitute non-Uninhabited and non-erased types. new_args: list[Type | None] = [] for arg in args: @@ -1901,7 +1901,7 @@ def infer_function_type_arguments( else: pass1_args.append(arg) - inferred_args = infer_function_type_arguments( + inferred_args, _ = infer_function_type_arguments( callee_type, pass1_args, arg_kinds, @@ -1943,7 +1943,7 @@ def infer_function_type_arguments( # variables while allowing for polymorphic solutions, i.e. for solutions # potentially involving free variables. # TODO: support the similar inference for return type context. - poly_inferred_args = infer_function_type_arguments( + poly_inferred_args, free_vars = infer_function_type_arguments( callee_type, arg_types, arg_kinds, @@ -1952,24 +1952,18 @@ def infer_function_type_arguments( strict=self.chk.in_checked_function(), allow_polymorphic=True, ) - for i, pa in enumerate(get_proper_types(poly_inferred_args)): - if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa): - # Indicate that free variables should not be applied in the call below. - poly_inferred_args[i] = None poly_callee_type = self.apply_generic_arguments( callee_type, poly_inferred_args, context ) - yes_vars = poly_callee_type.variables - no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables} - if not set(get_type_vars(poly_callee_type)) & no_vars: - # Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can - # be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed. - applied = apply_poly(poly_callee_type, yes_vars) - if applied is not None and poly_inferred_args != [UninhabitedType()] * len( - poly_inferred_args - ): - freeze_all_type_vars(applied) - return applied + # Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can + # be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed. + applied = apply_poly(poly_callee_type, free_vars) + if applied is not None and all( + a is not None and not isinstance(get_proper_type(a), UninhabitedType) + for a in poly_inferred_args + ): + freeze_all_type_vars(applied) + return applied # If it didn't work, erase free variables as , to avoid confusing errors. inferred_args = [ expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables}) @@ -2014,7 +2008,7 @@ def infer_function_type_arguments_pass2( arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual) - inferred_args = infer_function_type_arguments( + inferred_args, _ = infer_function_type_arguments( callee_type, arg_types, arg_kinds, diff --git a/mypy/constraints.py b/mypy/constraints.py index 7e0abdb1dd19..9d4ccb5b5e3f 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -912,7 +912,6 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # depends on this old behaviour. and not any(tv.id.raw_id == 0 for tv in cactual.variables) ): - # TODO: fix polymorphic application to support new vars res.extend( infer_constraints( cactual, template, neg_op(self.direction), skip_neg_op=True diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 5bbdd5311da7..7e3c9c1032a4 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -271,6 +271,11 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: return repl def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: + # Sometimes solver may need to expand a type variable with itself + # (usually for other TypeVars, and it is hard to filter out TypeVarTuples). + repl = self.variables[t.id] + if isinstance(repl, TypeVarTupleType): + return repl raise NotImplementedError def visit_unpack_type(self, t: UnpackType) -> Type: diff --git a/mypy/infer.py b/mypy/infer.py index 66ca4169e2ff..a347f192f6f9 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -12,7 +12,7 @@ ) from mypy.nodes import ArgKind from mypy.solve import solve_constraints -from mypy.types import CallableType, Instance, Type, TypeVarId +from mypy.types import CallableType, Instance, Type, TypeVarId, TypeVarType, TypeVarLikeType class ArgumentInferContext(NamedTuple): @@ -37,7 +37,7 @@ def infer_function_type_arguments( context: ArgumentInferContext, strict: bool = True, allow_polymorphic: bool = False, -) -> list[Type | None]: +) -> tuple[list[Type | None], list[TypeVarType]]: """Infer the type arguments of a generic function. Return an array of lower bound types for the type variables -1 (at @@ -57,14 +57,14 @@ def infer_function_type_arguments( ) # Solve constraints. - type_vars = callee_type.type_var_ids() + type_vars = callee_type.variables return solve_constraints(type_vars, constraints, strict, allow_polymorphic) def infer_type_arguments( - type_var_ids: list[TypeVarId], template: Type, actual: Type, is_supertype: bool = False + type_var_ids: Sequence[TypeVarLikeType], template: Type, actual: Type, is_supertype: bool = False ) -> list[Type | None]: # Like infer_function_type_arguments, but only match a single type # against a generic type. constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) - return solve_constraints(type_var_ids, constraints) + return solve_constraints(type_var_ids, constraints)[0] diff --git a/mypy/solve.py b/mypy/solve.py index 575c94952216..4e27a2d163e0 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict -from typing import Iterable +from typing import Iterable, Sequence from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op, infer_constraints from mypy.expandtype import expand_type @@ -23,16 +23,17 @@ UnionType, get_proper_type, remove_dups, + TypeVarLikeType, ) from mypy.typestate import type_state def solve_constraints( - vars: list[TypeVarId], + original_vars: Sequence[TypeVarLikeType], constraints: list[Constraint], strict: bool = True, allow_polymorphic: bool = False, -) -> list[Type | None]: +) -> tuple[list[Type | None], list[TypeVarType]]: """Solve type constraints. Return the best type(s) for type variables; each type can be None if the value of the variable @@ -42,13 +43,16 @@ def solve_constraints( pick NoneType as the value of the type variable. If strict=False, pick AnyType. """ + vars = [tv.id for tv in original_vars] if not vars: - return [] + return [], [] + originals = {tv.id: tv for tv in original_vars} extra_vars: list[TypeVarId] = [] # Get additional variables from generic actuals. for c in constraints: extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) + originals.update({v.id: v for v in c.extra_tvars if v.id not in originals}) if allow_polymorphic: # Constraints like T :> S and S <: T are semantically the same, but they are # represented differently. Normalize the constraint list w.r.t this equivalence. @@ -61,9 +65,10 @@ def solve_constraints( cmap[con.type_var].append(con) if allow_polymorphic: - solutions = solve_non_linear(vars + extra_vars, constraints, vars) + solutions, free_vars = solve_non_linear(vars + extra_vars, constraints, vars, originals) else: solutions = {} + free_vars = [] for tv, cs in cmap.items(): if not cs: continue @@ -91,12 +96,15 @@ def solve_constraints( else: candidate = AnyType(TypeOfAny.special_form) res.append(candidate) - return res + return res, [originals[tv] for tv in free_vars] def solve_non_linear( - vars: list[TypeVarId], constraints: list[Constraint], original_vars: list[TypeVarId] -) -> dict[TypeVarId, Type | None]: + vars: list[TypeVarId], + constraints: list[Constraint], + original_vars: list[TypeVarId], + originals: dict[TypeVarId, TypeVarLikeType], +) -> tuple[dict[TypeVarId, Type | None], list[TypeVarId]]: """Solve set of constraints that may include non-linear ones, like T <: List[S]. The whole algorithm consists of five steps: @@ -106,63 +114,49 @@ def solve_non_linear( * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) * Solve constraints iteratively starting from leafs, updating targets after each step. """ - originals = {c.type_var: c.origin_type_var for c in constraints} - for c in constraints: - for extra_tv in c.extra_tvars: - originals[extra_tv.id] = extra_tv graph, lowers, uppers = transitive_closure(vars, constraints) dmap = compute_dependencies(vars, graph, lowers, uppers) sccs = list(strongly_connected_components(set(vars), dmap)) - if all(check_linear(scc, lowers, uppers) for scc in sccs): - raw_batches = list(topsort(prepare_sccs(sccs, dmap))) - leafs = raw_batches[0] - free_vars = [] - for scc in leafs: - # If all constrain targets in this SCC are type variables within the - # same SCC then the only meaningful solution we can express, is that - # each variable is equal to a new free variable. For example if we - # have T <: S, S <: U, we deduce: T = S = U = . - if all(not lowers[tv] and not uppers[tv] for tv in scc): - # For convenience with current type application machinery, we randomly - # choose one of the existing type variables in SCC and designate it as free - # instead of defining a new type variable as a common solution. - # TODO: be careful about upper bounds (or values) when introducing free vars. - free_vars.append(sorted(scc, key=lambda x: (x not in original_vars, x.raw_id))[0]) - - # Flatten the SCCs that are independent, we can solve them together, - # since we don't need to update any targets in between. - batches = [] - for batch in raw_batches: - next_bc = [] - for scc in batch: - next_bc.extend(list(scc)) - batches.append(next_bc) - - # Update lowers/uppers with free vars, so these can now be used - # as valid solutions. - for l, u in graph.copy(): - if l == u: - continue - if l in free_vars: - lowers[u].add(originals[l]) - graph.remove((l, u)) - if u in free_vars: - uppers[l].add(originals[u]) - graph.remove((l, u)) - - solutions: dict[TypeVarId, Type | None] = {} - for flat_batch in batches: - solutions.update(solve_iteratively(flat_batch, graph, lowers, uppers, free_vars)) - # We remove the solutions like T = T for free variables. This will indicate - # to the apply function, that they should not be touched. - # TODO: return list of free type variables explicitly, this logic is fragile - # (but if we do, we need to be careful everything works in incremental modes). - for tv in free_vars: - if tv in solutions: - del solutions[tv] - return solutions - return {} + if not all(check_linear(scc, lowers, uppers) for scc in sccs): + return {}, [] + raw_batches = list(topsort(prepare_sccs(sccs, dmap))) + + free_vars = [] + for scc in raw_batches[0]: + # If all constrain targets in this SCC are type variables within the + # same SCC then the only meaningful solution we can express, is that + # each variable is equal to a new free variable. For example if we + # have T <: S, S <: U, we deduce: T = S = U = . + if all(not lowers[tv] and not uppers[tv] for tv in scc): + # For convenience with current type application machinery, we randomly + # choose one of the existing type variables in SCC and designate it as free + # instead of defining a new type variable as a common solution. + # TODO: be careful about upper bounds (or values) when introducing free vars. + free_vars.append(sorted(scc, key=lambda x: (x not in original_vars, x.raw_id))[0]) + + # Update lowers/uppers with free vars, so these can now be used + # as valid solutions. + for l, u in graph.copy(): + if l in free_vars: + lowers[u].add(originals[l]) + if u in free_vars: + uppers[l].add(originals[u]) + + # Flatten the SCCs that are independent, we can solve them together, + # since we don't need to update any targets in between. + batches = [] + for batch in raw_batches: + next_bc = [] + for scc in batch: + next_bc.extend(list(scc)) + batches.append(next_bc) + + solutions: dict[TypeVarId, Type | None] = {} + for flat_batch in batches: + res = solve_iteratively(flat_batch, graph, lowers, uppers, free_vars) + solutions.update(res) + return solutions, free_vars def solve_iteratively( diff --git a/mypy/subtypes.py b/mypy/subtypes.py index c9de56edfa36..b9efa8697a4e 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1708,8 +1708,7 @@ def unify_generic_callable( type.ret_type, target.ret_type, return_constraint_direction ) constraints.extend(c) - type_var_ids = [tvar.id for tvar in type.variables] - inferred_vars = mypy.solve.solve_constraints(type_var_ids, constraints) + inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints) if None in inferred_vars: return None non_none_inferred_vars = cast(List[Type], inferred_vars) diff --git a/mypy/test/testsolve.py b/mypy/test/testsolve.py index d6c585ef4aaa..f0798fd29017 100644 --- a/mypy/test/testsolve.py +++ b/mypy/test/testsolve.py @@ -6,7 +6,7 @@ from mypy.solve import solve_constraints from mypy.test.helpers import Suite, assert_equal from mypy.test.typefixture import TypeFixture -from mypy.types import Type, TypeVarId, TypeVarType +from mypy.types import Type, TypeVarId, TypeVarType, TypeVarLikeType class SolveSuite(Suite): @@ -17,26 +17,24 @@ def test_empty_input(self) -> None: self.assert_solve([], [], []) def test_simple_supertype_constraints(self) -> None: + self.assert_solve([self.fx.t], [self.supc(self.fx.t, self.fx.a)], [(self.fx.a, self.fx.o)]) self.assert_solve( - [self.fx.t.id], [self.supc(self.fx.t, self.fx.a)], [(self.fx.a, self.fx.o)] - ) - self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.a), self.supc(self.fx.t, self.fx.b)], [(self.fx.a, self.fx.o)], ) def test_simple_subtype_constraints(self) -> None: - self.assert_solve([self.fx.t.id], [self.subc(self.fx.t, self.fx.a)], [self.fx.a]) + self.assert_solve([self.fx.t], [self.subc(self.fx.t, self.fx.a)], [self.fx.a]) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.subc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.b)], [self.fx.b], ) def test_both_kinds_of_constraints(self) -> None: self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.b), self.subc(self.fx.t, self.fx.a)], [(self.fx.b, self.fx.a)], ) @@ -44,21 +42,19 @@ def test_both_kinds_of_constraints(self) -> None: def test_unsatisfiable_constraints(self) -> None: # The constraints are impossible to satisfy. self.assert_solve( - [self.fx.t.id], - [self.supc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.b)], - [None], + [self.fx.t], [self.supc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.b)], [None] ) def test_exactly_specified_result(self) -> None: self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.b), self.subc(self.fx.t, self.fx.b)], [(self.fx.b, self.fx.b)], ) def test_multiple_variables(self) -> None: self.assert_solve( - [self.fx.t.id, self.fx.s.id], + [self.fx.t, self.fx.s], [ self.supc(self.fx.t, self.fx.b), self.supc(self.fx.s, self.fx.c), @@ -68,40 +64,38 @@ def test_multiple_variables(self) -> None: ) def test_no_constraints_for_var(self) -> None: - self.assert_solve([self.fx.t.id], [], [self.fx.uninhabited]) - self.assert_solve( - [self.fx.t.id, self.fx.s.id], [], [self.fx.uninhabited, self.fx.uninhabited] - ) + self.assert_solve([self.fx.t], [], [self.fx.uninhabited]) + self.assert_solve([self.fx.t, self.fx.s], [], [self.fx.uninhabited, self.fx.uninhabited]) self.assert_solve( - [self.fx.t.id, self.fx.s.id], + [self.fx.t, self.fx.s], [self.supc(self.fx.s, self.fx.a)], [self.fx.uninhabited, (self.fx.a, self.fx.o)], ) def test_simple_constraints_with_dynamic_type(self) -> None: self.assert_solve( - [self.fx.t.id], [self.supc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] + [self.fx.t], [self.supc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] ) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.anyt), self.supc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)], ) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.anyt), self.supc(self.fx.t, self.fx.a)], [(self.fx.anyt, self.fx.anyt)], ) self.assert_solve( - [self.fx.t.id], [self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] + [self.fx.t], [self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] ) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.subc(self.fx.t, self.fx.anyt), self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)], ) - # self.assert_solve([self.fx.t.id], + # self.assert_solve([self.fx.t], # [self.subc(self.fx.t, self.fx.anyt), # self.subc(self.fx.t, self.fx.a)], # [(self.fx.anyt, self.fx.anyt)]) @@ -111,20 +105,20 @@ def test_both_normal_and_any_types_in_results(self) -> None: # If one of the bounds is any, we promote the other bound to # any as well, since otherwise the type range does not make sense. self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)], ) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.anyt), self.subc(self.fx.t, self.fx.a)], [(self.fx.anyt, self.fx.anyt)], ) def assert_solve( self, - vars: list[TypeVarId], + vars: list[TypeVarLikeType], constraints: list[Constraint], results: list[None | Type | tuple[Type, Type]], ) -> None: @@ -134,7 +128,7 @@ def assert_solve( res.append(r[0]) else: res.append(r) - actual = solve_constraints(vars, constraints) + actual, _ = solve_constraints(vars, constraints) assert_equal(str(actual), str(res)) def supc(self, type_var: TypeVarType, bound: Type) -> Constraint: diff --git a/mypy/typeops.py b/mypy/typeops.py index 58a641a54ab7..e84e0d054afc 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -313,7 +313,9 @@ class B(A): pass original_type = get_proper_type(original_type) all_ids = func.type_var_ids() - typeargs = infer_type_arguments(all_ids, self_param_type, original_type, is_supertype=True) + typeargs = infer_type_arguments( + func.variables, self_param_type, original_type, is_supertype=True + ) if ( is_classmethod # TODO: why do we need the extra guards here? @@ -322,7 +324,7 @@ class B(A): pass ): # In case we call a classmethod through an instance x, fallback to type(x) typeargs = infer_type_arguments( - all_ids, self_param_type, TypeType(original_type), is_supertype=True + func.variables, self_param_type, TypeType(original_type), is_supertype=True ) ids = [tid for tid in all_ids if any(tid == t.id for t in get_type_vars(self_param_type))] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 3e5a751746ef..95f85280e514 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2913,7 +2913,7 @@ def dec(f: Callable[[S], T]) -> Callable[[S], T]: ... def test(x: List[U]) -> List[U]: ... -reveal_type(dec(test)) # N: Revealed type is "def [U] (builtins.list[U`1]) -> builtins.list[U`1]" +reveal_type(dec(test)) # N: Revealed type is "def [U] (builtins.list[U`-1]) -> builtins.list[U`-1]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericAlias] From b43f286d1524f1a665b7ea8be8cdfe9bc5a44a90 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 23 Jul 2023 11:54:45 +0100 Subject: [PATCH 07/14] Fix partial types (and TypeVarTuple) --- mypy/checker.py | 11 ++++++----- mypy/checkexpr.py | 8 ++++++-- mypy/constraints.py | 11 +++++++---- mypy/expandtype.py | 2 +- mypy/infer.py | 9 ++++++--- mypy/solve.py | 15 ++++++++++----- mypy/test/testsolve.py | 2 +- 7 files changed, 37 insertions(+), 21 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index cdce42ddaaa1..e85450e380d6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3878,11 +3878,12 @@ def is_valid_defaultdict_partial_value_type(self, t: ProperType) -> bool: return True if len(t.args) == 1: arg = get_proper_type(t.args[0]) - # TODO: This is too permissive -- we only allow TypeVarType since - # they leak in cases like defaultdict(list) due to a bug. - # This can result in incorrect types being inferred, but only - # in rare cases. - if isinstance(arg, (TypeVarType, UninhabitedType, NoneType)): + if self.options.new_type_inference: + allowed = isinstance(arg, (UninhabitedType, NoneType)) + else: + # Allow leaked TypeVars for legacy inference logic. + allowed = isinstance(arg, (UninhabitedType, NoneType, TypeVarType)) + if allowed: return True return False diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d68bbccd187f..7a2774df6f0a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1965,11 +1965,15 @@ def infer_function_type_arguments( freeze_all_type_vars(applied) return applied # If it didn't work, erase free variables as , to avoid confusing errors. + unknown = UninhabitedType() + unknown.ambiguous = True inferred_args = [ - expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables}) + expand_type( + a, {v.id: unknown for v in list(callee_type.variables) + free_vars} + ) if a is not None else None - for a in inferred_args + for a in poly_inferred_args ] else: # In dynamically typed functions use implicit 'Any' types for diff --git a/mypy/constraints.py b/mypy/constraints.py index 9d4ccb5b5e3f..48833ec083e6 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -661,13 +661,13 @@ def visit_instance(self, template: Instance) -> list[Constraint]: assert mapped.type.type_var_tuple_prefix is not None assert mapped.type.type_var_tuple_suffix is not None - unpack_constraints, mapped_args, instance_args = build_constraints_for_unpack( - mapped.args, - mapped.type.type_var_tuple_prefix, - mapped.type.type_var_tuple_suffix, + unpack_constraints, instance_args, mapped_args = build_constraints_for_unpack( instance.args, instance.type.type_var_tuple_prefix, instance.type.type_var_tuple_suffix, + mapped.args, + mapped.type.type_var_tuple_prefix, + mapped.type.type_var_tuple_suffix, self.direction, ) res.extend(unpack_constraints) @@ -1217,6 +1217,9 @@ def find_and_build_constraints_for_unpack( def build_constraints_for_unpack( + # TODO: this naming is misleading, these should be "actual", not "mapped" + # both template and actual can be mapped before, depending on direction. + # Also the convention is to put template related args first. mapped: tuple[Type, ...], mapped_prefix_len: int | None, mapped_suffix_len: int | None, diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 7e3c9c1032a4..48a9dd162843 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -271,7 +271,7 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: return repl def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: - # Sometimes solver may need to expand a type variable with itself + # Sometimes solver may need to expand a type variable with (a copy of) itself # (usually for other TypeVars, and it is hard to filter out TypeVarTuples). repl = self.variables[t.id] if isinstance(repl, TypeVarTupleType): diff --git a/mypy/infer.py b/mypy/infer.py index a347f192f6f9..e9d2cbf6868b 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -12,7 +12,7 @@ ) from mypy.nodes import ArgKind from mypy.solve import solve_constraints -from mypy.types import CallableType, Instance, Type, TypeVarId, TypeVarType, TypeVarLikeType +from mypy.types import CallableType, Instance, Type, TypeVarLikeType class ArgumentInferContext(NamedTuple): @@ -37,7 +37,7 @@ def infer_function_type_arguments( context: ArgumentInferContext, strict: bool = True, allow_polymorphic: bool = False, -) -> tuple[list[Type | None], list[TypeVarType]]: +) -> tuple[list[Type | None], list[TypeVarLikeType]]: """Infer the type arguments of a generic function. Return an array of lower bound types for the type variables -1 (at @@ -62,7 +62,10 @@ def infer_function_type_arguments( def infer_type_arguments( - type_var_ids: Sequence[TypeVarLikeType], template: Type, actual: Type, is_supertype: bool = False + type_var_ids: Sequence[TypeVarLikeType], + template: Type, + actual: Type, + is_supertype: bool = False, ) -> list[Type | None]: # Like infer_function_type_arguments, but only match a single type # against a generic type. diff --git a/mypy/solve.py b/mypy/solve.py index 4e27a2d163e0..9cb222e4d6e0 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -5,7 +5,7 @@ from collections import defaultdict from typing import Iterable, Sequence -from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op, infer_constraints +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op from mypy.expandtype import expand_type from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.join import join_types @@ -18,12 +18,12 @@ Type, TypeOfAny, TypeVarId, + TypeVarLikeType, TypeVarType, UninhabitedType, UnionType, get_proper_type, remove_dups, - TypeVarLikeType, ) from mypy.typestate import type_state @@ -33,7 +33,7 @@ def solve_constraints( constraints: list[Constraint], strict: bool = True, allow_polymorphic: bool = False, -) -> tuple[list[Type | None], list[TypeVarType]]: +) -> tuple[list[Type | None], list[TypeVarLikeType]]: """Solve type constraints. Return the best type(s) for type variables; each type can be None if the value of the variable @@ -65,7 +65,13 @@ def solve_constraints( cmap[con.type_var].append(con) if allow_polymorphic: - solutions, free_vars = solve_non_linear(vars + extra_vars, constraints, vars, originals) + if constraints: + solutions, free_vars = solve_non_linear( + vars + extra_vars, constraints, vars, originals + ) + else: + solutions = {} + free_vars = [] else: solutions = {} free_vars = [] @@ -77,7 +83,6 @@ def solve_constraints( solution = solve_one(lowers, uppers, []) # Do not leak type variables in non-polymorphic solutions. - # XXX: somehow this breaks defaultdict partial types. if solution is None or not get_vars( solution, [tv for tv in extra_vars if tv not in vars] ): diff --git a/mypy/test/testsolve.py b/mypy/test/testsolve.py index f0798fd29017..5d67203dbbf5 100644 --- a/mypy/test/testsolve.py +++ b/mypy/test/testsolve.py @@ -6,7 +6,7 @@ from mypy.solve import solve_constraints from mypy.test.helpers import Suite, assert_equal from mypy.test.typefixture import TypeFixture -from mypy.types import Type, TypeVarId, TypeVarType, TypeVarLikeType +from mypy.types import Type, TypeVarLikeType, TypeVarType class SolveSuite(Suite): From b0ddd6d2288a6dc107d7c488c80a33f26438f3a3 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 23 Jul 2023 18:13:23 +0100 Subject: [PATCH 08/14] Some simplifications; better docs/comments --- mypy/constraints.py | 4 +- mypy/expandtype.py | 2 +- mypy/infer.py | 7 +-- mypy/solve.py | 129 +++++++++++++++++++++----------------------- 4 files changed, 67 insertions(+), 75 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 48833ec083e6..a504bd1e2d03 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -194,7 +194,9 @@ def infer_constraints( ((T, S), (X, Y)) --> T :> X and S :> Y (X[T], Any) --> T <: Any and T :> Any - The constraints are represented as Constraint objects. + The constraints are represented as Constraint objects. If skip_neg_op == True, + then skip adding reverse (polymorphic) constraints (since this is already a call + to infer such constraints). """ if any( get_proper_type(template) == get_proper_type(t) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 48a9dd162843..85648f88d0b5 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -272,7 +272,7 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: # Sometimes solver may need to expand a type variable with (a copy of) itself - # (usually for other TypeVars, and it is hard to filter out TypeVarTuples). + # (usually together with other TypeVars, but it is hard to filter out TypeVarTuples). repl = self.variables[t.id] if isinstance(repl, TypeVarTupleType): return repl diff --git a/mypy/infer.py b/mypy/infer.py index e9d2cbf6868b..f34087910e4b 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -62,12 +62,9 @@ def infer_function_type_arguments( def infer_type_arguments( - type_var_ids: Sequence[TypeVarLikeType], - template: Type, - actual: Type, - is_supertype: bool = False, + type_vars: Sequence[TypeVarLikeType], template: Type, actual: Type, is_supertype: bool = False ) -> list[Type | None]: # Like infer_function_type_arguments, but only match a single type # against a generic type. constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) - return solve_constraints(type_var_ids, constraints)[0] + return solve_constraints(type_vars, constraints)[0] diff --git a/mypy/solve.py b/mypy/solve.py index 9cb222e4d6e0..53d42379ceeb 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -4,6 +4,7 @@ from collections import defaultdict from typing import Iterable, Sequence +from typing_extensions import TypeAlias as _TypeAlias from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op from mypy.expandtype import expand_type @@ -27,6 +28,10 @@ ) from mypy.typestate import type_state +Bounds: _TypeAlias = "dict[TypeVarId, set[Type]]" +Graph: _TypeAlias = "set[tuple[TypeVarId, TypeVarId]]" +Solutions: _TypeAlias = "dict[TypeVarId, Type | None]" + def solve_constraints( original_vars: Sequence[TypeVarLikeType], @@ -36,12 +41,14 @@ def solve_constraints( ) -> tuple[list[Type | None], list[TypeVarLikeType]]: """Solve type constraints. - Return the best type(s) for type variables; each type can be None if the value of the variable - could not be solved. + Return the best type(s) for type variables; each type can be None if the value of + the variable could not be solved. If a variable has no constraints, if strict=True then arbitrarily - pick NoneType as the value of the type variable. If strict=False, - pick AnyType. + pick UninhabitedType as the value of the type variable. If strict=False, pick AnyType. + If allow_polymorphic=True, then use the full algorithm that can potentially return + free type variables in solutions (these require special care when applying). Otherwise, + use a simplified algorithm that just solves each type variable individually if possible. """ vars = [tv.id for tv in original_vars] if not vars: @@ -49,7 +56,7 @@ def solve_constraints( originals = {tv.id: tv for tv in original_vars} extra_vars: list[TypeVarId] = [] - # Get additional variables from generic actuals. + # Get additional type variables from generic actuals. for c in constraints: extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) originals.update({v.id: v for v in c.extra_tvars if v.id not in originals}) @@ -66,7 +73,7 @@ def solve_constraints( if allow_polymorphic: if constraints: - solutions, free_vars = solve_non_linear( + solutions, free_vars = solve_with_dependent( vars + extra_vars, constraints, vars, originals ) else: @@ -80,7 +87,7 @@ def solve_constraints( continue lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] uppers = [c.target for c in cs if c.op == SUBTYPE_OF] - solution = solve_one(lowers, uppers, []) + solution = solve_one(lowers, uppers) # Do not leak type variables in non-polymorphic solutions. if solution is None or not get_vars( @@ -104,20 +111,20 @@ def solve_constraints( return res, [originals[tv] for tv in free_vars] -def solve_non_linear( +def solve_with_dependent( vars: list[TypeVarId], constraints: list[Constraint], original_vars: list[TypeVarId], originals: dict[TypeVarId, TypeVarLikeType], -) -> tuple[dict[TypeVarId, Type | None], list[TypeVarId]]: - """Solve set of constraints that may include non-linear ones, like T <: List[S]. +) -> tuple[Solutions, list[TypeVarId]]: + """Solve set of constraints that may depend on each other, like T <: List[S]. The whole algorithm consists of five steps: - * Propagate via linear constraints to get all possible constraints for each variable + * Propagate via linear constraints and use secondary constraints to get transitive closure * Find dependencies between type variables, group them in SCCs, and sort topologically - * Check all SCC are intrinsically linear, we can't solve (express) T <: List[T] + * Check that all SCC are intrinsically linear, we can't solve (express) T <: List[T] * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) - * Solve constraints iteratively starting from leafs, updating targets after each step. + * Solve constraints iteratively starting from leafs, updating bounds after each step. """ graph, lowers, uppers = transitive_closure(vars, constraints) @@ -129,14 +136,12 @@ def solve_non_linear( free_vars = [] for scc in raw_batches[0]: - # If all constrain targets in this SCC are type variables within the - # same SCC then the only meaningful solution we can express, is that - # each variable is equal to a new free variable. For example if we - # have T <: S, S <: U, we deduce: T = S = U = . + # If there are no bounds on this SCC, then the only meaningful solution we can + # express, is that each variable is equal to a new free variable. For example, + # if we have T <: S, S <: U, we deduce: T = S = U = . if all(not lowers[tv] and not uppers[tv] for tv in scc): - # For convenience with current type application machinery, we randomly - # choose one of the existing type variables in SCC and designate it as free - # instead of defining a new type variable as a common solution. + # For convenience with current type application machinery, we use a stable + # choice that prefers the original type variables (not polymorphic ones) in SCC. # TODO: be careful about upper bounds (or values) when introducing free vars. free_vars.append(sorted(scc, key=lambda x: (x not in original_vars, x.raw_id))[0]) @@ -159,32 +164,29 @@ def solve_non_linear( solutions: dict[TypeVarId, Type | None] = {} for flat_batch in batches: - res = solve_iteratively(flat_batch, graph, lowers, uppers, free_vars) + res = solve_iteratively(flat_batch, graph, lowers, uppers) solutions.update(res) return solutions, free_vars def solve_iteratively( - batch: list[TypeVarId], - graph: set[tuple[TypeVarId, TypeVarId]], - lowers: dict[TypeVarId, set[Type]], - uppers: dict[TypeVarId, set[Type]], - free_vars: list[TypeVarId], -) -> dict[TypeVarId, Type | None]: - """Solve constraints sequentially, updating constraint targets after each step. - - We solve for type variables that appear in `batch`. If a constraint target is not constant - (i.e. constraint looks like T :> F[S, ...]), we substitute solutions found so far in - the target F[S, ...]. This way we can gradually solve for all variables in the batch taking - one solvable variable at a time (i.e. such a variable that has at least one constant bound). - - Importantly, variables in free_vars are considered constants, so for example if we have just - one initial constraint T <: List[S], we will have two SCCs {T} and {S}, then we first - designate S as free, and therefore T = List[S] is a valid solution for T. + batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds +) -> Solutions: + """Solve transitive closure sequentially, updating upper/lower bounds after each step. + + Transitive closure is represented as a linear graph plus lower/upper bounds for each + type variable, see transitive_closure() docstring for details. + + We solve for type variables that appear in `batch`. If a bound is not constant (i.e. it + looks like T :> F[S, ...]), we substitute solutions found so far in the target F[S, ...] + after solving the batch. + + Importantly, after solving each variable in a batch, we move it from linear graph to + upper/lower bounds, this way we can guarantee consistency of solutions (see comment below + for an example when this is important). """ solutions = {} s_batch = set(batch) - not_allowed_vars = [v for v in batch if v not in free_vars] while s_batch: for tv in sorted(s_batch, key=lambda x: x.raw_id): if lowers[tv] or uppers[tv]: @@ -194,7 +196,7 @@ def solve_iteratively( break # Solve each solvable type variable separately. s_batch.remove(solvable_tv) - result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars) + result = solve_one(lowers[solvable_tv], uppers[solvable_tv]) solutions[solvable_tv] = result if result is None: # TODO: support backtracking lower/upper bound choices and order within SCCs. @@ -227,9 +229,7 @@ def solve_iteratively( return solutions -def solve_one( - lowers: Iterable[Type], uppers: Iterable[Type], not_allowed_vars: list[TypeVarId] -) -> Type | None: +def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" bottom: Type | None = None top: Type | None = None @@ -239,10 +239,6 @@ def solve_one( # bounds based on constraints. Note that we assume that the constraint # targets do not have constraint references. for target in lowers: - # There may be multiple steps needed to solve all vars within a - # (linear) SCC. We ignore targets pointing to not yet solved vars. - if get_vars(target, not_allowed_vars): - continue if bottom is None: bottom = target else: @@ -254,9 +250,6 @@ def solve_one( bottom = join_types(bottom, target) for target in uppers: - # Same as above. - if get_vars(target, not_allowed_vars): - continue if top is None: top = target else: @@ -291,6 +284,7 @@ def normalize_constraints( This includes two things currently: * Complement T :> S by S <: T * Remove strict duplicates + * Remove constrains for unrelated variables """ res = constraints.copy() for c in constraints: @@ -301,23 +295,29 @@ def normalize_constraints( def transitive_closure( tvars: list[TypeVarId], constraints: list[Constraint] -) -> tuple[ - set[tuple[TypeVarId, TypeVarId]], dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]] -]: +) -> tuple[Graph, Bounds, Bounds]: """Find transitive closure for given constraints on type variables. Transitive closure gives maximal set of lower/upper bounds for each type variable, such that we cannot deduce any further bounds by chaining other existing bounds. + The transitive closure is represented by: + * A set of lower and upper bounds for each type variable, where only constant and + non-linear terms are included in the bounds. + * A graph of linear constraints between type variables (represented as a set of pairs) + Such separation simplifies reasoning, and allows an efficient and simple incremental + transitive closure algorithm that we use here. + For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive closure is given by: - * {} <: T <: {S, U, int} - * {T} <: S <: {U, int} - * {T, S} <: U <: {int} + * {} <: T <: {int} + * {} <: S <: {int} + * {} <: U <: {int} + * {T <: S, S <: U, T <: U} """ - uppers: dict[TypeVarId, set[Type]] = defaultdict(set) - lowers: dict[TypeVarId, set[Type]] = defaultdict(set) - graph: set[tuple[TypeVarId, TypeVarId]] = {(tv, tv) for tv in tvars} + uppers: Bounds = defaultdict(set) + lowers: Bounds = defaultdict(set) + graph: Graph = {(tv, tv) for tv in tvars} remaining = set(constraints) while remaining: @@ -329,10 +329,9 @@ def transitive_closure( lower, upper = c.target.id, c.type_var if (lower, upper) in graph: continue - extras = { + graph |= { (l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph } - graph |= extras for u in tvars: if (upper, u) in graph: lowers[u] |= lowers[lower] @@ -364,10 +363,7 @@ def transitive_closure( def compute_dependencies( - tvars: list[TypeVarId], - graph: set[tuple[TypeVarId, TypeVarId]], - lowers: dict[TypeVarId, set[Type]], - uppers: dict[TypeVarId, set[Type]], + tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds ) -> dict[TypeVarId, list[TypeVarId]]: """Compute dependencies between type variables induced by constraints. @@ -383,7 +379,6 @@ def compute_dependencies( deps |= get_vars(ut, tvars) for other in tvars: if other == tv: - # TODO: is there a value in either skipping or adding trivial deps? continue if (tv, other) in graph or (other, tv) in graph: deps.add(other) @@ -391,9 +386,7 @@ def compute_dependencies( return res -def check_linear( - scc: set[TypeVarId], lowers: dict[TypeVarId, set[Type]], uppers: dict[TypeVarId, set[Type]] -) -> bool: +def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool: """Check there are only linear constraints between type variables in SCC. Linear are constraints like T <: S (while T <: F[S] are non-linear). From 49d5415eb71d50f9d09785e99aeaf745648a5ad7 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 23 Jul 2023 18:57:08 +0100 Subject: [PATCH 09/14] Add another comment --- mypy/constraints.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mypy/constraints.py b/mypy/constraints.py index 016ab7576127..299c6292a259 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -913,6 +913,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # depends on this old behaviour. and not any(tv.id.raw_id == 0 for tv in cactual.variables) ): + # If the actual callable is generic, infer constraints in the opposite + # direction, and indicate to the solver there are extra type variables + # to solve for (see more details in mypy/solve.py). res.extend( infer_constraints( cactual, template, neg_op(self.direction), skip_neg_op=True From 99699c3549c3f85c8e7143345d9b0c3bf43377b6 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 23 Jul 2023 20:01:58 +0100 Subject: [PATCH 10/14] Fix xfail --- mypy/test/testconstraints.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mypy/test/testconstraints.py b/mypy/test/testconstraints.py index b46f31327150..338f764dc8e8 100644 --- a/mypy/test/testconstraints.py +++ b/mypy/test/testconstraints.py @@ -22,7 +22,6 @@ def test_basic_type_variable(self) -> None: Constraint(type_var=fx.t, op=direction, target=fx.a) ] - @pytest.mark.xfail def test_basic_type_var_tuple_subtype(self) -> None: fx = self.fx assert infer_constraints( From bfc7ffc20d5eea9967ae2d758ea44b2aa383f369 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Jul 2023 19:03:52 +0000 Subject: [PATCH 11/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/test/testconstraints.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mypy/test/testconstraints.py b/mypy/test/testconstraints.py index 338f764dc8e8..f40996145cba 100644 --- a/mypy/test/testconstraints.py +++ b/mypy/test/testconstraints.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints from mypy.test.helpers import Suite from mypy.test.typefixture import TypeFixture From 1018146b839b57bff6c9a0584626a38d760410fa Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 23 Jul 2023 21:28:24 +0100 Subject: [PATCH 12/14] Fix infinite loop --- mypy/solve.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mypy/solve.py b/mypy/solve.py index 53d42379ceeb..02df90aff1e1 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -345,6 +345,8 @@ def transitive_closure( remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF)) remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF)) elif c.op == SUBTYPE_OF: + if c.target in uppers[c.type_var]: + continue for l in tvars: if (l, c.type_var) in graph: uppers[l].add(c.target) @@ -353,6 +355,8 @@ def transitive_closure( remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF)) else: assert c.op == SUPERTYPE_OF + if c.target in lowers[c.type_var]: + continue for u in tvars: if (c.type_var, u) in graph: lowers[u].add(c.target) From db9df7192e9929ffc0e6b3d715eced356d6e6591 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 26 Jul 2023 23:10:43 +0100 Subject: [PATCH 13/14] Fix a bug in detach_callable --- mypy/checker.py | 61 +++++++++------------------ test-data/unit/check-overloading.test | 15 +++++++ 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 0a053b28ef99..6d0fcfcf699a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -733,8 +733,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # def foo(x: str) -> str: ... # # See Python 2's map function for a concrete example of this kind of overload. + current_class = self.scope.active_class() + type_vars = current_class.defn.type_vars if current_class else [] with state.strict_optional_set(True): - if is_unsafe_overlapping_overload_signatures(sig1, sig2): + if is_unsafe_overlapping_overload_signatures(sig1, sig2, type_vars): self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func) if impl_type is not None: @@ -1697,7 +1699,9 @@ def is_unsafe_overlapping_op( first = forward_tweaked second = reverse_tweaked - return is_unsafe_overlapping_overload_signatures(first, second) + current_class = self.scope.active_class() + type_vars = current_class.defn.type_vars if current_class else [] + return is_unsafe_overlapping_overload_signatures(first, second, type_vars) def check_inplace_operator_method(self, defn: FuncBase) -> None: """Check an inplace operator method such as __iadd__. @@ -7154,7 +7158,7 @@ def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool: def is_unsafe_overlapping_overload_signatures( - signature: CallableType, other: CallableType + signature: CallableType, other: CallableType, class_type_vars: list[TypeVarLikeType] ) -> bool: """Check if two overloaded signatures are unsafely overlapping or partially overlapping. @@ -7173,8 +7177,8 @@ def is_unsafe_overlapping_overload_signatures( # This lets us identify cases where the two signatures use completely # incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars # test case. - signature = detach_callable(signature) - other = detach_callable(other) + signature = detach_callable(signature, class_type_vars) + other = detach_callable(other, class_type_vars) # Note: We repeat this check twice in both directions due to a slight # asymmetry in 'is_callable_compatible'. When checking for partial overlaps, @@ -7205,7 +7209,7 @@ def is_unsafe_overlapping_overload_signatures( ) -def detach_callable(typ: CallableType) -> CallableType: +def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) -> CallableType: """Ensures that the callable's type variables are 'detached' and independent of the context. A callable normally keeps track of the type variables it uses within its 'variables' field. @@ -7215,42 +7219,17 @@ def detach_callable(typ: CallableType) -> CallableType: This function will traverse the callable and find all used type vars and add them to the variables field if it isn't already present. - The caller can then unify on all type variables whether or not the callable is originally - from a class or not.""" - type_list = typ.arg_types + [typ.ret_type] - - appear_map: dict[str, list[int]] = {} - for i, inner_type in enumerate(type_list): - typevars_available = get_type_vars(inner_type) - for var in typevars_available: - if var.fullname not in appear_map: - appear_map[var.fullname] = [] - appear_map[var.fullname].append(i) - - used_type_var_names = set() - for var_name, appearances in appear_map.items(): - used_type_var_names.add(var_name) - - all_type_vars = get_type_vars(typ) - new_variables = [] - for var in set(all_type_vars): - if var.fullname not in used_type_var_names: - continue - new_variables.append( - TypeVarType( - name=var.name, - fullname=var.fullname, - id=var.id, - values=var.values, - upper_bound=var.upper_bound, - default=var.default, - variance=var.variance, - ) - ) - out = typ.copy_modified( - variables=new_variables, arg_types=type_list[:-1], ret_type=type_list[-1] + The caller can then unify on all type variables whether the callable is originally from + the class or not.""" + if not class_type_vars: + # Fast path, nothing to update. + return typ + seen_type_vars = set() + for t in typ.arg_types + [typ.ret_type]: + seen_type_vars |= set(get_type_vars(t)) + return typ.copy_modified( + variables=list(typ.variables) + [tv for tv in class_type_vars if tv in seen_type_vars] ) - return out def overload_can_never_match(signature: CallableType, other: CallableType) -> bool: diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index f49a15ada85c..8a27c2bc2266 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6554,3 +6554,18 @@ class Snafu(object): reveal_type(Snafu().snafu('123')) # N: Revealed type is "builtins.str" reveal_type(Snafu.snafu('123')) # N: Revealed type is "builtins.str" [builtins fixtures/staticmethod.pyi] + +[case testOverloadedWithInternalTypeVars] +# flags: --new-type-inference +import m + +[file m.pyi] +from typing import Callable, TypeVar, overload + +T = TypeVar("T") +S = TypeVar("S", bound=str) + +@overload +def foo(x: int = ...) -> Callable[[T], T]: ... +@overload +def foo(x: S = ...) -> Callable[[T], T]: ... From ab8a6c2b659c77398c9e2bd90e12f05e7f8c223d Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 27 Jul 2023 00:12:49 +0100 Subject: [PATCH 14/14] Flip the default back --- mypy/options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/options.py b/mypy/options.py index e7074f07f7f6..75343acd38bb 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -349,7 +349,7 @@ def __init__(self) -> None: # -1 means unlimited. self.many_errors_threshold = defaults.MANY_ERRORS_THRESHOLD # Enable new experimental type inference algorithm. - self.new_type_inference = True + self.new_type_inference = False # Disable recursive type aliases (currently experimental) self.disable_recursive_aliases = False # Deprecated reverse version of the above, do not use.