Skip to content

Commit bf4530e

Browse files
ethan-lebatushar-deepsource
authored andcommitted
Allow booleans to be narrowed to literal types (python#10389)
1 parent 2db8ff8 commit bf4530e

10 files changed

+142
-52
lines changed

mypy/checker.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
5252
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
5353
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
54-
tuple_fallback, is_singleton_type, try_expanding_enum_to_union,
54+
tuple_fallback, is_singleton_type, try_expanding_sum_type_to_union,
5555
true_only, false_only, function_type, get_type_vars, custom_special_method,
5656
is_literal_type_like,
5757
)
@@ -4583,8 +4583,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:
45834583

45844584
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
45854585
# respectively
4586-
vartype = type_map[node]
4587-
self._check_for_truthy_type(vartype, node)
4586+
original_vartype = type_map[node]
4587+
self._check_for_truthy_type(original_vartype, node)
4588+
vartype = try_expanding_sum_type_to_union(original_vartype, "builtins.bool")
4589+
45884590
if_type = true_only(vartype) # type: Type
45894591
else_type = false_only(vartype) # type: Type
45904592
ref = node # type: Expression
@@ -4857,10 +4859,11 @@ def refine_identity_comparison_expression(self,
48574859
if singleton_index == -1:
48584860
singleton_index = possible_target_indices[-1]
48594861

4860-
enum_name = None
4862+
sum_type_name = None
48614863
target = get_proper_type(target)
4862-
if isinstance(target, LiteralType) and target.is_enum_literal():
4863-
enum_name = target.fallback.type.fullname
4864+
if (isinstance(target, LiteralType) and
4865+
(target.is_enum_literal() or isinstance(target.value, bool))):
4866+
sum_type_name = target.fallback.type.fullname
48644867

48654868
target_type = [TypeRange(target, is_upper_bound=False)]
48664869

@@ -4881,8 +4884,8 @@ def refine_identity_comparison_expression(self,
48814884
expr = operands[i]
48824885
expr_type = coerce_to_literal(operand_types[i])
48834886

4884-
if enum_name is not None:
4885-
expr_type = try_expanding_enum_to_union(expr_type, enum_name)
4887+
if sum_type_name is not None:
4888+
expr_type = try_expanding_sum_type_to_union(expr_type, sum_type_name)
48864889

48874890
# We intentionally use 'conditional_type_map' directly here instead of
48884891
# 'self.conditional_type_map_with_intersection': we only compute ad-hoc

mypy/checkexpr.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@
6666
FunctionContext, FunctionSigContext,
6767
)
6868
from mypy.typeops import (
69-
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
70-
function_type, callable_type, try_getting_str_literals, custom_special_method,
69+
try_expanding_sum_type_to_union, tuple_fallback, make_simplified_union,
70+
true_only, false_only, erase_to_union_or_bound, function_type,
71+
callable_type, try_getting_str_literals, custom_special_method,
7172
is_literal_type_like,
7273
)
7374
import mypy.errorcodes as codes
@@ -2800,6 +2801,9 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
28002801
# '[1] or []' are inferred correctly.
28012802
ctx = self.type_context[-1]
28022803
left_type = self.accept(e.left, ctx)
2804+
expanded_left_type = try_expanding_sum_type_to_union(
2805+
self.accept(e.left, ctx), "builtins.bool"
2806+
)
28032807

28042808
assert e.op in ('and', 'or') # Checked by visit_op_expr
28052809

@@ -2834,7 +2838,7 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
28342838
# to be unreachable and therefore any errors found in the right branch
28352839
# should be suppressed.
28362840
with (self.msg.disable_errors() if right_map is None else nullcontext()):
2837-
right_type = self.analyze_cond_branch(right_map, e.right, left_type)
2841+
right_type = self.analyze_cond_branch(right_map, e.right, expanded_left_type)
28382842

28392843
if right_map is None:
28402844
# The boolean expression is statically known to be the left value
@@ -2846,11 +2850,11 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
28462850
return right_type
28472851

28482852
if e.op == 'and':
2849-
restricted_left_type = false_only(left_type)
2850-
result_is_left = not left_type.can_be_true
2853+
restricted_left_type = false_only(expanded_left_type)
2854+
result_is_left = not expanded_left_type.can_be_true
28512855
elif e.op == 'or':
2852-
restricted_left_type = true_only(left_type)
2853-
result_is_left = not left_type.can_be_false
2856+
restricted_left_type = true_only(expanded_left_type)
2857+
result_is_left = not expanded_left_type.can_be_false
28542858

28552859
if isinstance(restricted_left_type, UninhabitedType):
28562860
# The left operand can never be the result

mypy/typeops.py

+33-24
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ def is_singleton_type(typ: Type) -> bool:
701701
)
702702

