18
18
from mypy .types import (
19
19
Type , AnyType , CallableType , Overloaded , NoneTyp , TypeVarDef ,
20
20
TupleType , TypedDictType , Instance , TypeVarType , ErasedType , UnionType ,
21
- PartialType , DeletedType , UninhabitedType , TypeType , TypeOfAny , LiteralType ,
21
+ PartialType , DeletedType , UninhabitedType , TypeType , TypeOfAny , LiteralType , LiteralValue ,
22
22
true_only , false_only , is_named_instance , function_type , callable_type , FunctionLike ,
23
23
StarType , is_optional , remove_optional , is_generic_instance
24
24
)
@@ -139,7 +139,7 @@ def __init__(self,
139
139
self .msg = msg
140
140
self .plugin = plugin
141
141
self .type_context = [None ]
142
- self .infer_literal = False
142
+ self .in_final_declaration = False
143
143
# Temporary overrides for expression types. This is currently
144
144
# used by the union math in overloads.
145
145
# TODO: refactor this to use a pattern similar to one in
@@ -211,10 +211,12 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
211
211
212
212
def analyze_var_ref (self , var : Var , context : Context ) -> Type :
213
213
if var .type :
214
- if self .is_literal_context () and var .name () in {'True' , 'False' }:
215
- return LiteralType (var .name () == 'True' , self .named_type ('builtins.bool' ))
216
- else :
217
- return var .type
214
+ if isinstance (var .type , Instance ):
215
+ if self ._is_literal_context () and var .type .final_value is not None :
216
+ return var .type .final_value
217
+ if var .name () in {'True' , 'False' }:
218
+ return self ._handle_literal_expr (var .name () == 'True' , 'builtins.bool' )
219
+ return var .type
218
220
else :
219
221
if not var .is_ready and self .chk .in_checked_function ():
220
222
self .chk .handle_cannot_determine_type (var .name (), context )
@@ -693,7 +695,8 @@ def check_call(self,
693
695
elif isinstance (callee , Instance ):
694
696
call_function = analyze_member_access ('__call__' , callee , context ,
695
697
False , False , False , self .msg ,
696
- original_type = callee , chk = self .chk )
698
+ original_type = callee , chk = self .chk ,
699
+ in_literal_context = self ._is_literal_context ())
697
700
return self .check_call (call_function , args , arg_kinds , context , arg_names ,
698
701
callable_node , arg_messages )
699
702
elif isinstance (callee , TypeVarType ):
@@ -1757,7 +1760,8 @@ def analyze_ordinary_member_access(self, e: MemberExpr,
1757
1760
original_type = self .accept (e .expr )
1758
1761
member_type = analyze_member_access (
1759
1762
e .name , original_type , e , is_lvalue , False , False ,
1760
- self .msg , original_type = original_type , chk = self .chk )
1763
+ self .msg , original_type = original_type , chk = self .chk ,
1764
+ in_literal_context = self ._is_literal_context ())
1761
1765
return member_type
1762
1766
1763
1767
def analyze_external_member_access (self , member : str , base_type : Type ,
@@ -1767,35 +1771,36 @@ def analyze_external_member_access(self, member: str, base_type: Type,
1767
1771
"""
1768
1772
# TODO remove; no private definitions in mypy
1769
1773
return analyze_member_access (member , base_type , context , False , False , False ,
1770
- self .msg , original_type = base_type , chk = self .chk )
1774
+ self .msg , original_type = base_type , chk = self .chk ,
1775
+ in_literal_context = self ._is_literal_context ())
1776
+
1777
+ def _is_literal_context (self ) -> bool :
1778
+ return is_literal_type_like (self .type_context [- 1 ])
1779
+
1780
+ def _handle_literal_expr (self , value : LiteralValue , fallback_name : str ) -> Type :
1781
+ typ = self .named_type (fallback_name )
1782
+ if self ._is_literal_context ():
1783
+ return LiteralType (value = value , fallback = typ )
1784
+ elif self .in_final_declaration :
1785
+ return typ .copy_with_final_value (value )
1786
+ else :
1787
+ return typ
1771
1788
1772
1789
def visit_int_expr (self , e : IntExpr ) -> Type :
1773
1790
"""Type check an integer literal (trivial)."""
1774
- typ = self .named_type ('builtins.int' )
1775
- if self .is_literal_context ():
1776
- return LiteralType (value = e .value , fallback = typ )
1777
- return typ
1791
+ return self ._handle_literal_expr (e .value , 'builtins.int' )
1778
1792
1779
1793
def visit_str_expr (self , e : StrExpr ) -> Type :
1780
1794
"""Type check a string literal (trivial)."""
1781
- typ = self .named_type ('builtins.str' )
1782
- if self .is_literal_context ():
1783
- return LiteralType (value = e .value , fallback = typ )
1784
- return typ
1795
+ return self ._handle_literal_expr (e .value , 'builtins.str' )
1785
1796
1786
1797
def visit_bytes_expr (self , e : BytesExpr ) -> Type :
1787
1798
"""Type check a bytes literal (trivial)."""
1788
- typ = self .named_type ('builtins.bytes' )
1789
- if is_literal_type_like (self .type_context [- 1 ]):
1790
- return LiteralType (value = e .value , fallback = typ )
1791
- return typ
1799
+ return self ._handle_literal_expr (e .value , 'builtins.bytes' )
1792
1800
1793
1801
def visit_unicode_expr (self , e : UnicodeExpr ) -> Type :
1794
1802
"""Type check a unicode literal (trivial)."""
1795
- typ = self .named_type ('builtins.unicode' )
1796
- if is_literal_type_like (self .type_context [- 1 ]):
1797
- return LiteralType (value = e .value , fallback = typ )
1798
- return typ
1803
+ return self ._handle_literal_expr (e .value , 'builtins.unicode' )
1799
1804
1800
1805
def visit_float_expr (self , e : FloatExpr ) -> Type :
1801
1806
"""Type check a float literal (trivial)."""
@@ -1932,7 +1937,8 @@ def check_method_call_by_name(self,
1932
1937
"""
1933
1938
local_errors = local_errors or self .msg
1934
1939
method_type = analyze_member_access (method , base_type , context , False , False , True ,
1935
- local_errors , original_type = base_type , chk = self .chk )
1940
+ local_errors , original_type = base_type , chk = self .chk ,
1941
+ in_literal_context = self ._is_literal_context ())
1936
1942
return self .check_method_call (
1937
1943
method , base_type , method_type , args , arg_kinds , context , local_errors )
1938
1944
@@ -1996,6 +2002,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
1996
2002
context = context ,
1997
2003
msg = local_errors ,
1998
2004
chk = self .chk ,
2005
+ in_literal_context = self ._is_literal_context ()
1999
2006
)
2000
2007
if local_errors .is_errors ():
2001
2008
return None
@@ -2946,7 +2953,8 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
2946
2953
override_info = base ,
2947
2954
context = e ,
2948
2955
msg = self .msg ,
2949
- chk = self .chk )
2956
+ chk = self .chk ,
2957
+ in_literal_context = self ._is_literal_context ())
2950
2958
assert False , 'unreachable'
2951
2959
else :
2952
2960
# Invalid super. This has been reported by the semantic analyzer.
@@ -3113,16 +3121,16 @@ def accept(self,
3113
3121
type_context : Optional [Type ] = None ,
3114
3122
allow_none_return : bool = False ,
3115
3123
always_allow_any : bool = False ,
3116
- infer_literal : bool = False ,
3124
+ in_final_declaration : bool = False ,
3117
3125
) -> Type :
3118
3126
"""Type check a node in the given type context. If allow_none_return
3119
3127
is True and this expression is a call, allow it to return None. This
3120
3128
applies only to this expression and not any subexpressions.
3121
3129
"""
3122
3130
if node in self .type_overrides :
3123
3131
return self .type_overrides [node ]
3124
- old_infer_literal = self .infer_literal
3125
- self .infer_literal = infer_literal
3132
+ old_in_final_declaration = self .in_final_declaration
3133
+ self .in_final_declaration = in_final_declaration
3126
3134
self .type_context .append (type_context )
3127
3135
try :
3128
3136
if allow_none_return and isinstance (node , CallExpr ):
@@ -3135,7 +3143,7 @@ def accept(self,
3135
3143
report_internal_error (err , self .chk .errors .file ,
3136
3144
node .line , self .chk .errors , self .chk .options )
3137
3145
self .type_context .pop ()
3138
- self .infer_literal = old_infer_literal
3146
+ self .in_final_declaration = old_in_final_declaration
3139
3147
assert typ is not None
3140
3148
self .chk .store_type (node , typ )
3141
3149
@@ -3381,9 +3389,6 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type:
3381
3389
return ans
3382
3390
return known_type
3383
3391
3384
- def is_literal_context (self ) -> bool :
3385
- return self .infer_literal or is_literal_type_like (self .type_context [- 1 ])
3386
-
3387
3392
3388
3393
def has_any_type (t : Type ) -> bool :
3389
3394
"""Whether t contains an Any type"""
0 commit comments