Skip to content

Commit d2063d2

Browse files
authored
Enable recursive type aliases behind a flag (#13297)
This PR exposes recursive type aliases that were secretly there for last ~3 years. For now they will still be behind an opt-in flag, because they are not production ready. As we discussed with Jukka during PyCon, I use couple hacks to make them minimally usable, as proper solutions will take time. I may clean up some of them in near future (or may not). You can see few added test cases to get an idea of what is supported, example: ```python Nested = Sequence[Union[T, Nested[T]]] def flatten(seq: Nested[T]) -> List[T]: flat: List[T] = [] for item in seq: if isinstance(item, Sequence): res.extend(flatten(item)) else: res.append(item) return flat reveal_type(flatten([1, [2, [3]]])) # N: Revealed type is "builtins.list[builtins.int]" ```
1 parent 1bb970a commit d2063d2

16 files changed

+398
-43
lines changed

mypy/checker.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3620,20 +3620,23 @@ def check_simple_assignment(
36203620
# '...' is always a valid initializer in a stub.
36213621
return AnyType(TypeOfAny.special_form)
36223622
else:
3623+
orig_lvalue = lvalue_type
36233624
lvalue_type = get_proper_type(lvalue_type)
36243625
always_allow_any = lvalue_type is not None and not isinstance(lvalue_type, AnyType)
36253626
rvalue_type = self.expr_checker.accept(
36263627
rvalue, lvalue_type, always_allow_any=always_allow_any
36273628
)
3629+
orig_rvalue = rvalue_type
36283630
rvalue_type = get_proper_type(rvalue_type)
36293631
if isinstance(rvalue_type, DeletedType):
36303632
self.msg.deleted_as_rvalue(rvalue_type, context)
36313633
if isinstance(lvalue_type, DeletedType):
36323634
self.msg.deleted_as_lvalue(lvalue_type, context)
36333635
elif lvalue_type:
36343636
self.check_subtype(
3635-
rvalue_type,
3636-
lvalue_type,
3637+
# Preserve original aliases for error messages when possible.
3638+
orig_rvalue,
3639+
orig_lvalue or lvalue_type,
36373640
context,
36383641
msg,
36393642
f"{rvalue_name} has type",
@@ -5526,7 +5529,9 @@ def check_subtype(
55265529
code = msg.code
55275530
else:
55285531
msg_text = msg
5532+
orig_subtype = subtype
55295533
subtype = get_proper_type(subtype)
5534+
orig_supertype = supertype
55305535
supertype = get_proper_type(supertype)
55315536
if self.msg.try_report_long_tuple_assignment_error(
55325537
subtype, supertype, context, msg_text, subtype_label, supertype_label, code=code
@@ -5538,7 +5543,7 @@ def check_subtype(
55385543
note_msg = ""
55395544
notes: List[str] = []
55405545
if subtype_label is not None or supertype_label is not None:
5541-
subtype_str, supertype_str = format_type_distinctly(subtype, supertype)
5546+
subtype_str, supertype_str = format_type_distinctly(orig_subtype, orig_supertype)
55425547
if subtype_label is not None:
55435548
extra_info.append(subtype_label + " " + subtype_str)
55445549
if supertype_label is not None:

mypy/checkexpr.py

+13
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
flatten_nested_unions,
150150
get_proper_type,
151151
get_proper_types,
152+
has_recursive_types,
152153
is_generic_instance,
153154
is_named_instance,
154155
is_optional,
@@ -1568,13 +1569,25 @@ def infer_function_type_arguments(
15681569
else:
15691570
pass1_args.append(arg)
15701571

1572+
# This is a hack to better support inference for recursive types.
1573+
# When the outer context for a function call is known to be recursive,
1574+
# we solve type constraints inferred from arguments using unions instead
1575+
# of joins. This is a bit arbitrary, but in practice it works for most
1576+
# cases. A cleaner alternative would be to switch to single bin type
1577+
# inference, but this is a lot of work.
1578+
ctx = self.type_context[-1]
1579+
if ctx and has_recursive_types(ctx):
1580+
infer_unions = True
1581+
else:
1582+
infer_unions = False
15711583
inferred_args = infer_function_type_arguments(
15721584
callee_type,
15731585
pass1_args,
15741586
arg_kinds,
15751587
formal_to_actual,
15761588
context=self.argument_infer_context(),
15771589
strict=self.chk.in_checked_function(),
1590+
infer_unions=infer_unions,
15781591
)
15791592

15801593
if 2 in arg_pass_nums:

mypy/expandtype.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def expand_types_with_unpack(
293293
else:
294294
items.extend(unpacked_items)
295295
else:
296-
items.append(proper_item.accept(self))
296+
# Must preserve original aliases when possible.
297+
items.append(item.accept(self))
297298
return items
298299

299300
def visit_tuple_type(self, t: TupleType) -> Type:

mypy/infer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def infer_function_type_arguments(
3434
formal_to_actual: List[List[int]],
3535
context: ArgumentInferContext,
3636
strict: bool = True,
37+
infer_unions: bool = False,
3738
) -> List[Optional[Type]]:
3839
"""Infer the type arguments of a generic function.
3940
@@ -55,7 +56,7 @@ def infer_function_type_arguments(
5556

5657
# Solve constraints.
5758
type_vars = callee_type.type_var_ids()
58-
return solve_constraints(type_vars, constraints, strict)
59+
return solve_constraints(type_vars, constraints, strict, infer_unions=infer_unions)
5960

6061

6162
def infer_type_arguments(

mypy/main.py

+5
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,11 @@ def add_invertible_flag(
977977
dest="custom_typing_module",
978978
help="Use a custom typing module",
979979
)
980+
internals_group.add_argument(
981+
"--enable-recursive-aliases",
982+
action="store_true",
983+
help="Experimental support for recursive type aliases",
984+
)
980985
internals_group.add_argument(
981986
"--custom-typeshed-dir", metavar="DIR", help="Use the custom typeshed in DIR"
982987
)

mypy/messages.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
ProperType,
8888
TupleType,
8989
Type,
90+
TypeAliasType,
9091
TypedDictType,
9192
TypeOfAny,
9293
TypeType,
@@ -2146,7 +2147,17 @@ def format_literal_value(typ: LiteralType) -> str:
21462147
else:
21472148
return typ.value_repr()
21482149

2149-
# TODO: show type alias names in errors.
2150+
if isinstance(typ, TypeAliasType) and typ.is_recursive:
2151+
# TODO: find balance here, str(typ) doesn't support custom verbosity, and may be
2152+
# too verbose for user messages, OTOH it nicely shows structure of recursive types.
2153+
if verbosity < 2:
2154+
type_str = typ.alias.name if typ.alias else "<alias (unfixed)>"
2155+
if typ.args:
2156+
type_str += f"[{format_list(typ.args)}]"
2157+
return type_str
2158+
return str(typ)
2159+
2160+
# TODO: always mention type alias names in errors.
21502161
typ = get_proper_type(typ)
21512162

21522163
if isinstance(typ, Instance):

mypy/options.py

+2
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ def __init__(self) -> None:
315315
# skip most errors after this many messages have been reported.
316316
# -1 means unlimited.
317317
self.many_errors_threshold = defaults.MANY_ERRORS_THRESHOLD
318+
# Enable recursive type aliases (currently experimental)
319+
self.enable_recursive_aliases = False
318320

319321
# To avoid breaking plugin compatibility, keep providing new_semantic_analyzer
320322
@property

mypy/sametypes.py

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
def is_same_type(left: Type, right: Type) -> bool:
3434
"""Is 'left' the same type as 'right'?"""
3535

36+
if isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType):
37+
if left.is_recursive and right.is_recursive:
38+
return left.alias == right.alias and left.args == right.args
39+
3640
left = get_proper_type(left)
3741
right = get_proper_type(right)
3842

mypy/semanal.py

+71-9
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,14 @@ def __init__(
453453
# current SCC or top-level function.
454454
self.deferral_debug_context: List[Tuple[str, int]] = []
455455

456+
# This is needed to properly support recursive type aliases. The problem is that
457+
# Foo[Bar] could mean three things depending on context: a target for type alias,
458+
# a normal index expression (including enum index), or a type application.
459+
# The latter is particularly problematic as it can falsely create incomplete
460+
# refs while analysing rvalues of type aliases. To avoid this we first analyse
461+
# rvalues while temporarily setting this to True.
462+
self.basic_type_applications = False
463+
456464
# mypyc doesn't properly handle implementing an abstractproperty
457465
# with a regular attribute so we make them properties
458466
@property
@@ -2319,14 +2327,25 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
23192327
return
23202328

23212329
tag = self.track_incomplete_refs()
2322-
s.rvalue.accept(self)
2330+
2331+
# Here we have a chicken and egg problem: at this stage we can't call
2332+
# can_be_type_alias(), because we have not enough information about rvalue.
2333+
# But we can't use a full visit because it may emit extra incomplete refs (namely
2334+
# when analysing any type applications there) thus preventing the further analysis.
2335+
# To break the tie, we first analyse rvalue partially, if it can be a type alias.
2336+
with self.basic_type_applications_set(s):
2337+
s.rvalue.accept(self)
23232338
if self.found_incomplete_ref(tag) or self.should_wait_rhs(s.rvalue):
23242339
# Initializer couldn't be fully analyzed. Defer the current node and give up.
23252340
# Make sure that if we skip the definition of some local names, they can't be
23262341
# added later in this scope, since an earlier definition should take precedence.
23272342
for expr in names_modified_by_assignment(s):
23282343
self.mark_incomplete(expr.name, expr)
23292344
return
2345+
if self.can_possibly_be_index_alias(s):
2346+
# Now re-visit those rvalues that were we skipped type applications above.
2347+
# This should be safe as generally semantic analyzer is idempotent.
2348+
s.rvalue.accept(self)
23302349

23312350
# The r.h.s. is now ready to be classified, first check if it is a special form:
23322351
special_form = False
@@ -2465,6 +2484,36 @@ def can_be_type_alias(self, rv: Expression, allow_none: bool = False) -> bool:
24652484
return True
24662485
return False
24672486

2487+
def can_possibly_be_index_alias(self, s: AssignmentStmt) -> bool:
2488+
"""Like can_be_type_alias(), but simpler and doesn't require analyzed rvalue.
2489+
2490+
Instead, use lvalues/annotations structure to figure out whether this can
2491+
potentially be a type alias definition. Another difference from above function
2492+
is that we are only interested IndexExpr and OpExpr rvalues, since only those
2493+
can be potentially recursive (things like `A = A` are never valid).
2494+
"""
2495+
if len(s.lvalues) > 1:
2496+
return False
2497+
if not isinstance(s.lvalues[0], NameExpr):
2498+
return False
2499+
if s.unanalyzed_type is not None and not self.is_pep_613(s):
2500+
return False
2501+
if not isinstance(s.rvalue, (IndexExpr, OpExpr)):
2502+
return False
2503+
# Something that looks like Foo = Bar[Baz, ...]
2504+
return True
2505+
2506+
@contextmanager
2507+
def basic_type_applications_set(self, s: AssignmentStmt) -> Iterator[None]:
2508+
old = self.basic_type_applications
2509+
# As an optimization, only use the double visit logic if this
2510+
# can possibly be a recursive type alias.
2511+
self.basic_type_applications = self.can_possibly_be_index_alias(s)
2512+
try:
2513+
yield
2514+
finally:
2515+
self.basic_type_applications = old
2516+
24682517
def is_type_ref(self, rv: Expression, bare: bool = False) -> bool:
24692518
"""Does this expression refer to a type?
24702519
@@ -2941,6 +2990,13 @@ def analyze_alias(
29412990
qualified_tvars = []
29422991
return typ, alias_tvars, depends_on, qualified_tvars
29432992

2993+
def is_pep_613(self, s: AssignmentStmt) -> bool:
2994+
if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType):
2995+
lookup = self.lookup_qualified(s.unanalyzed_type.name, s, suppress_errors=True)
2996+
if lookup and lookup.fullname in TYPE_ALIAS_NAMES:
2997+
return True
2998+
return False
2999+
29443000
def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
29453001
"""Check if assignment creates a type alias and set it up as needed.
29463002
@@ -2955,11 +3011,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
29553011
# First rule: Only simple assignments like Alias = ... create aliases.
29563012
return False
29573013

2958-
pep_613 = False
2959-
if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType):
2960-
lookup = self.lookup_qualified(s.unanalyzed_type.name, s, suppress_errors=True)
2961-
if lookup and lookup.fullname in TYPE_ALIAS_NAMES:
2962-
pep_613 = True
3014+
pep_613 = self.is_pep_613(s)
29633015
if not pep_613 and s.unanalyzed_type is not None:
29643016
# Second rule: Explicit type (cls: Type[A] = A) always creates variable, not alias.
29653017
# unless using PEP 613 `cls: TypeAlias = A`
@@ -3023,9 +3075,16 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
30233075
)
30243076
if not res:
30253077
return False
3026-
# TODO: Maybe we only need to reject top-level placeholders, similar
3027-
# to base classes.
3028-
if self.found_incomplete_ref(tag) or has_placeholder(res):
3078+
if self.options.enable_recursive_aliases:
3079+
# Only marking incomplete for top-level placeholders makes recursive aliases like
3080+
# `A = Sequence[str | A]` valid here, similar to how we treat base classes in class
3081+
# definitions, allowing `class str(Sequence[str]): ...`
3082+
incomplete_target = isinstance(res, ProperType) and isinstance(
3083+
res, PlaceholderType
3084+
)
3085+
else:
3086+
incomplete_target = has_placeholder(res)
3087+
if self.found_incomplete_ref(tag) or incomplete_target:
30293088
# Since we have got here, we know this must be a type alias (incomplete refs
30303089
# may appear in nested positions), therefore use becomes_typeinfo=True.
30313090
self.mark_incomplete(lvalue.name, rvalue, becomes_typeinfo=True)
@@ -4532,6 +4591,9 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]]
45324591
self.analyze_type_expr(index)
45334592
if self.found_incomplete_ref(tag):
45344593
return None
4594+
if self.basic_type_applications:
4595+
# Postpone the rest until we have more information (for r.h.s. of an assignment)
4596+
return None
45354597
types: List[Type] = []
45364598
if isinstance(index, TupleExpr):
45374599
items = index.items

mypy/solve.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,22 @@
77
from mypy.join import join_types
88
from mypy.meet import meet_types
99
from mypy.subtypes import is_subtype
10-
from mypy.types import AnyType, Type, TypeOfAny, TypeVarId, UninhabitedType, get_proper_type
10+
from mypy.types import (
11+
AnyType,
12+
Type,
13+
TypeOfAny,
14+
TypeVarId,
15+
UninhabitedType,
16+
UnionType,
17+
get_proper_type,
18+
)
1119

1220

1321
def solve_constraints(
14-
vars: List[TypeVarId], constraints: List[Constraint], strict: bool = True
22+
vars: List[TypeVarId],
23+
constraints: List[Constraint],
24+
strict: bool = True,
25+
infer_unions: bool = False,
1526
) -> List[Optional[Type]]:
1627
"""Solve type constraints.
1728
@@ -43,7 +54,12 @@ def solve_constraints(
4354
if bottom is None:
4455
bottom = c.target
4556
else:
46-
bottom = join_types(bottom, c.target)
57+
if infer_unions:
58+
# This deviates from the general mypy semantics because
59+
# recursive types are union-heavy in 95% of cases.
60+
bottom = UnionType.make_union([bottom, c.target])
61+
else:
62+
bottom = join_types(bottom, c.target)
4763
else:
4864
if top is None:
4965
top = c.target

mypy/subtypes.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,15 @@ def is_subtype(
105105
if TypeState.is_assumed_subtype(left, right):
106106
return True
107107
if (
108+
# TODO: recursive instances like `class str(Sequence[str])` can also cause
109+
# issues, so we also need to include them in the assumptions stack
108110
isinstance(left, TypeAliasType)
109111
and isinstance(right, TypeAliasType)
110112
and left.is_recursive
111113
and right.is_recursive
112114
):
113115
# This case requires special care because it may cause infinite recursion.
114-
# Our view on recursive types is known under a fancy name of equirecursive mu-types.
116+
# Our view on recursive types is known under a fancy name of iso-recursive mu-types.
115117
# Roughly this means that a recursive type is defined as an alias where right hand side
116118
# can refer to the type as a whole, for example:
117119
# A = Union[int, Tuple[A, ...]]

0 commit comments

Comments
 (0)