703703

704-
def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> ProperType:
704+
def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> ProperType:
705705
"""Attempts to recursively expand any enum Instances with the given target_fullname
706706
into a Union of all of its component LiteralTypes.
707707
@@ -723,28 +723,34 @@ class Status(Enum):
723723
typ = get_proper_type(typ)
724724

725725
if isinstance(typ, UnionType):
726-
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
726+
items = [try_expanding_sum_type_to_union(item, target_fullname) for item in typ.items]
727727
return make_simplified_union(items, contract_literals=False)
728-
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname:
729-
new_items = []
730-
for name, symbol in typ.type.names.items():
731-
if not isinstance(symbol.node, Var):
732-
continue
733-
# Skip "_order_" and "__order__", since Enum will remove it
734-
if name in ("_order_", "__order__"):
735-
continue
736-
new_items.append(LiteralType(name, typ))
737-
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
738-
# insertion order only starting with Python 3.7. So, we sort these for older
739-
# versions of Python to help make tests deterministic.
740-
#
741-
# We could probably skip the sort for Python 3.6 since people probably run mypy
742-
# only using CPython, but we might as well for the sake of full correctness.
743-
if sys.version_info < (3, 7):
744-
new_items.sort(key=lambda lit: lit.value)
745-
return make_simplified_union(new_items, contract_literals=False)
746-
else:
747-
return typ
728+
elif isinstance(typ, Instance) and typ.type.fullname == target_fullname:
729+
if typ.type.is_enum:
730+
new_items = []
731+
for name, symbol in typ.type.names.items():
732+
if not isinstance(symbol.node, Var):
733+
continue
734+
# Skip "_order_" and "__order__", since Enum will remove it
735+
if name in ("_order_", "__order__"):
736+
continue
737+
new_items.append(LiteralType(name, typ))
738+
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
739+
# insertion order only starting with Python 3.7. So, we sort these for older
740+
# versions of Python to help make tests deterministic.
741+
#
742+
# We could probably skip the sort for Python 3.6 since people probably run mypy
743+
# only using CPython, but we might as well for the sake of full correctness.
744+
if sys.version_info < (3, 7):
745+
new_items.sort(key=lambda lit: lit.value)
746+
return make_simplified_union(new_items, contract_literals=False)
747+
elif typ.type.fullname == "builtins.bool":
748+
return make_simplified_union(
749+
[LiteralType(True, typ), LiteralType(False, typ)],
750+
contract_literals=False
751+
)
752+
753+
return typ
748754

749755

750756
def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]:
@@ -762,9 +768,12 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]
762768
for idx, typ in enumerate(proper_types):
763769
if isinstance(typ, LiteralType):
764770
fullname = typ.fallback.type.fullname
765-
if typ.fallback.type.is_enum:
771+
if typ.fallback.type.is_enum or isinstance(typ.value, bool):
766772
if fullname not in sum_types:
767-
sum_types[fullname] = (set(get_enum_values(typ.fallback)), [])
773+
sum_types[fullname] = (set(get_enum_values(typ.fallback))
774+
if typ.fallback.type.is_enum
775+
else set((True, False)),
776+
[])
768777
literals, indexes = sum_types[fullname]
769778
literals.discard(typ.value)
770779
indexes.append(idx)

test-data/unit/check-dynamic-typing.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ a or d
159159
if int():
160160
c = a in d # E: Incompatible types in assignment (expression has type "bool", variable has type "C")
161161
if int():
162-
c = b and d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
162+
c = b and d # E: Incompatible types in assignment (expression has type "Union[Literal[False], Any]", variable has type "C")
163163
if int():
164-
c = b or d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
164+
c = b or d # E: Incompatible types in assignment (expression has type "Union[Literal[True], Any]", variable has type "C")
165165
if int():
166166
b = a + d
167167
if int():

test-data/unit/check-expressions.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,11 @@ if int():
323323
if int():
324324
b = b or b
325325
if int():
326-
b = b and a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool")
326+
b = b and a # E: Incompatible types in assignment (expression has type "Union[Literal[False], A]", variable has type "bool")
327327
if int():
328328
b = a and b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool")
329329
if int():
330-
b = b or a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool")
330+
b = b or a # E: Incompatible types in assignment (expression has type "Union[Literal[True], A]", variable has type "bool")
331331
if int():
332332
b = a or b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool")
333333
class A: pass

test-data/unit/check-narrowing.test

+74-1
Original file line numberDiff line numberDiff line change
@@ -1047,8 +1047,81 @@ else:
10471047
if str_or_bool_literal is not True and str_or_bool_literal is not False:
10481048
reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.str"
10491049
else:
1050-
reveal_type(str_or_bool_literal) # N: Revealed type is "Union[Literal[False], Literal[True]]"
1050+
reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.bool"
1051+
[builtins fixtures/primitives.pyi]
1052+
1053+
[case testNarrowingBooleanIdentityCheck]
1054+
# flags: --strict-optional
1055+
from typing import Optional
1056+
from typing_extensions import Literal
1057+
1058+
bool_val: bool
1059+
1060+
if bool_val is not False:
1061+
reveal_type(bool_val) # N: Revealed type is "Literal[True]"
1062+
else:
1063+
reveal_type(bool_val) # N: Revealed type is "Literal[False]"
1064+
1065+
opt_bool_val: Optional[bool]
1066+
1067+
if opt_bool_val is not None:
1068+
reveal_type(opt_bool_val) # N: Revealed type is "builtins.bool"
1069+
1070+
if opt_bool_val is not False:
1071+
reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[True], None]"
1072+
else:
1073+
reveal_type(opt_bool_val) # N: Revealed type is "Literal[False]"
1074+
[builtins fixtures/primitives.pyi]
1075+
1076+
[case testNarrowingBooleanTruthiness]
1077+
# flags: --strict-optional
1078+
from typing import Optional
1079+
from typing_extensions import Literal
1080+
1081+
bool_val: bool
1082+
1083+
if bool_val:
1084+
reveal_type(bool_val) # N: Revealed type is "Literal[True]"
1085+
else:
1086+
reveal_type(bool_val) # N: Revealed type is "Literal[False]"
1087+
reveal_type(bool_val) # N: Revealed type is "builtins.bool"
1088+
1089+
opt_bool_val: Optional[bool]
1090+
1091+
if opt_bool_val:
1092+
reveal_type(opt_bool_val) # N: Revealed type is "Literal[True]"
1093+
else:
1094+
reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[False], None]"
1095+
reveal_type(opt_bool_val) # N: Revealed type is "Union[builtins.bool, None]"
1096+
[builtins fixtures/primitives.pyi]
1097+
1098+
[case testNarrowingBooleanBoolOp]
1099+
# flags: --strict-optional
1100+
from typing import Optional
1101+
from typing_extensions import Literal
1102+
1103+
bool_a: bool
1104+
bool_b: bool
1105+
1106+
if bool_a and bool_b:
1107+
reveal_type(bool_a) # N: Revealed type is "Literal[True]"
1108+
reveal_type(bool_b) # N: Revealed type is "Literal[True]"
1109+
else:
1110+
reveal_type(bool_a) # N: Revealed type is "builtins.bool"
1111+
reveal_type(bool_b) # N: Revealed type is "builtins.bool"
1112+
1113+
if not bool_a or bool_b:
1114+
reveal_type(bool_a) # N: Revealed type is "builtins.bool"
1115+
reveal_type(bool_b) # N: Revealed type is "builtins.bool"
1116+
else:
1117+
reveal_type(bool_a) # N: Revealed type is "Literal[True]"
1118+
reveal_type(bool_b) # N: Revealed type is "Literal[False]"
1119+
1120+
if True and bool_b:
1121+
reveal_type(bool_b) # N: Revealed type is "Literal[True]"
10511122

1123+
x = True and bool_b
1124+
reveal_type(x) # N: Revealed type is "builtins.bool"
10521125
[builtins fixtures/primitives.pyi]
10531126

10541127
[case testNarrowingTypedDictUsingEnumLiteral]

test-data/unit/check-newsemanal.test

+1
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ from a import x
304304
def f(): pass
305305

306306
[targets a, b, a, a.y, b.f, __main__]
307+
[builtins fixtures/tuple.pyi]
307308

308309
[case testNewAnalyzerRedefinitionAndDeferral1b]
309310
import a

test-data/unit/check-python38.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,10 @@ from typing import Optional
411411
maybe_str: Optional[str]
412412

413413
if (is_str := maybe_str is not None):
414-
reveal_type(is_str) # N: Revealed type is "builtins.bool"
414+
reveal_type(is_str) # N: Revealed type is "Literal[True]"
415415
reveal_type(maybe_str) # N: Revealed type is "builtins.str"
416416
else:
417-
reveal_type(is_str) # N: Revealed type is "builtins.bool"
417+
reveal_type(is_str) # N: Revealed type is "Literal[False]"
418418
reveal_type(maybe_str) # N: Revealed type is "None"
419419

420420
reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"

test-data/unit/check-unreachable-code.test

+4-4
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,11 @@ f = (PY3 or PY2) and 's'
533533
g = (PY2 or PY3) or 's'
534534
h = (PY3 or PY2) or 's'
535535
reveal_type(a) # N: Revealed type is "builtins.bool"
536-
reveal_type(b) # N: Revealed type is "builtins.str"
537-
reveal_type(c) # N: Revealed type is "builtins.str"
536+
reveal_type(b) # N: Revealed type is "Literal['s']"
537+
reveal_type(c) # N: Revealed type is "Literal['s']"
538538
reveal_type(d) # N: Revealed type is "builtins.bool"
539-
reveal_type(e) # N: Revealed type is "builtins.str"
540-
reveal_type(f) # N: Revealed type is "builtins.str"
539+
reveal_type(e) # N: Revealed type is "Literal['s']"
540+
reveal_type(f) # N: Revealed type is "Literal['s']"
541541
reveal_type(g) # N: Revealed type is "builtins.bool"
542542
reveal_type(h) # N: Revealed type is "builtins.bool"
543543
[builtins fixtures/ops.pyi]

test-data/unit/typexport-basic.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ elif not a:
247247
[out]
248248
NameExpr(3) : builtins.bool
249249
IntExpr(4) : Literal[1]?
250-
NameExpr(5) : builtins.bool
250+
NameExpr(5) : Literal[False]
251251
UnaryExpr(5) : builtins.bool
252252
IntExpr(6) : Literal[1]?
253253

@@ -259,7 +259,7 @@ while a:
259259
[builtins fixtures/bool.pyi]
260260
[out]
261261
NameExpr(3) : builtins.bool
262-
NameExpr(4) : builtins.bool
262+
NameExpr(4) : Literal[True]
263263

264264

265265
-- Simple type inference

0 commit comments

Comments
 (0)