diff --git a/mypy/checker.py b/mypy/checker.py index 8973ade98228..6f14ee5fc2f3 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5017,6 +5017,45 @@ def conditional_callable_type_map( return None, {} + def conditional_types_for_iterable( + self, item_type: Type, iterable_type: Type + ) -> tuple[Type | None, Type | None]: + """ + Narrows the type of `iterable_type` based on the type of `item_type`. + For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s). + """ + if_types: list[Type] = [] + else_types: list[Type] = [] + + iterable_type = get_proper_type(iterable_type) + if isinstance(iterable_type, UnionType): + possible_iterable_types = get_proper_types(iterable_type.relevant_items()) + else: + possible_iterable_types = [iterable_type] + + item_str_literals = try_getting_str_literals_from_type(item_type) + + for possible_iterable_type in possible_iterable_types: + if item_str_literals and isinstance(possible_iterable_type, TypedDictType): + for key in item_str_literals: + if key in possible_iterable_type.required_keys: + if_types.append(possible_iterable_type) + elif ( + key in possible_iterable_type.items or not possible_iterable_type.is_final + ): + if_types.append(possible_iterable_type) + else_types.append(possible_iterable_type) + else: + else_types.append(possible_iterable_type) + else: + if_types.append(possible_iterable_type) + else_types.append(possible_iterable_type) + + return ( + UnionType.make_union(if_types) if if_types else None, + UnionType.make_union(else_types) if else_types else None, + ) + def _is_truthy_type(self, t: ProperType) -> bool: return ( ( @@ -5324,28 +5363,42 @@ def has_no_custom_eq_checks(t: Type) -> bool: elif operator in {"in", "not in"}: assert len(expr_indices) == 2 left_index, right_index = expr_indices - if left_index not in narrowable_operand_index_to_hash: - continue - item_type = operand_types[left_index] - collection_type = operand_types[right_index] + iterable_type = operand_types[right_index] - # We only try and narrow away 'None' for now - if not is_optional(item_type): - continue + if_map, else_map = {}, {} + + if left_index in narrowable_operand_index_to_hash: + # We only try and narrow away 'None' for now + if is_optional(item_type): + collection_item_type = get_proper_type( + builtin_item_type(iterable_type) + ) + if ( + collection_item_type is not None + and not is_optional(collection_item_type) + and not ( + isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == "builtins.object" + ) + and is_overlapping_erased_types(item_type, collection_item_type) + ): + if_map[operands[left_index]] = remove_optional(item_type) + + if right_index in narrowable_operand_index_to_hash: + if_type, else_type = self.conditional_types_for_iterable( + item_type, iterable_type + ) + expr = operands[right_index] + if if_type is None: + if_map = None + else: + if_map[expr] = if_type + if else_type is None: + else_map = None + else: + else_map[expr] = else_type - collection_item_type = get_proper_type(builtin_item_type(collection_type)) - if collection_item_type is None or is_optional(collection_item_type): - continue - if ( - isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == "builtins.object" - ): - continue - if is_overlapping_erased_types(item_type, collection_item_type): - if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {} - else: - continue else: if_map = {} else_map = {} diff --git a/mypy/types.py b/mypy/types.py index e322cf02505f..12e15409f45d 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2290,6 +2290,10 @@ def deserialize(cls, data: JsonDict) -> TypedDictType: Instance.deserialize(data["fallback"]), ) + @property + def is_final(self) -> bool: + return self.fallback.type.is_final + def is_anonymous(self) -> bool: return self.fallback.type.fullname in TPDICT_FB_NAMES diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 796f2f547528..1a0b478db0f2 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2025,6 +2025,191 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] +[case testOperatorContainsNarrowsTypedDicts_unionWithList] +from __future__ import annotations +from typing import assert_type, TypedDict, Union +from typing_extensions import final + +@final +class D(TypedDict): + foo: int + + +d_or_list: D | list[str] + +if 'foo' in d_or_list: + assert_type(d_or_list, Union[D, list[str]]) +elif 'bar' in d_or_list: + assert_type(d_or_list, list[str]) +else: + assert_type(d_or_list, list[str]) + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testOperatorContainsNarrowsTypedDicts_total] +from __future__ import annotations +from typing import assert_type, Literal, TypedDict, TypeVar, Union +from typing_extensions import final + +@final +class D1(TypedDict): + foo: int + + +@final +class D2(TypedDict): + bar: int + + +d: D1 | D2 + +if 'foo' in d: + assert_type(d, D1) +else: + assert_type(d, D2) + +foo_or_bar: Literal['foo', 'bar'] +if foo_or_bar in d: + assert_type(d, Union[D1, D2]) +else: + assert_type(d, Union[D1, D2]) + +foo_or_invalid: Literal['foo', 'invalid'] +if foo_or_invalid in d: + assert_type(d, D1) + # won't narrow 'foo_or_invalid' + assert_type(foo_or_invalid, Literal['foo', 'invalid']) +else: + assert_type(d, Union[D1, D2]) + # won't narrow 'foo_or_invalid' + assert_type(foo_or_invalid, Literal['foo', 'invalid']) + +TD = TypeVar('TD', D1, D2) + +def f(arg: TD) -> None: + value: int + if 'foo' in arg: + assert_type(arg['foo'], int) + else: + assert_type(arg['bar'], int) + + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testOperatorContainsNarrowsTypedDicts_final] +# flags: --warn-unreachable +from __future__ import annotations +from typing import assert_type, TypedDict, Union +from typing_extensions import final + +@final +class DFinal(TypedDict): + foo: int + + +class DNotFinal(TypedDict): + bar: int + + +d_not_final: DNotFinal + +if 'bar' in d_not_final: + assert_type(d_not_final, DNotFinal) +else: + spam = 'ham' # E: Statement is unreachable + +if 'spam' in d_not_final: + assert_type(d_not_final, DNotFinal) +else: + assert_type(d_not_final, DNotFinal) + +d_final: DFinal + +if 'spam' in d_final: + spam = 'ham' # E: Statement is unreachable +else: + assert_type(d_final, DFinal) + +d_union: DFinal | DNotFinal + +if 'foo' in d_union: + assert_type(d_union, Union[DFinal, DNotFinal]) +else: + assert_type(d_union, DNotFinal) + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testOperatorContainsNarrowsTypedDicts_partialThroughTotalFalse] +from __future__ import annotations +from typing import assert_type, Literal, TypedDict, Union +from typing_extensions import final + +@final +class DTotal(TypedDict): + required_key: int + + +@final +class DNotTotal(TypedDict, total=False): + optional_key: int + + +d: DTotal | DNotTotal + +if 'required_key' in d: + assert_type(d, DTotal) +else: + assert_type(d, DNotTotal) + +if 'optional_key' in d: + assert_type(d, DNotTotal) +else: + assert_type(d, Union[DTotal, DNotTotal]) + +key: Literal['optional_key', 'required_key'] +if key in d: + assert_type(d, Union[DTotal, DNotTotal]) +else: + assert_type(d, Union[DTotal, DNotTotal]) + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testOperatorContainsNarrowsTypedDicts_partialThroughNotRequired] +from __future__ import annotations +from typing import assert_type, Required, NotRequired, TypedDict, Union +from typing_extensions import final + +@final +class D1(TypedDict): + required_key: Required[int] + optional_key: NotRequired[int] + + +@final +class D2(TypedDict): + abc: int + xyz: int + + +d: D1 | D2 + +if 'required_key' in d: + assert_type(d, D1) +else: + assert_type(d, D2) + +if 'optional_key' in d: + assert_type(d, D1) +else: + assert_type(d, Union[D1, D2]) + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + [case testCannotSubclassFinalTypedDict] from typing import TypedDict from typing_extensions import final diff --git a/test-data/unit/fixtures/typing-typeddict.pyi b/test-data/unit/fixtures/typing-typeddict.pyi index 378570b4c19c..b45c31969a00 100644 --- a/test-data/unit/fixtures/typing-typeddict.pyi +++ b/test-data/unit/fixtures/typing-typeddict.pyi @@ -9,6 +9,7 @@ from abc import ABCMeta cast = 0 +assert_type = 0 overload = 0 Any = 0 Union = 0