Skip to content

Commit 7335991

Browse files
committed
Switch to making final variables context-sensitive
This commit modifies this PR to make selecting the type of final variables context-sensitive. Now, when we do: x: Final = 1 ...the variable `x` is normally inferred to be of type `int`. However, if that variable is used in a context which expects `Literal`, we infer the literal type. This commit also removes some of the hacks to mypy and the tests that the first iteration added.
1 parent e5a7495 commit 7335991

21 files changed

+378
-142
lines changed

mypy/checker.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
174174
# Type checking pass number (0 = first pass)
175175
pass_num = 0
176176
# Last pass number to take
177-
last_pass = DEFAULT_LAST_PASS # type: int
177+
last_pass = DEFAULT_LAST_PASS
178178
# Have we deferred the current function? If yes, don't infer additional
179179
# types during this pass within the function.
180180
current_node_deferred = False
@@ -1809,7 +1809,10 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
18091809
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
18101810

18111811
if inferred:
1812-
rvalue_type = self.expr_checker.accept(rvalue, infer_literal=inferred.is_final)
1812+
rvalue_type = self.expr_checker.accept(
1813+
rvalue,
1814+
in_final_declaration=inferred.is_final,
1815+
)
18131816
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
18141817

18151818
def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],

mypy/checkexpr.py

