Skip to content

Commit 08467a1

Browse files
ethan-lebaEthan Leba
authored and
Ethan Leba
committed
Expand booleans during truthy checks and bool ops
1 parent 9e7a695 commit 08467a1

8 files changed

+74
-18
lines changed

mypy/checker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4318,7 +4318,7 @@ def has_no_custom_eq_checks(t: Type) -> bool:
43184318
elif isinstance(node, RefExpr):
43194319
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
43204320
# respectively
4321-
vartype = type_map[node]
4321+
vartype = try_expanding_sum_type_to_union(type_map[node], "builtins.bool")
43224322
if_type: Type = true_only(vartype)
43234323
else_type: Type = false_only(vartype)
43244324
ref: Expression = node

mypy/checkexpr.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@
6464
FunctionContext, FunctionSigContext,
6565
)
6666
from mypy.typeops import (
67-
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
68-
function_type, callable_type, try_getting_str_literals, custom_special_method,
67+
try_expanding_sum_type_to_union, tuple_fallback, make_simplified_union,
68+
true_only, false_only, erase_to_union_or_bound, function_type,
69+
callable_type, try_getting_str_literals, custom_special_method,
6970
is_literal_type_like,
7071
)
7172
import mypy.errorcodes as codes
@@ -2783,6 +2784,9 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
27832784
# '[1] or []' are inferred correctly.
27842785
ctx = self.type_context[-1]
27852786
left_type = self.accept(e.left, ctx)
2787+
expanded_left_type = try_expanding_sum_type_to_union(
2788+
self.accept(e.left, ctx), "builtins.bool"
2789+
)
27862790

27872791
assert e.op in ('and', 'or') # Checked by visit_op_expr
27882792

@@ -2817,7 +2821,7 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
28172821
# to be unreachable and therefore any errors found in the right branch
28182822
# should be suppressed.
28192823
with (self.msg.disable_errors() if right_map is None else nullcontext()):
2820-
right_type = self.analyze_cond_branch(right_map, e.right, left_type)
2824+
right_type = self.analyze_cond_branch(right_map, e.right, expanded_left_type)
28212825

28222826
if right_map is None:
28232827
# The boolean expression is statically known to be the left value
@@ -2829,11 +2833,11 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
28292833
return right_type
28302834

28312835
if e.op == 'and':
2832-
restricted_left_type = false_only(left_type)
2833-
result_is_left = not left_type.can_be_true
2836+
restricted_left_type = false_only(expanded_left_type)
2837+
result_is_left = not expanded_left_type.can_be_true
28342838
elif e.op == 'or':
2835-
restricted_left_type = true_only(left_type)
2836-
result_is_left = not left_type.can_be_false
2839+
restricted_left_type = true_only(expanded_left_type)
2840+
result_is_left = not expanded_left_type.can_be_false
28372841

28382842
if isinstance(restricted_left_type, UninhabitedType):
28392843
# The left operand can never be the result

test-data/unit/check-dynamic-typing.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ a or d
159159
if int():
160160
c = a in d # E: Incompatible types in assignment (expression has type "bool", variable has type "C")
161161
if int():
162-
c = b and d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
162+
c = b and d # E: Incompatible types in assignment (expression has type "Union[Literal[False], Any]", variable has type "C")
163163
if int():
164-
c = b or d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
164+
c = b or d # E: Incompatible types in assignment (expression has type "Union[Literal[True], Any]", variable has type "C")
165165
if int():
166166
b = a + d
167167
if int():

test-data/unit/check-expressions.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,11 @@ if int():
316316
if int():
317317
b = b or b
318318
if int():
319-
b = b and a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool")
319+
b = b and a # E: Incompatible types in assignment (expression has type "Union[Literal[False], A]", variable has type "bool")
320320
if int():
321321
b = a and b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool")
322322
if int():
323-
b = b or a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool")
323+
b = b or a # E: Incompatible types in assignment (expression has type "Union[Literal[True], A]", variable has type "bool")
324324
if int():
325325
b = a or b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool")
326326
class A: pass

test-data/unit/check-narrowing.test

+51
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,57 @@ else:
10521052
reveal_type(opt_bool_val) # N: Revealed type is "Literal[False]"
10531053
[builtins fixtures/primitives.pyi]
10541054

