Skip to content

Commit 59617e8

Browse files
authored
Implementing background infrastructure for recursive types: Part 3 (#7885)
This is the last part of plumbing for recursive types (previous #7366 and #7330). Here I implement visitors and related functions. I convinced myself that we need to only be more careful when a recursive type is checked against another recursive one, so I only special-case these. Logic is similar to how protocols behave, because very roughly type alias can be imagined as a protocol with single property: ```python A = Union[T, Tuple[A[T], ...]] class A(Protocol[T]): @Property def __base__(self) -> Union[T, Tuple[A[T], ...]]: ... ``` but where `TypeAliasType` plays role of `Instance` and `TypeAlias` plays role of `TypeInfo`. Next two pull requests will contain some non-trivial implementation logic.
1 parent 834efe4 commit 59617e8

30 files changed

+636
-221
lines changed

mypy/binder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,16 +296,17 @@ def assign_type(self, expr: Expression,
296296
# (See discussion in #3526)
297297
elif (isinstance(type, AnyType)
298298
and isinstance(declared_type, UnionType)
299-
and any(isinstance(item, NoneType) for item in declared_type.items)
299+
and any(isinstance(get_proper_type(item), NoneType) for item in declared_type.items)
300300
and isinstance(get_proper_type(self.most_recent_enclosing_type(expr, NoneType())),
301301
NoneType)):
302302
# Replace any Nones in the union type with Any
303-
new_items = [type if isinstance(item, NoneType) else item
303+
new_items = [type if isinstance(get_proper_type(item), NoneType) else item
304304
for item in declared_type.items]
305305
self.put(expr, UnionType(new_items))
306306
elif (isinstance(type, AnyType)
307307
and not (isinstance(declared_type, UnionType)
308-
and any(isinstance(item, AnyType) for item in declared_type.items))):
308+
and any(isinstance(get_proper_type(item), AnyType)
309+
for item in declared_type.items))):
309310
# Assigning an Any value doesn't affect the type to avoid false negatives, unless
310311
# there is an Any item in a declared union type.
311312
self.put(expr, declared_type)

mypy/checker.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
3737
is_named_instance, union_items, TypeQuery, LiteralType,
3838
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
39-
get_proper_types, is_literal_type
40-
)
39+
get_proper_types, is_literal_type, TypeAliasType)
4140
from mypy.sametypes import is_same_type
4241
from mypy.messages import (
4342
MessageBuilder, make_inferred_type_note, append_invariance_notes,
@@ -2480,7 +2479,7 @@ def check_multi_assignment(self, lvalues: List[Lvalue],
24802479
# If this is an Optional type in non-strict Optional code, unwrap it.
24812480
relevant_items = rvalue_type.relevant_items()
24822481
if len(relevant_items) == 1:
2483-
rvalue_type = relevant_items[0]
2482+
rvalue_type = get_proper_type(relevant_items[0])
24842483

24852484
if isinstance(rvalue_type, AnyType):
24862485
for lv in lvalues:
@@ -2587,7 +2586,7 @@ def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expre
25872586
# If this is an Optional type in non-strict Optional code, unwrap it.
25882587
relevant_items = reinferred_rvalue_type.relevant_items()
25892588
if len(relevant_items) == 1:
2590-
reinferred_rvalue_type = relevant_items[0]
2589+
reinferred_rvalue_type = get_proper_type(relevant_items[0])
25912590
if isinstance(reinferred_rvalue_type, UnionType):
25922591
self.check_multi_assignment_from_union(lvalues, rvalue,
25932592
reinferred_rvalue_type, context,
@@ -3732,7 +3731,7 @@ def find_isinstance_check(self, node: Expression
37323731
type = get_isinstance_type(node.args[1], type_map)
37333732
if isinstance(vartype, UnionType):
37343733
union_list = []
3735-
for t in vartype.items:
3734+
for t in get_proper_types(vartype.items):
37363735
if isinstance(t, TypeType):
37373736
union_list.append(t.item)
37383737
else:
@@ -4558,6 +4557,7 @@ def overload_can_never_match(signature: CallableType, other: CallableType) -> bo
45584557
# TODO: find a cleaner solution instead of this ad-hoc erasure.
45594558
exp_signature = expand_type(signature, {tvar.id: erase_def_to_union_or_bound(tvar)
45604559
for tvar in signature.variables})
4560+
assert isinstance(exp_signature, ProperType)
45614561
assert isinstance(exp_signature, CallableType)
45624562
return is_callable_compatible(exp_signature, other,
45634563
is_compat=is_more_precise,
@@ -4641,6 +4641,11 @@ def visit_uninhabited_type(self, t: UninhabitedType) -> Type:
46414641
return AnyType(TypeOfAny.from_error)
46424642
return t
46434643

4644+
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
4645+
# Target of the alias cannot by an ambigous <nothing>, so we just
4646+
# replace the arguments.
4647+
return t.copy_modified(args=[a.accept(self) for a in t.args])
4648+
46444649

46454650
def is_node_static(node: Optional[Node]) -> Optional[bool]:
46464651
"""Find out if a node describes a static function method."""

mypy/checkexpr.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ def check_callable_call(self,
900900
callee = callee.copy_modified(ret_type=new_ret_type)
901901
return callee.ret_type, callee
902902

903-
def analyze_type_type_callee(self, item: ProperType, context: Context) -> ProperType:
903+
def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type:
904904
"""Analyze the callee X in X(...) where X is Type[item].
905905
906906
Return a Y that we can pass to check_call(Y, ...).
@@ -913,14 +913,15 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Proper
913913
res = res.copy_modified(from_type_type=True)
914914
return expand_type_by_instance(res, item)
915915
if isinstance(item, UnionType):
916-
return UnionType([self.analyze_type_type_callee(tp, context)
916+
return UnionType([self.analyze_type_type_callee(get_proper_type(tp), context)
917917
for tp in item.relevant_items()], item.line)
918918
if isinstance(item, TypeVarType):
919919
# Pretend we're calling the typevar's upper bound,
920920
# i.e. its constructor (a poor approximation for reality,
921921
# but better than AnyType...), but replace the return type
922922
# with typevar.
923923
callee = self.analyze_type_type_callee(get_proper_type(item.upper_bound), context)
924+
callee = get_proper_type(callee)
924925
if isinstance(callee, CallableType):
925926
callee = callee.copy_modified(ret_type=item)
926927
elif isinstance(callee, Overloaded):
@@ -2144,8 +2145,7 @@ def dangerous_comparison(self, left: Type, right: Type,
21442145
if not self.chk.options.strict_equality:
21452146
return False
21462147

2147-
left = get_proper_type(left)
2148-
right = get_proper_type(right)
2148+
left, right = get_proper_types((left, right))
21492149

21502150
if self.chk.binder.is_unreachable_warning_suppressed():
21512151
# We are inside a function that contains type variables with value restrictions in
@@ -2165,6 +2165,7 @@ def dangerous_comparison(self, left: Type, right: Type,
21652165
if isinstance(left, UnionType) and isinstance(right, UnionType):
21662166
left = remove_optional(left)
21672167
right = remove_optional(right)
2168+
left, right = get_proper_types((left, right))
21682169
py2 = self.chk.options.python_version < (3, 0)
21692170
if (original_container and has_bytes_component(original_container, py2) and
21702171
has_bytes_component(left, py2)):
@@ -2794,7 +2795,7 @@ def try_getting_int_literals(self, index: Expression) -> Optional[List[int]]:
27942795
return [typ.value]
27952796
if isinstance(typ, UnionType):
27962797
out = []
2797-
for item in typ.items:
2798+
for item in get_proper_types(typ.items):
27982799
if isinstance(item, LiteralType) and isinstance(item.value, int):
27992800
out.append(item.value)
28002801
else:
@@ -2969,7 +2970,7 @@ class LongName(Generic[T]): ...
29692970
# For example:
29702971
# A = List[Tuple[T, T]]
29712972
# x = A() <- same as List[Tuple[Any, Any]], see PEP 484.
2972-
item = set_any_tvars(target, alias_tvars, ctx.line, ctx.column)
2973+
item = get_proper_type(set_any_tvars(target, alias_tvars, ctx.line, ctx.column))
29732974
if isinstance(item, Instance):
29742975
# Normally we get a callable type (or overloaded) with .is_type_obj() true
29752976
# representing the class's constructor
@@ -3052,7 +3053,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:
30523053
type_context = get_proper_type(self.type_context[-1])
30533054
type_context_items = None
30543055
if isinstance(type_context, UnionType):
3055-
tuples_in_context = [t for t in type_context.items
3056+
tuples_in_context = [t for t in get_proper_types(type_context.items)
30563057
if (isinstance(t, TupleType) and len(t.items) == len(e.items)) or
30573058
is_named_instance(t, 'builtins.tuple')]
30583059
if len(tuples_in_context) == 1:
@@ -3240,7 +3241,8 @@ def infer_lambda_type_using_context(self, e: LambdaExpr) -> Tuple[Optional[Calla
32403241
ctx = get_proper_type(self.type_context[-1])
32413242

32423243
if isinstance(ctx, UnionType):
3243-
callables = [t for t in ctx.relevant_items() if isinstance(t, CallableType)]
3244+
callables = [t for t in get_proper_types(ctx.relevant_items())
3245+
if isinstance(t, CallableType)]
32443246
if len(callables) == 1:
32453247
ctx = callables[0]
32463248

mypy/checkmember.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mypy.types import (
77
Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef,
88
Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType,
9-
DeletedType, NoneType, TypeType, get_type_vars, get_proper_type, ProperType
9+
DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType
1010
)
1111
from mypy.nodes import (
1212
TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr,
@@ -377,7 +377,7 @@ def analyze_member_var_access(name: str,
377377
function = function_type(method, mx.builtin_type('builtins.function'))
378378
bound_method = bind_self(function, mx.self_type)
379379
typ = map_instance_to_supertype(itype, method.info)
380-
getattr_type = expand_type_by_instance(bound_method, typ)
380+
getattr_type = get_proper_type(expand_type_by_instance(bound_method, typ))
381381
if isinstance(getattr_type, CallableType):
382382
result = getattr_type.ret_type
383383

@@ -394,7 +394,7 @@ def analyze_member_var_access(name: str,
394394
setattr_func = function_type(setattr_meth, mx.builtin_type('builtins.function'))
395395
bound_type = bind_self(setattr_func, mx.self_type)
396396
typ = map_instance_to_supertype(itype, setattr_meth.info)
397-
setattr_type = expand_type_by_instance(bound_type, typ)
397+
setattr_type = get_proper_type(expand_type_by_instance(bound_type, typ))
398398
if isinstance(setattr_type, CallableType) and len(setattr_type.arg_types) > 0:
399399
return setattr_type.arg_types[-1]
400400

@@ -497,10 +497,11 @@ def instance_alias_type(alias: TypeAlias,
497497
498498
As usual, we first erase any unbound type variables to Any.
499499
"""
500-
target = get_proper_type(alias.target)
501-
assert isinstance(target, Instance), "Must be called only with aliases to classes"
500+
target = get_proper_type(alias.target) # type: Type
501+
assert isinstance(get_proper_type(target),
502+
Instance), "Must be called only with aliases to classes"
502503
target = set_any_tvars(target, alias.alias_tvars, alias.line, alias.column)
503-
assert isinstance(target, Instance)
504+
assert isinstance(target, Instance) # type: ignore[misc]
504505
tp = type_object_type(target.type, builtin_type)
505506
return expand_type_by_instance(tp, target)
506507

@@ -525,7 +526,7 @@ def analyze_var(name: str,
525526
if typ:
526527
if isinstance(typ, PartialType):
527528
return mx.chk.handle_partial_var_type(typ, mx.is_lvalue, var, mx.context)
528-
t = expand_type_by_instance(typ, itype)
529+
t = get_proper_type(expand_type_by_instance(typ, itype))
529530
if mx.is_lvalue and var.is_property and not var.is_settable_property:
530531
# TODO allow setting attributes in subclass (although it is probably an error)
531532
mx.msg.read_only_property(name, itype.type, mx.context)
@@ -577,7 +578,9 @@ def analyze_var(name: str,
577578
return result
578579

579580

580-
def freeze_type_vars(member_type: ProperType) -> None:
581+
def freeze_type_vars(member_type: Type) -> None:
582+
if not isinstance(member_type, ProperType):
583+
return
581584
if isinstance(member_type, CallableType):
582585
for v in member_type.variables:
583586
v.id.meta_level = 0
@@ -713,7 +716,7 @@ def analyze_class_attribute_access(itype: Instance,
713716
# x: T
714717
# C.x # Error, ambiguous access
715718
# C[int].x # Also an error, since C[int] is same as C at runtime
716-
if isinstance(t, TypeVarType) or get_type_vars(t):
719+
if isinstance(t, TypeVarType) or has_type_vars(t):
717720
# Exception: access on Type[...], including first argument of class methods is OK.
718721
if not isinstance(get_proper_type(mx.original_type), TypeType):
719722
if node.node.is_classvar:
@@ -799,7 +802,7 @@ class B(A[str]): pass
799802
info = itype.type # type: TypeInfo
800803
if is_classmethod:
801804
assert isuper is not None
802-
t = expand_type_by_instance(t, isuper)
805+
t = get_proper_type(expand_type_by_instance(t, isuper))
803806
# We add class type variables if the class method is accessed on class object
804807
# without applied type arguments, this matches the behavior of __init__().
805808
# For example (continuing the example in docstring):

mypy/checkstrformat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from mypy.types import (
2121
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType,
22-
CallableType, LiteralType
22+
CallableType, LiteralType, get_proper_types
2323
)
2424
from mypy.nodes import (
2525
StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr,
@@ -359,7 +359,8 @@ def check_specs_in_format_call(self, call: CallExpr,
359359
continue
360360

361361
a_type = get_proper_type(actual_type)
362-
actual_items = a_type.items if isinstance(a_type, UnionType) else [a_type]
362+
actual_items = (get_proper_types(a_type.items) if isinstance(a_type, UnionType)
363+
else [a_type])
363364
for a_type in actual_items:
364365
if custom_special_method(a_type, '__format__'):
365366
continue

mypy/constraints.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance,
88
TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
99
UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
10-
ProperType, get_proper_type
10+
ProperType, get_proper_type, TypeAliasType
1111
)
1212
from mypy.maptype import map_instance_to_supertype
1313
import mypy.subtypes
@@ -16,6 +16,7 @@
1616
from mypy.erasetype import erase_typevars
1717
from mypy.nodes import COVARIANT, CONTRAVARIANT
1818
from mypy.argmap import ArgTypeExpander
19+
from mypy.typestate import TypeState
1920

2021
SUBTYPE_OF = 0 # type: Final[int]
2122
SUPERTYPE_OF = 1 # type: Final[int]
@@ -89,6 +90,21 @@ def infer_constraints(template: Type, actual: Type,
8990
9091
The constraints are represented as Constraint objects.
9192
"""
93+
if any(get_proper_type(template) == get_proper_type(t) for t in TypeState._inferring):
94+
return []
95+
if (isinstance(template, TypeAliasType) and isinstance(actual, TypeAliasType) and
96+
template.is_recursive and actual.is_recursive):
97+
# This case requires special care because it may cause infinite recursion.
98+
TypeState._inferring.append(template)
99+
res = _infer_constraints(template, actual, direction)
100+
TypeState._inferring.pop()
101+
return res
102+
return _infer_constraints(template, actual, direction)
103+
104+
105+
def _infer_constraints(template: Type, actual: Type,
106+
direction: int) -> List[Constraint]:
107+
92108
template = get_proper_type(template)
93109
actual = get_proper_type(actual)
94110

@@ -487,6 +503,9 @@ def visit_union_type(self, template: UnionType) -> List[Constraint]:
487503
assert False, ("Unexpected UnionType in ConstraintBuilderVisitor"
488504
" (should have been handled in infer_constraints)")
489505

506+
def visit_type_alias_type(self, template: TypeAliasType) -> List[Constraint]:
507+
assert False, "This should be never called, got {}".format(template)
508+
490509
def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Constraint]:
491510
res = [] # type: List[Constraint]
492511
for t in types:

mypy/erasetype.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType,
55
CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType,
66
DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType,
7-
get_proper_type
7+
get_proper_type, TypeAliasType
88
)
99
from mypy.nodes import ARG_STAR, ARG_STAR2
1010

@@ -93,6 +93,9 @@ def visit_union_type(self, t: UnionType) -> ProperType:
9393
def visit_type_type(self, t: TypeType) -> ProperType:
9494
return TypeType.make_normalized(t.item.accept(self), line=t.line)
9595

96+
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
97+
raise RuntimeError("Type aliases should be expanded before accepting this visitor")
98+
9699

97100
def erase_typevars(t: Type, ids_to_erase: Optional[Container[TypeVarId]] = None) -> Type:
98101
"""Replace all type variables in a type with any,
@@ -122,6 +125,11 @@ def visit_type_var(self, t: TypeVarType) -> Type:
122125
return self.replacement
123126
return t
124127

128+
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
129+
# Type alias target can't contain bound type variables, so
130+
# it is safe to just erase the arguments.
131+
return t.copy_modified(args=[a.accept(self) for a in t.args])
132+
125133

126134
def remove_instance_last_known_values(t: Type) -> Type:
127135
return t.accept(LastKnownValueEraser())
@@ -135,3 +143,8 @@ def visit_instance(self, t: Instance) -> Type:
135143
if t.last_known_value:
136144
return t.copy_modified(last_known_value=None)
137145
return t
146+
147+
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
148+
# Type aliases can't contain literal values, because they are
149+
# always constructed as explicit types.
150+
return t

0 commit comments

Comments
 (0)