+39-34
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from mypy.types import (
1919
Type, AnyType, CallableType, Overloaded, NoneTyp, TypeVarDef,
2020
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
21-
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType,
21+
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
2222
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
2323
StarType, is_optional, remove_optional, is_generic_instance
2424
)
@@ -139,7 +139,7 @@ def __init__(self,
139139
self.msg = msg
140140
self.plugin = plugin
141141
self.type_context = [None]
142-
self.infer_literal = False
142+
self.in_final_declaration = False
143143
# Temporary overrides for expression types. This is currently
144144
# used by the union math in overloads.
145145
# TODO: refactor this to use a pattern similar to one in
@@ -211,10 +211,12 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
211211

212212
def analyze_var_ref(self, var: Var, context: Context) -> Type:
213213
if var.type:
214-
if self.is_literal_context() and var.name() in {'True', 'False'}:
215-
return LiteralType(var.name() == 'True', self.named_type('builtins.bool'))
216-
else:
217-
return var.type
214+
if isinstance(var.type, Instance):
215+
if self._is_literal_context() and var.type.final_value is not None:
216+
return var.type.final_value
217+
if var.name() in {'True', 'False'}:
218+
return self._handle_literal_expr(var.name() == 'True', 'builtins.bool')
219+
return var.type
218220
else:
219221
if not var.is_ready and self.chk.in_checked_function():
220222
self.chk.handle_cannot_determine_type(var.name(), context)
@@ -693,7 +695,8 @@ def check_call(self,
693695
elif isinstance(callee, Instance):
694696
call_function = analyze_member_access('__call__', callee, context,
695697
False, False, False, self.msg,
696-
original_type=callee, chk=self.chk)
698+
original_type=callee, chk=self.chk,
699+
in_literal_context=self._is_literal_context())
697700
return self.check_call(call_function, args, arg_kinds, context, arg_names,
698701
callable_node, arg_messages)
699702
elif isinstance(callee, TypeVarType):
@@ -1757,7 +1760,8 @@ def analyze_ordinary_member_access(self, e: MemberExpr,
17571760
original_type = self.accept(e.expr)
17581761
member_type = analyze_member_access(
17591762
e.name, original_type, e, is_lvalue, False, False,
1760-
self.msg, original_type=original_type, chk=self.chk)
1763+
self.msg, original_type=original_type, chk=self.chk,
1764+
in_literal_context=self._is_literal_context())
17611765
return member_type
17621766

17631767
def analyze_external_member_access(self, member: str, base_type: Type,
@@ -1767,35 +1771,36 @@ def analyze_external_member_access(self, member: str, base_type: Type,
17671771
"""
17681772
# TODO remove; no private definitions in mypy
17691773
return analyze_member_access(member, base_type, context, False, False, False,
1770-
self.msg, original_type=base_type, chk=self.chk)
1774+
self.msg, original_type=base_type, chk=self.chk,
1775+
in_literal_context=self._is_literal_context())
1776+
1777+
def _is_literal_context(self) -> bool:
1778+
return is_literal_type_like(self.type_context[-1])
1779+
1780+
def _handle_literal_expr(self, value: LiteralValue, fallback_name: str) -> Type:
1781+
typ = self.named_type(fallback_name)
1782+
if self._is_literal_context():
1783+
return LiteralType(value=value, fallback=typ)
1784+
elif self.in_final_declaration:
1785+
return typ.copy_with_final_value(value)
1786+
else:
1787+
return typ
17711788

17721789
def visit_int_expr(self, e: IntExpr) -> Type:
17731790
"""Type check an integer literal (trivial)."""
1774-
typ = self.named_type('builtins.int')
1775-
if self.is_literal_context():
1776-
return LiteralType(value=e.value, fallback=typ)
1777-
return typ
1791+
return self._handle_literal_expr(e.value, 'builtins.int')
17781792

17791793
def visit_str_expr(self, e: StrExpr) -> Type:
17801794
"""Type check a string literal (trivial)."""
1781-
typ = self.named_type('builtins.str')
1782-
if self.is_literal_context():
1783-
return LiteralType(value=e.value, fallback=typ)
1784-
return typ
1795+
return self._handle_literal_expr(e.value, 'builtins.str')
17851796

17861797
def visit_bytes_expr(self, e: BytesExpr) -> Type:
17871798
"""Type check a bytes literal (trivial)."""
1788-
typ = self.named_type('builtins.bytes')
1789-
if is_literal_type_like(self.type_context[-1]):
1790-
return LiteralType(value=e.value, fallback=typ)
1791-
return typ
1799+
return self._handle_literal_expr(e.value, 'builtins.bytes')
17921800

17931801
def visit_unicode_expr(self, e: UnicodeExpr) -> Type:
17941802
"""Type check a unicode literal (trivial)."""
1795-
typ = self.named_type('builtins.unicode')
1796-
if is_literal_type_like(self.type_context[-1]):
1797-
return LiteralType(value=e.value, fallback=typ)
1798-
return typ
1803+
return self._handle_literal_expr(e.value, 'builtins.unicode')
17991804

18001805
def visit_float_expr(self, e: FloatExpr) -> Type:
18011806
"""Type check a float literal (trivial)."""
@@ -1932,7 +1937,8 @@ def check_method_call_by_name(self,
19321937
"""
19331938
local_errors = local_errors or self.msg
19341939
method_type = analyze_member_access(method, base_type, context, False, False, True,
1935-
local_errors, original_type=base_type, chk=self.chk)
1940+
local_errors, original_type=base_type, chk=self.chk,
1941+
in_literal_context=self._is_literal_context())
19361942
return self.check_method_call(
19371943
method, base_type, method_type, args, arg_kinds, context, local_errors)
19381944

@@ -1996,6 +2002,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
19962002
context=context,
19972003
msg=local_errors,
19982004
chk=self.chk,
2005+
in_literal_context=self._is_literal_context()
19992006
)
20002007
if local_errors.is_errors():
20012008
return None
@@ -2946,7 +2953,8 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
29462953
override_info=base,
29472954
context=e,
29482955
msg=self.msg,
2949-
chk=self.chk)
2956+
chk=self.chk,
2957+
in_literal_context=self._is_literal_context())
29502958
assert False, 'unreachable'
29512959
else:
29522960
# Invalid super. This has been reported by the semantic analyzer.
@@ -3113,16 +3121,16 @@ def accept(self,
31133121
type_context: Optional[Type] = None,
31143122
allow_none_return: bool = False,
31153123
always_allow_any: bool = False,
3116-
infer_literal: bool = False,
3124+
in_final_declaration: bool = False,
31173125
) -> Type:
31183126
"""Type check a node in the given type context. If allow_none_return
31193127
is True and this expression is a call, allow it to return None. This
31203128
applies only to this expression and not any subexpressions.
31213129
"""
31223130
if node in self.type_overrides:
31233131
return self.type_overrides[node]
3124-
old_infer_literal = self.infer_literal
3125-
self.infer_literal = infer_literal
3132+
old_in_final_declaration = self.in_final_declaration
3133+
self.in_final_declaration = in_final_declaration
31263134
self.type_context.append(type_context)
31273135
try:
31283136
if allow_none_return and isinstance(node, CallExpr):
@@ -3135,7 +3143,7 @@ def accept(self,
31353143
report_internal_error(err, self.chk.errors.file,
31363144
node.line, self.chk.errors, self.chk.options)
31373145
self.type_context.pop()
3138-
self.infer_literal = old_infer_literal
3146+
self.in_final_declaration = old_in_final_declaration
31393147
assert typ is not None
31403148
self.chk.store_type(node, typ)
31413149

@@ -3381,9 +3389,6 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type:
33813389
return ans
33823390
return known_type
33833391

3384-
def is_literal_context(self) -> bool:
3385-
return self.infer_literal or is_literal_type_like(self.type_context[-1])
3386-
33873392

33883393
def has_any_type(t: Type) -> bool:
33893394
"""Whether t contains an Any type"""

mypy/checkmember.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def analyze_member_access(name: str,
7171
msg: MessageBuilder, *,
7272
original_type: Type,
7373
chk: 'mypy.checker.TypeChecker',
74-
override_info: Optional[TypeInfo] = None) -> Type:
74+
override_info: Optional[TypeInfo] = None,
75+
in_literal_context: bool = False) -> Type:
7576
"""Return the type of attribute 'name' of 'typ'.
7677
7778
The actual implementation is in '_analyze_member_access' and this docstring
@@ -96,7 +97,11 @@ def analyze_member_access(name: str,
9697
context,
9798
msg,
9899
chk=chk)
99-
return _analyze_member_access(name, typ, mx, override_info)
100+
result = _analyze_member_access(name, typ, mx, override_info)
101+
if in_literal_context and isinstance(result, Instance) and result.final_value is not None:
102+
return result.final_value
103+
else:
104+
return result
100105

101106

102107
def _analyze_member_access(name: str,

mypy/defaults.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
PYTHON2_VERSION = (2, 7) # type: Final
66
PYTHON3_VERSION = (3, 6) # type: Final
77
PYTHON3_VERSION_MIN = (3, 4) # type: Final
8-
CACHE_DIR = '.mypy_cache' # type: Final[str]
9-
CONFIG_FILE = 'mypy.ini' # type: Final[str]
8+
CACHE_DIR = '.mypy_cache' # type: Final
9+
CONFIG_FILE = 'mypy.ini' # type: Final
1010
SHARED_CONFIG_FILES = ('setup.cfg',) # type: Final
1111
USER_CONFIG_FILES = ('~/.mypy.ini',) # type: Final
1212
CONFIG_FILES = (CONFIG_FILE,) + SHARED_CONFIG_FILES + USER_CONFIG_FILES # type: Final

mypy/erasetype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
5151
return t
5252

5353
def visit_instance(self, t: Instance) -> Type:
54-
return Instance(t.type, [AnyType(TypeOfAny.special_form)] * len(t.args), t.line)
54+
return t.copy_modified(args=[AnyType(TypeOfAny.special_form)] * len(t.args))
5555

5656
def visit_type_var(self, t: TypeVarType) -> Type:
5757
return AnyType(TypeOfAny.special_form)

mypy/expandtype.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,14 @@ def visit_erased_type(self, t: ErasedType) -> Type:
8080
raise RuntimeError()
8181

8282
def visit_instance(self, t: Instance) -> Type:
83-
args = self.expand_types(t.args)
84-
return Instance(t.type, args, t.line, t.column)
83+
return t.copy_modified(args=self.expand_types(t.args))
8584

8685
def visit_type_var(self, t: TypeVarType) -> Type:
8786
repl = self.variables.get(t.id, t)
8887
if isinstance(repl, Instance):
8988
inst = repl
9089
# Return copy of instance with type erasure flag on.
91-
return Instance(inst.type, inst.args, line=inst.line,
92-
column=inst.column, erased=True)
90+
return inst.copy_modified(erased=True)
9391
else:
9492
return repl
9593

mypy/reachability.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def infer_condition_value(expr: Expression, options: Options) -> int:
7777
if alias.op == 'not':
7878
expr = alias.expr
7979
negated = True
80-
result = TRUTH_VALUE_UNKNOWN # type: int
80+
result = TRUTH_VALUE_UNKNOWN
8181
if isinstance(expr, NameExpr):
8282
name = expr.name
8383
elif isinstance(expr, MemberExpr):

mypy/sametypes.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def visit_deleted_type(self, left: DeletedType) -> bool:
7777
def visit_instance(self, left: Instance) -> bool:
7878
return (isinstance(self.right, Instance) and
7979
left.type == self.right.type and
80-
is_same_types(left.args, self.right.args))
80+
is_same_types(left.args, self.right.args) and
81+
left.final_value == self.right.final_value)
8182

8283
def visit_type_var(self, left: TypeVarType) -> bool:
8384
return (isinstance(self.right, TypeVarType) and

mypy/semanal.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from mypy.messages import CANNOT_ASSIGN_TO_TYPE, MessageBuilder
6666
from mypy.types import (
6767
FunctionLike, UnboundType, TypeVarDef, TupleType, UnionType, StarType, function_type,
68-
CallableType, Overloaded, Instance, Type, AnyType, LiteralType,
68+
CallableType, Overloaded, Instance, Type, AnyType,
6969
TypeTranslator, TypeOfAny, TypeType, NoneTyp,
7070
)
7171
from mypy.nodes import implicit_module_attrs
@@ -1908,22 +1908,29 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Opt
19081908
# inside type variables with value restrictions (like
19091909
# AnyStr).
19101910
return None
1911+
if isinstance(rvalue, FloatExpr):
1912+
return self.named_type_or_none('builtins.float')
1913+
19111914
if isinstance(rvalue, IntExpr):
19121915
typ = self.named_type_or_none('builtins.int')
19131916
if typ and is_final:
1914-
return LiteralType(rvalue.value, typ, rvalue.line, rvalue.column)
1917+
return typ.copy_with_final_value(rvalue.value)
19151918
return typ
1916-
if isinstance(rvalue, FloatExpr):
1917-
return self.named_type_or_none('builtins.float')
19181919
if isinstance(rvalue, StrExpr):
19191920
typ = self.named_type_or_none('builtins.str')
19201921
if typ and is_final:
1921-
return LiteralType(rvalue.value, typ, rvalue.line, rvalue.column)
1922+
return typ.copy_with_final_value(rvalue.value)
19221923
return typ
19231924
if isinstance(rvalue, BytesExpr):
1924-
return self.named_type_or_none('builtins.bytes')
1925+
typ = self.named_type_or_none('builtins.bytes')
1926+
if typ and is_final:
1927+
return typ.copy_with_final_value(rvalue.value)
1928+
return typ
19251929
if isinstance(rvalue, UnicodeExpr):
1926-
return self.named_type_or_none('builtins.unicode')
1930+
typ = self.named_type_or_none('builtins.unicode')
1931+
if typ and is_final:
1932+
return typ.copy_with_final_value(rvalue.value)
1933+
return typ
19271934

19281935
return None
19291936

mypy/server/astdiff.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem:
284284
def visit_instance(self, typ: Instance) -> SnapshotItem:
285285
return ('Instance',
286286
typ.type.fullname(),
287-
snapshot_types(typ.args))
287+
snapshot_types(typ.args),
288+
None if typ.final_value is None else snapshot_type(typ.final_value))
288289

289290
def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
290291
return ('TypeVar',

mypy/server/astmerge.py

+2
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def visit_instance(self, typ: Instance) -> None:
337337
typ.type = self.fixup(typ.type)
338338
for arg in typ.args:
339339
arg.accept(self)
340+
if typ.final_value:
341+
typ.final_value.accept(self)
340342

341343
def visit_any(self, typ: AnyType) -> None:
342344
pass

mypy/server/deps.py

+2
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,8 @@ def visit_instance(self, typ: Instance) -> List[str]:
882882
triggers = [trigger]
883883
for arg in typ.args:
884884
triggers.extend(self.get_type_triggers(arg))
885+
if typ.final_value:
886+
triggers.extend(self.get_type_triggers(typ.final_value))
885887
return triggers
886888

887889
def visit_any(self, typ: AnyType) -> List[str]:

mypy/stubgen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int],
420420
self._import_lines = [] # type: List[str]
421421
self._indent = ''
422422
self._vars = [[]] # type: List[List[str]]
423-
self._state = EMPTY # type: str
423+
self._state = EMPTY
424424
self._toplevel_names = [] # type: List[str]
425425
self._pyversion = pyversion
426426
self._include_private = include_private

mypy/subtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def get_member_flags(name: str, info: TypeInfo) -> Set[int]:
589589
return {IS_CLASS_OR_STATIC}
590590
# just a variable
591591
if isinstance(v, Var) and not v.is_property:
592-
flags = {IS_SETTABLE} # type: Set[int]
592+
flags = {IS_SETTABLE}
593593
if v.is_classvar:
594594
flags.add(IS_CLASSVAR)
595595
return flags

mypy/type_visitor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
159159
return t
160160

161161
def visit_instance(self, t: Instance) -> Type:
162-
return Instance(t.type, self.translate_types(t.args), t.line, t.column)
162+
return t.copy_modified(args=self.translate_types(t.args))
163163

164164
def visit_type_var(self, t: TypeVarType) -> Type:
165165
return t

mypy/typeanal.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
678678
elif isinstance(arg, (NoneTyp, LiteralType)):
679679
# Types that we can just add directly to the literal/potential union of literals.
680680
return [arg]
681+
elif isinstance(arg, Instance) and arg.final_value is not None:
682+
# Types generated from declarations like "var: Final = 4".
683+
return [arg.final_value]
681684
elif isinstance(arg, UnionType):
682685
out = []
683686
for union_arg in arg.items:
@@ -1073,7 +1076,7 @@ def replace_alias_tvars(tp: Type, vars: List[str], subs: List[Type],
10731076
def set_any_tvars(tp: Type, vars: List[str],
10741077
newline: int, newcolumn: int, implicit: bool = True) -> Type:
10751078
if implicit:
1076-
type_of_any = TypeOfAny.from_omitted_generics # type: int
1079+
type_of_any = TypeOfAny.from_omitted_generics
10771080
else:
10781081
type_of_any = TypeOfAny.special_form
10791082
any_type = AnyType(type_of_any, line=newline, column=newcolumn)

0 commit comments

Comments
 (0)