From ac3f31357f9fa316856f9a4d175d19763ff2ae58 Mon Sep 17 00:00:00 2001 From: ethan-leba Date: Fri, 30 Apr 2021 08:19:13 -0400 Subject: [PATCH 1/6] Narrow booleans to literals with identity check --- mypy/checker.py | 3 +- mypy/typeops.py | 53 +++++++++++++++++------------ test-data/unit/check-narrowing.test | 24 ++++++++++++- 3 files changed, 56 insertions(+), 24 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ba020f5d97d5..d21af10ab16a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4614,7 +4614,8 @@ def refine_identity_comparison_expression(self, enum_name = None target = get_proper_type(target) - if isinstance(target, LiteralType) and target.is_enum_literal(): + if (isinstance(target, LiteralType) and + (target.is_enum_literal() or isinstance(target.value, bool))): enum_name = target.fallback.type.fullname target_type = [TypeRange(target, is_upper_bound=False)] diff --git a/mypy/typeops.py b/mypy/typeops.py index 20772c82c765..08c7c15d5cce 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -723,26 +723,32 @@ class Status(Enum): if isinstance(typ, UnionType): items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items] return make_simplified_union(items, contract_literals=False) - elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname: - new_items = [] - for name, symbol in typ.type.names.items(): - if not isinstance(symbol.node, Var): - continue - # Skip "_order_" and "__order__", since Enum will remove it - if name in ("_order_", "__order__"): - continue - new_items.append(LiteralType(name, typ)) - # SymbolTables are really just dicts, and dicts are guaranteed to preserve - # insertion order only starting with Python 3.7. So, we sort these for older - # versions of Python to help make tests deterministic. - # - # We could probably skip the sort for Python 3.6 since people probably run mypy - # only using CPython, but we might as well for the sake of full correctness. - if sys.version_info < (3, 7): - new_items.sort(key=lambda lit: lit.value) - return make_simplified_union(new_items, contract_literals=False) - else: - return typ + elif isinstance(typ, Instance) and typ.type.fullname == target_fullname: + if typ.type.is_enum: + new_items = [] + for name, symbol in typ.type.names.items(): + if not isinstance(symbol.node, Var): + continue + # Skip "_order_" and "__order__", since Enum will remove it + if name in ("_order_", "__order__"): + continue + new_items.append(LiteralType(name, typ)) + # SymbolTables are really just dicts, and dicts are guaranteed to preserve + # insertion order only starting with Python 3.7. So, we sort these for older + # versions of Python to help make tests deterministic. + # + # We could probably skip the sort for Python 3.6 since people probably run mypy + # only using CPython, but we might as well for the sake of full correctness. + if sys.version_info < (3, 7): + new_items.sort(key=lambda lit: lit.value) + return make_simplified_union(new_items, contract_literals=False) + elif typ.type.fullname == "builtins.bool": + return make_simplified_union( + [LiteralType(True, typ), LiteralType(False, typ)], + contract_literals=False + ) + + return typ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]: @@ -760,9 +766,12 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType] for idx, typ in enumerate(proper_types): if isinstance(typ, LiteralType): fullname = typ.fallback.type.fullname - if typ.fallback.type.is_enum: + if typ.fallback.type.is_enum or isinstance(typ.value, bool): if fullname not in sum_types: - sum_types[fullname] = (set(get_enum_values(typ.fallback)), []) + sum_types[fullname] = (set(get_enum_values(typ.fallback)) + if typ.fallback.type.is_enum + else set((True, False)), + []) literals, indexes = sum_types[fullname] literals.discard(typ.value) indexes.append(idx) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 952192995642..78c49abee05f 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1026,8 +1026,30 @@ else: if str_or_bool_literal is not True and str_or_bool_literal is not False: reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.str" else: - reveal_type(str_or_bool_literal) # N: Revealed type is "Union[Literal[False], Literal[True]]" + reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.bool" +[builtins fixtures/primitives.pyi] + +[case testNarrowingBooleanIdentityCheck] +# flags: --strict-optional +from typing import Optional +from typing_extensions import Literal + +bool_val: bool +if bool_val is not False: + reveal_type(bool_val) # N: Revealed type is "Literal[True]" +else: + reveal_type(bool_val) # N: Revealed type is "Literal[False]" + +opt_bool_val: Optional[bool] + +if opt_bool_val is not None: + reveal_type(opt_bool_val) # N: Revealed type is "builtins.bool" + +if opt_bool_val is not False: + reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[True], None]" +else: + reveal_type(opt_bool_val) # N: Revealed type is "Literal[False]" [builtins fixtures/primitives.pyi] [case testNarrowingTypedDictUsingEnumLiteral] From 9e7a695b308cb658d0c0cc0f09ab1f9f2bc62a43 Mon Sep 17 00:00:00 2001 From: ethan-leba Date: Fri, 30 Apr 2021 08:43:17 -0400 Subject: [PATCH 2/6] Rename expansion-related locals/functions No longer only works on enums --- mypy/checker.py | 10 +++++----- mypy/typeops.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index d21af10ab16a..7c3087d0fc92 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -50,7 +50,7 @@ map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, try_getting_str_literals_from_type, try_getting_int_literals_from_type, - tuple_fallback, is_singleton_type, try_expanding_enum_to_union, + tuple_fallback, is_singleton_type, try_expanding_sum_type_to_union, true_only, false_only, function_type, get_type_vars, custom_special_method, is_literal_type_like, ) @@ -4612,11 +4612,11 @@ def refine_identity_comparison_expression(self, if singleton_index == -1: singleton_index = possible_target_indices[-1] - enum_name = None + sum_type_name = None target = get_proper_type(target) if (isinstance(target, LiteralType) and (target.is_enum_literal() or isinstance(target.value, bool))): - enum_name = target.fallback.type.fullname + sum_type_name = target.fallback.type.fullname target_type = [TypeRange(target, is_upper_bound=False)] @@ -4637,8 +4637,8 @@ def refine_identity_comparison_expression(self, expr = operands[i] expr_type = coerce_to_literal(operand_types[i]) - if enum_name is not None: - expr_type = try_expanding_enum_to_union(expr_type, enum_name) + if sum_type_name is not None: + expr_type = try_expanding_sum_type_to_union(expr_type, sum_type_name) # We intentionally use 'conditional_type_map' directly here instead of # 'self.conditional_type_map_with_intersection': we only compute ad-hoc diff --git a/mypy/typeops.py b/mypy/typeops.py index 08c7c15d5cce..507e7cfc0489 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -699,7 +699,7 @@ def is_singleton_type(typ: Type) -> bool: ) -def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> ProperType: +def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> ProperType: """Attempts to recursively expand any enum Instances with the given target_fullname into a Union of all of its component LiteralTypes. @@ -721,7 +721,7 @@ class Status(Enum): typ = get_proper_type(typ) if isinstance(typ, UnionType): - items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items] + items = [try_expanding_sum_type_to_union(item, target_fullname) for item in typ.items] return make_simplified_union(items, contract_literals=False) elif isinstance(typ, Instance) and typ.type.fullname == target_fullname: if typ.type.is_enum: From 08467a166a426d6f6ba279802323f75cab46fc44 Mon Sep 17 00:00:00 2001 From: ethan-leba Date: Thu, 6 May 2021 10:00:33 -0400 Subject: [PATCH 3/6] Expand booleans during truthy checks and bool ops --- mypy/checker.py | 2 +- mypy/checkexpr.py | 18 +++++--- test-data/unit/check-dynamic-typing.test | 4 +- test-data/unit/check-expressions.test | 4 +- test-data/unit/check-narrowing.test | 51 ++++++++++++++++++++++ test-data/unit/check-newsemanal.test | 1 + test-data/unit/check-unreachable-code.test | 8 ++-- test-data/unit/typexport-basic.test | 4 +- 8 files changed, 74 insertions(+), 18 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 7c3087d0fc92..e6928833bfe3 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4318,7 +4318,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: elif isinstance(node, RefExpr): # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively - vartype = type_map[node] + vartype = try_expanding_sum_type_to_union(type_map[node], "builtins.bool") if_type: Type = true_only(vartype) else_type: Type = false_only(vartype) ref: Expression = node diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 5aef881aa2ff..02ede5c7b11f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -64,8 +64,9 @@ FunctionContext, FunctionSigContext, ) from mypy.typeops import ( - tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound, - function_type, callable_type, try_getting_str_literals, custom_special_method, + try_expanding_sum_type_to_union, tuple_fallback, make_simplified_union, + true_only, false_only, erase_to_union_or_bound, function_type, + callable_type, try_getting_str_literals, custom_special_method, is_literal_type_like, ) import mypy.errorcodes as codes @@ -2783,6 +2784,9 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: # '[1] or []' are inferred correctly. ctx = self.type_context[-1] left_type = self.accept(e.left, ctx) + expanded_left_type = try_expanding_sum_type_to_union( + self.accept(e.left, ctx), "builtins.bool" + ) assert e.op in ('and', 'or') # Checked by visit_op_expr @@ -2817,7 +2821,7 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: # to be unreachable and therefore any errors found in the right branch # should be suppressed. with (self.msg.disable_errors() if right_map is None else nullcontext()): - right_type = self.analyze_cond_branch(right_map, e.right, left_type) + right_type = self.analyze_cond_branch(right_map, e.right, expanded_left_type) if right_map is None: # The boolean expression is statically known to be the left value @@ -2829,11 +2833,11 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: return right_type if e.op == 'and': - restricted_left_type = false_only(left_type) - result_is_left = not left_type.can_be_true + restricted_left_type = false_only(expanded_left_type) + result_is_left = not expanded_left_type.can_be_true elif e.op == 'or': - restricted_left_type = true_only(left_type) - result_is_left = not left_type.can_be_false + restricted_left_type = true_only(expanded_left_type) + result_is_left = not expanded_left_type.can_be_false if isinstance(restricted_left_type, UninhabitedType): # The left operand can never be the result diff --git a/test-data/unit/check-dynamic-typing.test b/test-data/unit/check-dynamic-typing.test index 137376535b4e..69eba1f894e3 100644 --- a/test-data/unit/check-dynamic-typing.test +++ b/test-data/unit/check-dynamic-typing.test @@ -159,9 +159,9 @@ a or d if int(): c = a in d # E: Incompatible types in assignment (expression has type "bool", variable has type "C") if int(): - c = b and d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C") + c = b and d # E: Incompatible types in assignment (expression has type "Union[Literal[False], Any]", variable has type "C") if int(): - c = b or d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C") + c = b or d # E: Incompatible types in assignment (expression has type "Union[Literal[True], Any]", variable has type "C") if int(): b = a + d if int(): diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index ff3a5efde6ad..7b02c43476db 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -316,11 +316,11 @@ if int(): if int(): b = b or b if int(): - b = b and a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool") + b = b and a # E: Incompatible types in assignment (expression has type "Union[Literal[False], A]", variable has type "bool") if int(): b = a and b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool") if int(): - b = b or a # E: Incompatible types in assignment (expression has type "Union[bool, A]", variable has type "bool") + b = b or a # E: Incompatible types in assignment (expression has type "Union[Literal[True], A]", variable has type "bool") if int(): b = a or b # E: Incompatible types in assignment (expression has type "Union[A, bool]", variable has type "bool") class A: pass diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 78c49abee05f..53bc886e19ee 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1052,6 +1052,57 @@ else: reveal_type(opt_bool_val) # N: Revealed type is "Literal[False]" [builtins fixtures/primitives.pyi] +[case testNarrowingBooleanTruthiness] +# flags: --strict-optional +from typing import Optional +from typing_extensions import Literal + +bool_val: bool + +if bool_val: + reveal_type(bool_val) # N: Revealed type is "Literal[True]" +else: + reveal_type(bool_val) # N: Revealed type is "Literal[False]" +reveal_type(bool_val) # N: Revealed type is "builtins.bool" + +opt_bool_val: Optional[bool] + +if opt_bool_val: + reveal_type(opt_bool_val) # N: Revealed type is "Literal[True]" +else: + reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[False], None]" +reveal_type(opt_bool_val) # N: Revealed type is "Union[builtins.bool, None]" +[builtins fixtures/primitives.pyi] + +[case testNarrowingBooleanBoolOp] +# flags: --strict-optional +from typing import Optional +from typing_extensions import Literal + +bool_a: bool +bool_b: bool + +if bool_a and bool_b: + reveal_type(bool_a) # N: Revealed type is "Literal[True]" + reveal_type(bool_b) # N: Revealed type is "Literal[True]" +else: + reveal_type(bool_a) # N: Revealed type is "builtins.bool" + reveal_type(bool_b) # N: Revealed type is "builtins.bool" + +if not bool_a or bool_b: + reveal_type(bool_a) # N: Revealed type is "builtins.bool" + reveal_type(bool_b) # N: Revealed type is "builtins.bool" +else: + reveal_type(bool_a) # N: Revealed type is "Literal[True]" + reveal_type(bool_b) # N: Revealed type is "Literal[False]" + +if True and bool_b: + reveal_type(bool_b) # N: Revealed type is "Literal[True]" + +x = True and bool_b +reveal_type(x) # N: Revealed type is "builtins.bool" +[builtins fixtures/primitives.pyi] + [case testNarrowingTypedDictUsingEnumLiteral] # flags: --python-version 3.6 from typing import Union diff --git a/test-data/unit/check-newsemanal.test b/test-data/unit/check-newsemanal.test index ed999d1f46b6..d44d8ad0348e 100644 --- a/test-data/unit/check-newsemanal.test +++ b/test-data/unit/check-newsemanal.test @@ -304,6 +304,7 @@ from a import x def f(): pass [targets a, b, a, a.y, b.f, __main__] +[builtins fixtures/tuple.pyi] [case testNewAnalyzerRedefinitionAndDeferral1b] import a diff --git a/test-data/unit/check-unreachable-code.test b/test-data/unit/check-unreachable-code.test index 010c944e3bfc..0e11ddef0998 100644 --- a/test-data/unit/check-unreachable-code.test +++ b/test-data/unit/check-unreachable-code.test @@ -533,11 +533,11 @@ f = (PY3 or PY2) and 's' g = (PY2 or PY3) or 's' h = (PY3 or PY2) or 's' reveal_type(a) # N: Revealed type is "builtins.bool" -reveal_type(b) # N: Revealed type is "builtins.str" -reveal_type(c) # N: Revealed type is "builtins.str" +reveal_type(b) # N: Revealed type is "Literal['s']" +reveal_type(c) # N: Revealed type is "Literal['s']" reveal_type(d) # N: Revealed type is "builtins.bool" -reveal_type(e) # N: Revealed type is "builtins.str" -reveal_type(f) # N: Revealed type is "builtins.str" +reveal_type(e) # N: Revealed type is "Literal['s']" +reveal_type(f) # N: Revealed type is "Literal['s']" reveal_type(g) # N: Revealed type is "builtins.bool" reveal_type(h) # N: Revealed type is "builtins.bool" [builtins fixtures/ops.pyi] diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index deb43f6d316f..4f40117d18d2 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -247,7 +247,7 @@ elif not a: [out] NameExpr(3) : builtins.bool IntExpr(4) : Literal[1]? -NameExpr(5) : builtins.bool +NameExpr(5) : Literal[False] UnaryExpr(5) : builtins.bool IntExpr(6) : Literal[1]? @@ -259,7 +259,7 @@ while a: [builtins fixtures/bool.pyi] [out] NameExpr(3) : builtins.bool -NameExpr(4) : builtins.bool +NameExpr(4) : Literal[True] -- Simple type inference From 2fe0e4c1d7c4facb269a9f600e3d0d0eeba3e876 Mon Sep 17 00:00:00 2001 From: hauntsaninja <> Date: Tue, 28 Sep 2021 22:21:17 -0700 Subject: [PATCH 4/6] fix tests --- mypy/checker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index e9b04f344a9a..43df720d7694 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4477,8 +4477,9 @@ def has_no_custom_eq_checks(t: Type) -> bool: # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively - vartype = try_expanding_sum_type_to_union(type_map[node], "builtins.bool") - self._check_for_truthy_type(vartype, node) + original_vartype = type_map[node] + self._check_for_truthy_type(original_vartype, node) + vartype = try_expanding_sum_type_to_union(original_vartype, "builtins.bool") if_type: Type = true_only(vartype) else_type: Type = false_only(vartype) From 0143fe70dc99193a86a6a2307b90978892269687 Mon Sep 17 00:00:00 2001 From: hauntsaninja <> Date: Tue, 28 Sep 2021 22:46:54 -0700 Subject: [PATCH 5/6] fix my bad merge --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 134d4236e831..363405f7fa63 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4516,7 +4516,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: original_vartype = type_map[node] self._check_for_truthy_type(original_vartype, node) vartype = try_expanding_sum_type_to_union(original_vartype, "builtins.bool") - self._check_for_truthy_type(vartype, node) + if_type = true_only(vartype) # type: Type else_type = false_only(vartype) # type: Type ref = node # type: Expression From 55e96b4dcd7c5ef1317f77ccb5fbef662afe0bb3 Mon Sep 17 00:00:00 2001 From: hauntsaninja <> Date: Tue, 28 Sep 2021 23:10:10 -0700 Subject: [PATCH 6/6] one more test --- test-data/unit/check-python38.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index f033f7e65e01..b5471f02c408 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -411,10 +411,10 @@ from typing import Optional maybe_str: Optional[str] if (is_str := maybe_str is not None): - reveal_type(is_str) # N: Revealed type is "builtins.bool" + reveal_type(is_str) # N: Revealed type is "Literal[True]" reveal_type(maybe_str) # N: Revealed type is "builtins.str" else: - reveal_type(is_str) # N: Revealed type is "builtins.bool" + reveal_type(is_str) # N: Revealed type is "Literal[False]" reveal_type(maybe_str) # N: Revealed type is "None" reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"