1055+
[case testNarrowingBooleanTruthiness]
1056+
# flags: --strict-optional
1057+
from typing import Optional
1058+
from typing_extensions import Literal
1059+
1060+
bool_val: bool
1061+
1062+
if bool_val:
1063+
reveal_type(bool_val) # N: Revealed type is "Literal[True]"
1064+
else:
1065+
reveal_type(bool_val) # N: Revealed type is "Literal[False]"
1066+
reveal_type(bool_val) # N: Revealed type is "builtins.bool"
1067+
1068+
opt_bool_val: Optional[bool]
1069+
1070+
if opt_bool_val:
1071+
reveal_type(opt_bool_val) # N: Revealed type is "Literal[True]"
1072+
else:
1073+
reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[False], None]"
1074+
reveal_type(opt_bool_val) # N: Revealed type is "Union[builtins.bool, None]"
1075+
[builtins fixtures/primitives.pyi]
1076+
1077+
[case testNarrowingBooleanBoolOp]
1078+
# flags: --strict-optional
1079+
from typing import Optional
1080+
from typing_extensions import Literal
1081+
1082+
bool_a: bool
1083+
bool_b: bool
1084+
1085+
if bool_a and bool_b:
1086+
reveal_type(bool_a) # N: Revealed type is "Literal[True]"
1087+
reveal_type(bool_b) # N: Revealed type is "Literal[True]"
1088+
else:
1089+
reveal_type(bool_a) # N: Revealed type is "builtins.bool"
1090+
reveal_type(bool_b) # N: Revealed type is "builtins.bool"
1091+
1092+
if not bool_a or bool_b:
1093+
reveal_type(bool_a) # N: Revealed type is "builtins.bool"
1094+
reveal_type(bool_b) # N: Revealed type is "builtins.bool"
1095+
else:
1096+
reveal_type(bool_a) # N: Revealed type is "Literal[True]"
1097+
reveal_type(bool_b) # N: Revealed type is "Literal[False]"
1098+
1099+
if True and bool_b:
1100+
reveal_type(bool_b) # N: Revealed type is "Literal[True]"
1101+
1102+
x = True and bool_b
1103+
reveal_type(x) # N: Revealed type is "builtins.bool"
1104+
[builtins fixtures/primitives.pyi]
1105+
10551106
[case testNarrowingTypedDictUsingEnumLiteral]
10561107
# flags: --python-version 3.6
10571108
from typing import Union

test-data/unit/check-newsemanal.test

+1
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ from a import x
304304
def f(): pass
305305

306306
[targets a, b, a, a.y, b.f, __main__]
307+
[builtins fixtures/tuple.pyi]
307308

308309
[case testNewAnalyzerRedefinitionAndDeferral1b]
309310
import a

test-data/unit/check-unreachable-code.test

+4-4
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,11 @@ f = (PY3 or PY2) and 's'
533533
g = (PY2 or PY3) or 's'
534534
h = (PY3 or PY2) or 's'
535535
reveal_type(a) # N: Revealed type is "builtins.bool"
536-
reveal_type(b) # N: Revealed type is "builtins.str"
537-
reveal_type(c) # N: Revealed type is "builtins.str"
536+
reveal_type(b) # N: Revealed type is "Literal['s']"
537+
reveal_type(c) # N: Revealed type is "Literal['s']"
538538
reveal_type(d) # N: Revealed type is "builtins.bool"
539-
reveal_type(e) # N: Revealed type is "builtins.str"
540-
reveal_type(f) # N: Revealed type is "builtins.str"
539+
reveal_type(e) # N: Revealed type is "Literal['s']"
540+
reveal_type(f) # N: Revealed type is "Literal['s']"
541541
reveal_type(g) # N: Revealed type is "builtins.bool"
542542
reveal_type(h) # N: Revealed type is "builtins.bool"
543543
[builtins fixtures/ops.pyi]

test-data/unit/typexport-basic.test

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ elif not a:
247247
[out]
248248
NameExpr(3) : builtins.bool
249249
IntExpr(4) : Literal[1]?
250-
NameExpr(5) : builtins.bool
250+
NameExpr(5) : Literal[False]
251251
UnaryExpr(5) : builtins.bool
252252
IntExpr(6) : Literal[1]?
253253

@@ -259,7 +259,7 @@ while a:
259259
[builtins fixtures/bool.pyi]
260260
[out]
261261
NameExpr(3) : builtins.bool
262-
NameExpr(4) : builtins.bool
262+
NameExpr(4) : Literal[True]
263263

264264

265265
-- Simple type inference

0 commit comments

Comments
 (0)