From ca873608e2f71a0c25349ef1c1faf6f5a6d40a16 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Sat, 8 Oct 2022 00:01:29 -0400 Subject: [PATCH 01/17] 'in' can narrow TypedDict unions --- mypy/checker.py | 89 +++++++++++++++++++++++------ test-data/unit/check-typeddict.test | 41 +++++++++++++ 2 files changed, 112 insertions(+), 18 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 16bbc1c982a6..b9ea110d99b4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5009,6 +5009,44 @@ def conditional_callable_type_map( return None, {} + def contains_operator_right_operand_type_map( + self, item_type: Type, collection_type: Type + ) -> tuple[Type, Type]: + """ + Deduces the type of the right operand of the `in` operator. + For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s). + """ + if_types, else_types = [collection_type], [collection_type] + item_strs = try_getting_str_literals_from_type(item_type) + if item_strs: + if_types, else_types = self._contains_string_right_operand_type_map( + set(item_strs), collection_type + ) + return UnionType.make_union(if_types), UnionType.make_union(else_types) + + def _contains_string_right_operand_type_map( + self, item_strs: set[str], t: Type + ) -> tuple[list[Type], list[Type]]: + t = get_proper_type(t) + if_types: list[Type] = [] + else_types: list[Type] = [] + if isinstance(t, TypedDictType): + if item_strs <= t.items.keys(): + if_types.append(t) + elif item_strs.isdisjoint(t.items.keys()): + else_types.append(t) + else: + if_types.append(t) + else_types.append(t) + elif isinstance(t, UnionType): + for union_item in t.items: + a, b = self._contains_string_right_operand_type_map(item_strs, union_item) + if_types.extend(a) + else_types.extend(b) + else: + if_types = else_types = [t] + return if_types, else_types + def _is_truthy_type(self, t: ProperType) -> bool: return ( ( @@ -5316,28 +5354,39 @@ def has_no_custom_eq_checks(t: Type) -> bool: elif operator in {"in", "not in"}: assert len(expr_indices) == 2 left_index, right_index = expr_indices - if left_index not in narrowable_operand_index_to_hash: - continue - item_type = operand_types[left_index] collection_type = operand_types[right_index] - # We only try and narrow away 'None' for now - if not is_optional(item_type): - continue + if_map, else_map = {}, {} + + if left_index in narrowable_operand_index_to_hash: + # We only try and narrow away 'None' for now + if is_optional(item_type): + collection_item_type = get_proper_type( + builtin_item_type(collection_type) + ) + if ( + collection_item_type is not None + and not is_optional(collection_item_type) + and not ( + isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == "builtins.object" + ) + and is_overlapping_erased_types(item_type, collection_item_type) + ): + if_map[operands[left_index]] = remove_optional(item_type) + + if right_index in narrowable_operand_index_to_hash: + ( + right_if_type, + right_else_type, + ) = self.contains_operator_right_operand_type_map( + item_type, collection_type + ) + expr = operands[right_index] + if_map[expr] = right_if_type + else_map[expr] = right_else_type - collection_item_type = get_proper_type(builtin_item_type(collection_type)) - if collection_item_type is None or is_optional(collection_item_type): - continue - if ( - isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == "builtins.object" - ): - continue - if is_overlapping_erased_types(item_type, collection_item_type): - if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {} - else: - continue else: if_map = {} else_map = {} @@ -5390,6 +5439,10 @@ def has_no_custom_eq_checks(t: Type) -> bool: or_conditional_maps(left_if_vars, right_if_vars), and_conditional_maps(left_else_vars, right_else_vars), ) + elif isinstance(node, OpExpr) and node.op == "in": + left_if_vars, left_else_vars = self.find_isinstance_check(node.left) + right_if_vars, right_else_vars = self.find_isinstance_check(node.right) + elif isinstance(node, UnaryExpr) and node.op == "not": left, right = self.find_isinstance_check(node.expr) return right, left diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 4c68b7b692ff..30e402557b52 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2012,6 +2012,47 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] +[case testFinalTypedDictTagged] +from __future__ import annotations +from typing import Literal, TypedDict +from typing_extensions import final + +@final +class D1(TypedDict): + foo: int + + +@final +class D2(TypedDict): + bar: int + +d: D1 | D2 +val: int + +val = d['foo'] # E: TypedDict "D2" has no key "foo" +if 'foo' in d: + val = d['foo'] +else: + val = d['bar'] + +foo_or_bar: Literal['foo', 'bar'] +if foo_or_bar in d: + val = d['foo'] # E: TypedDict "D2" has no key "foo" + val = d['bar'] # E: TypedDict "D1" has no key "bar" +else: + val = d['foo'] # E: TypedDict "D2" has no key "foo" + val = d['bar'] # E: TypedDict "D1" has no key "bar" + +foo_or_invalid: Literal['foo', 'invalid'] +if foo_or_invalid in d: + val = d['foo'] +else: + val = d['foo'] # E: TypedDict "D2" has no key "foo" + val = d['bar'] # E: TypedDict "D1" has no key "bar" + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + [case testCannotSubclassFinalTypedDict] from typing import TypedDict from typing_extensions import final From 2a733252e34c3d9d3f78ede4264ea47a7d5101e1 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Sat, 15 Oct 2022 01:29:29 -0400 Subject: [PATCH 02/17] take care of TypeVar --- mypy/checker.py | 22 ++++++++++++++-------- test-data/unit/check-typeddict.test | 11 ++++++++++- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index b9ea110d99b4..65c5d3b3d0b0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5011,7 +5011,7 @@ def conditional_callable_type_map( def contains_operator_right_operand_type_map( self, item_type: Type, collection_type: Type - ) -> tuple[Type, Type]: + ) -> tuple[Type | None, Type | None]: """ Deduces the type of the right operand of the `in` operator. For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s). @@ -5022,7 +5022,10 @@ def contains_operator_right_operand_type_map( if_types, else_types = self._contains_string_right_operand_type_map( set(item_strs), collection_type ) - return UnionType.make_union(if_types), UnionType.make_union(else_types) + return ( + UnionType.make_union(if_types) if if_types else None, + UnionType.make_union(else_types) if else_types else None, + ) def _contains_string_right_operand_type_map( self, item_strs: set[str], t: Type @@ -5377,15 +5380,18 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_map[operands[left_index]] = remove_optional(item_type) if right_index in narrowable_operand_index_to_hash: - ( - right_if_type, - right_else_type, - ) = self.contains_operator_right_operand_type_map( + (if_type, else_type) = self.contains_operator_right_operand_type_map( item_type, collection_type ) expr = operands[right_index] - if_map[expr] = right_if_type - else_map[expr] = right_else_type + if if_type is None: + if_map = None + else: + if_map[expr] = if_type + if else_type is None: + else_map = None + else: + else_map[expr] = else_type else: if_map = {} diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 30e402557b52..8c4fe55975e9 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2014,7 +2014,7 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value" [case testFinalTypedDictTagged] from __future__ import annotations -from typing import Literal, TypedDict +from typing import Literal, TypedDict, TypeVar from typing_extensions import final @final @@ -2050,6 +2050,15 @@ else: val = d['foo'] # E: TypedDict "D2" has no key "foo" val = d['bar'] # E: TypedDict "D1" has no key "bar" +TD = TypeVar('TD', D1, D2) + +def f(arg: TD) -> None: + if 'foo' in arg: + val = arg['foo'] + else: + val = arg['bar'] + + [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] From 9671a695663bd3a17f8899baef2e9d84addefd7e Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Sat, 15 Oct 2022 02:34:34 -0400 Subject: [PATCH 03/17] take care of totality --- mypy/checker.py | 19 +++++----- test-data/unit/check-typeddict.test | 59 +++++++++++++++++++++++------ 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 65c5d3b3d0b0..23fca9eabdc4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5020,7 +5020,7 @@ def contains_operator_right_operand_type_map( item_strs = try_getting_str_literals_from_type(item_type) if item_strs: if_types, else_types = self._contains_string_right_operand_type_map( - set(item_strs), collection_type + item_strs, collection_type ) return ( UnionType.make_union(if_types) if if_types else None, @@ -5028,19 +5028,20 @@ def contains_operator_right_operand_type_map( ) def _contains_string_right_operand_type_map( - self, item_strs: set[str], t: Type + self, item_strs: Iterable[str], t: Type ) -> tuple[list[Type], list[Type]]: t = get_proper_type(t) if_types: list[Type] = [] else_types: list[Type] = [] if isinstance(t, TypedDictType): - if item_strs <= t.items.keys(): - if_types.append(t) - elif item_strs.isdisjoint(t.items.keys()): - else_types.append(t) - else: - if_types.append(t) - else_types.append(t) + for key in item_strs: + if key in t.required_keys: + if_types.append(t) + elif key in t.items: + if_types.append(t) + else_types.append(t) + else: + else_types.append(t) elif isinstance(t, UnionType): for union_item in t.items: a, b = self._contains_string_right_operand_type_map(item_strs, union_item) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 8c4fe55975e9..c345349e020c 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2014,7 +2014,7 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value" [case testFinalTypedDictTagged] from __future__ import annotations -from typing import Literal, TypedDict, TypeVar +from typing import assert_type, Literal, TypedDict, TypeVar, Union from typing_extensions import final @final @@ -2029,39 +2029,74 @@ class D2(TypedDict): d: D1 | D2 val: int -val = d['foo'] # E: TypedDict "D2" has no key "foo" if 'foo' in d: - val = d['foo'] + assert_type(d, D1) else: - val = d['bar'] + assert_type(d, D2) foo_or_bar: Literal['foo', 'bar'] if foo_or_bar in d: - val = d['foo'] # E: TypedDict "D2" has no key "foo" - val = d['bar'] # E: TypedDict "D1" has no key "bar" + assert_type(d, Union[D1, D2]) else: - val = d['foo'] # E: TypedDict "D2" has no key "foo" - val = d['bar'] # E: TypedDict "D1" has no key "bar" + assert_type(d, Union[D1, D2]) foo_or_invalid: Literal['foo', 'invalid'] if foo_or_invalid in d: - val = d['foo'] + assert_type(d, D1) else: - val = d['foo'] # E: TypedDict "D2" has no key "foo" - val = d['bar'] # E: TypedDict "D1" has no key "bar" + assert_type(d, Union[D1, D2]) TD = TypeVar('TD', D1, D2) def f(arg: TD) -> None: if 'foo' in arg: - val = arg['foo'] + assert_type(d, Union[D1, D2]) # strangely enough it's seen as a union + val = arg['foo'] # but acts here as D1 else: + assert_type(d, Union[D1, D2]) # ditto here, but D2 val = arg['bar'] [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] +[case testFinalTypedDictTaggedNotRequired] +from __future__ import annotations +from typing import assert_type, Literal, TypedDict, Union +from typing_extensions import final + +@final +class DTotal(TypedDict): + required_key: int + + +@final +class DNotTotal(TypedDict, total=False): + optional_key: int + + +d: DTotal | DNotTotal +val: int + +if 'required_key' in d: + assert_type(d, DTotal) +else: + assert_type(d, DNotTotal) + +if 'optional_key' in d: + assert_type(d, DNotTotal) +else: + assert_type(d, Union[DTotal, DNotTotal]) + +key: Literal['optional_key', 'required_key'] +if key in d: + assert_type(d, Union[DTotal, DNotTotal]) +else: + assert_type(d, Union[DTotal, DNotTotal]) + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + [case testCannotSubclassFinalTypedDict] from typing import TypedDict from typing_extensions import final From 0b3701b80ccdba6b6d98edb9fe78e102e3457f54 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Sat, 15 Oct 2022 09:31:18 -0400 Subject: [PATCH 04/17] add missing assert_type fixture --- test-data/unit/fixtures/typing-typeddict.pyi | 35 +++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/test-data/unit/fixtures/typing-typeddict.pyi b/test-data/unit/fixtures/typing-typeddict.pyi index 378570b4c19c..b746a6db4db3 100644 --- a/test-data/unit/fixtures/typing-typeddict.pyi +++ b/test-data/unit/fixtures/typing-typeddict.pyi @@ -26,35 +26,44 @@ NoReturn = 0 Required = 0 NotRequired = 0 -T = TypeVar('T') -T_co = TypeVar('T_co', covariant=True) -V = TypeVar('V') +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +V = TypeVar("V") # Note: definitions below are different from typeshed, variances are declared # to silence the protocol variance checks. Maybe it is better to use type: ignore? class Sized(Protocol): - def __len__(self) -> int: pass + def __len__(self) -> int: + pass class Iterable(Protocol[T_co]): - def __iter__(self) -> 'Iterator[T_co]': pass + def __iter__(self) -> "Iterator[T_co]": + pass class Iterator(Iterable[T_co], Protocol): - def __next__(self) -> T_co: pass + def __next__(self) -> T_co: + pass class Sequence(Iterable[T_co]): # misc is for explicit Any. - def __getitem__(self, n: Any) -> T_co: pass # type: ignore[misc] + def __getitem__(self, n: Any) -> T_co: + pass # type: ignore[misc] class Mapping(Iterable[T], Generic[T, T_co], metaclass=ABCMeta): - def __getitem__(self, key: T) -> T_co: pass + def __getitem__(self, key: T) -> T_co: + pass @overload - def get(self, k: T) -> Optional[T_co]: pass + def get(self, k: T) -> Optional[T_co]: + pass @overload - def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: pass - def values(self) -> Iterable[T_co]: pass # Approximate return type + def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: + pass + def values(self) -> Iterable[T_co]: + pass # Approximate return type def __len__(self) -> int: ... - def __contains__(self, arg: object) -> int: pass + def __contains__(self, arg: object) -> int: + pass # Fallback type for all typed dicts (does not exist at runtime). class _TypedDict(Mapping[str, object]): @@ -68,3 +77,5 @@ class _TypedDict(Mapping[str, object]): def pop(self, k: NoReturn, default: T = ...) -> object: ... def update(self: T, __m: T) -> None: ... def __delitem__(self, k: NoReturn) -> None: ... + +def assert_type(o, t): ... From 1f4224ae1bcc005b313dfb3f1cf1a3a8208541fe Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Sat, 15 Oct 2022 09:34:21 -0400 Subject: [PATCH 05/17] rename tests slightly --- test-data/unit/check-typeddict.test | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index c345349e020c..c3b93a7cf271 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2012,7 +2012,7 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] -[case testFinalTypedDictTagged] +[case testOperatorContainsNarrowsTotalTypedDicts] from __future__ import annotations from typing import assert_type, Literal, TypedDict, TypeVar, Union from typing_extensions import final @@ -2027,7 +2027,6 @@ class D2(TypedDict): bar: int d: D1 | D2 -val: int if 'foo' in d: assert_type(d, D1) @@ -2049,18 +2048,19 @@ else: TD = TypeVar('TD', D1, D2) def f(arg: TD) -> None: + value: int if 'foo' in arg: assert_type(d, Union[D1, D2]) # strangely enough it's seen as a union - val = arg['foo'] # but acts here as D1 + value = arg['foo'] # but acts here as D1 else: assert_type(d, Union[D1, D2]) # ditto here, but D2 - val = arg['bar'] + value = arg['bar'] [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] -[case testFinalTypedDictTaggedNotRequired] +[case testOperatorContainsNarrowsPartialTypedDicts] from __future__ import annotations from typing import assert_type, Literal, TypedDict, Union from typing_extensions import final From e67c6f8b93c8dfb7084f88f147feaca37f0321d2 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Sat, 15 Oct 2022 11:43:18 -0400 Subject: [PATCH 06/17] fix unintentional changes to typing-typeddict --- test-data/unit/fixtures/typing-typeddict.pyi | 36 +++++++------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/test-data/unit/fixtures/typing-typeddict.pyi b/test-data/unit/fixtures/typing-typeddict.pyi index b746a6db4db3..b45c31969a00 100644 --- a/test-data/unit/fixtures/typing-typeddict.pyi +++ b/test-data/unit/fixtures/typing-typeddict.pyi @@ -9,6 +9,7 @@ from abc import ABCMeta cast = 0 +assert_type = 0 overload = 0 Any = 0 Union = 0 @@ -26,44 +27,35 @@ NoReturn = 0 Required = 0 NotRequired = 0 -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) -V = TypeVar("V") +T = TypeVar('T') +T_co = TypeVar('T_co', covariant=True) +V = TypeVar('V') # Note: definitions below are different from typeshed, variances are declared # to silence the protocol variance checks. Maybe it is better to use type: ignore? class Sized(Protocol): - def __len__(self) -> int: - pass + def __len__(self) -> int: pass class Iterable(Protocol[T_co]): - def __iter__(self) -> "Iterator[T_co]": - pass + def __iter__(self) -> 'Iterator[T_co]': pass class Iterator(Iterable[T_co], Protocol): - def __next__(self) -> T_co: - pass + def __next__(self) -> T_co: pass class Sequence(Iterable[T_co]): # misc is for explicit Any. - def __getitem__(self, n: Any) -> T_co: - pass # type: ignore[misc] + def __getitem__(self, n: Any) -> T_co: pass # type: ignore[misc] class Mapping(Iterable[T], Generic[T, T_co], metaclass=ABCMeta): - def __getitem__(self, key: T) -> T_co: - pass + def __getitem__(self, key: T) -> T_co: pass @overload - def get(self, k: T) -> Optional[T_co]: - pass + def get(self, k: T) -> Optional[T_co]: pass @overload - def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: - pass - def values(self) -> Iterable[T_co]: - pass # Approximate return type + def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: pass + def values(self) -> Iterable[T_co]: pass # Approximate return type def __len__(self) -> int: ... - def __contains__(self, arg: object) -> int: - pass + def __contains__(self, arg: object) -> int: pass # Fallback type for all typed dicts (does not exist at runtime). class _TypedDict(Mapping[str, object]): @@ -77,5 +69,3 @@ class _TypedDict(Mapping[str, object]): def pop(self, k: NoReturn, default: T = ...) -> object: ... def update(self: T, __m: T) -> None: ... def __delitem__(self, k: NoReturn) -> None: ... - -def assert_type(o, t): ... From 88e2c9f37340563d2496fccaff158494ed9cf3bf Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Sat, 15 Oct 2022 12:44:43 -0400 Subject: [PATCH 07/17] add testOperatorContainsNarrowsTypedDicts_unionWithList --- test-data/unit/check-typeddict.test | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index c3b93a7cf271..ac04240208cf 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2012,6 +2012,27 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] +[case testOperatorContainsNarrowsTypedDicts_unionWithList] +from __future__ import annotations +from typing import assert_type, TypedDict, Union +from typing_extensions import final + + +@final +class D(TypedDict): + foo: int + + +d: D | list[str] + +if 'foo' in d: + assert_type(d, Union[D, list[str]]) +else: + assert_type(d, list[str]) + +[builtins fixtures / dict.pyi] +[typing fixtures / typing - typeddict.pyi] + [case testOperatorContainsNarrowsTotalTypedDicts] from __future__ import annotations from typing import assert_type, Literal, TypedDict, TypeVar, Union From f799344581f0b899670ba70da84f57d1880e1470 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Sat, 15 Oct 2022 12:54:04 -0400 Subject: [PATCH 08/17] respect final-ity of TypedDict --- mypy/checker.py | 2 +- mypy/types.py | 4 +++ test-data/unit/check-typeddict.test | 53 ++++++++++++++++++++++++++--- 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 23fca9eabdc4..6f143a1daa3a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5037,7 +5037,7 @@ def _contains_string_right_operand_type_map( for key in item_strs: if key in t.required_keys: if_types.append(t) - elif key in t.items: + elif key in t.items or not t.is_final: if_types.append(t) else_types.append(t) else: diff --git a/mypy/types.py b/mypy/types.py index e322cf02505f..12e15409f45d 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2290,6 +2290,10 @@ def deserialize(cls, data: JsonDict) -> TypedDictType: Instance.deserialize(data["fallback"]), ) + @property + def is_final(self) -> bool: + return self.fallback.type.is_final + def is_anonymous(self) -> bool: return self.fallback.type.fullname in TPDICT_FB_NAMES diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index ac04240208cf..f2f39a9b997f 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2017,7 +2017,6 @@ from __future__ import annotations from typing import assert_type, TypedDict, Union from typing_extensions import final - @final class D(TypedDict): foo: int @@ -2030,10 +2029,10 @@ if 'foo' in d: else: assert_type(d, list[str]) -[builtins fixtures / dict.pyi] -[typing fixtures / typing - typeddict.pyi] +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] -[case testOperatorContainsNarrowsTotalTypedDicts] +[case testOperatorContainsNarrowsTypedDicts_total] from __future__ import annotations from typing import assert_type, Literal, TypedDict, TypeVar, Union from typing_extensions import final @@ -2048,6 +2047,13 @@ class D2(TypedDict): bar: int d: D1 | D2 +opt_d: D1 | None + +if 'foo' in opt_d: + assert_type(opt_d, D1) +else: + assert_type(opt_d, None) + if 'foo' in d: assert_type(d, D1) @@ -2081,7 +2087,44 @@ def f(arg: TD) -> None: [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] -[case testOperatorContainsNarrowsPartialTypedDicts] +[case testOperatorContainsNarrowsTypedDicts_final] +# flags: --warn-unreachable +from __future__ import annotations +from typing import assert_type, Literal, TypedDict, TypeVar, Union +from typing_extensions import final + +@final +class DFinal(TypedDict): + foo: int + + +class DNotFinal(TypedDict): + bar: int + + +d_not_final: DNotFinal + +if 'bar' in d_not_final: + assert_type(d_not_final, DNotFinal) +else: + spam = 'ham' # E: Statement is unreachable + +if 'spam' in d_not_final: + assert_type(d_not_final, DNotFinal) +else: + assert_type(d_not_final, DNotFinal) + +d_union: DFinal | DNotFinal + +if 'foo' in d_union: + assert_type(d_union, Union[DFinal, DNotFinal]) +else: + assert_type(d_union, DNotFinal) + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testOperatorContainsNarrowsTypedDicts_partial] from __future__ import annotations from typing import assert_type, Literal, TypedDict, Union from typing_extensions import final From 20a2e6695597d6ee7e55ef7201f89ed8406f3f77 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Sat, 15 Oct 2022 22:19:38 -0400 Subject: [PATCH 09/17] remove bogus addition --- mypy/checker.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 6f143a1daa3a..b8dc157821cb 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5446,10 +5446,6 @@ def has_no_custom_eq_checks(t: Type) -> bool: or_conditional_maps(left_if_vars, right_if_vars), and_conditional_maps(left_else_vars, right_else_vars), ) - elif isinstance(node, OpExpr) and node.op == "in": - left_if_vars, left_else_vars = self.find_isinstance_check(node.left) - right_if_vars, right_else_vars = self.find_isinstance_check(node.right) - elif isinstance(node, UnaryExpr) and node.op == "not": left, right = self.find_isinstance_check(node.expr) return right, left From 88f03f65362ba97cc2d0b0a4086f1ed760d07df0 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Tue, 18 Oct 2022 21:44:15 -0400 Subject: [PATCH 10/17] use less risky pattern --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index b8dc157821cb..fc76ed8110df 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5048,7 +5048,7 @@ def _contains_string_right_operand_type_map( if_types.extend(a) else_types.extend(b) else: - if_types = else_types = [t] + if_types, else_types = [t], [t] return if_types, else_types def _is_truthy_type(self, t: ProperType) -> bool: From a47bc04e0f1f916437df4b6cbcb00d35a422a6de Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Wed, 19 Oct 2022 11:58:58 -0400 Subject: [PATCH 11/17] 'D1 | None' -> 'D1 | list [str]' --- test-data/unit/check-typeddict.test | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index f2f39a9b997f..eb732805c38b 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2046,14 +2046,14 @@ class D1(TypedDict): class D2(TypedDict): bar: int -d: D1 | D2 -opt_d: D1 | None +d_or_list: D1 | list[str] -if 'foo' in opt_d: - assert_type(opt_d, D1) +if 'foo' in d_or_list: + assert_type(d_or_list, Union[D1, list[str]]) else: - assert_type(opt_d, None) + assert_type(d_or_list, list[str]) +d: D1 | D2 if 'foo' in d: assert_type(d, D1) From f72a634658e823f27a2c11a19c491d21ceb9e00c Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Wed, 19 Oct 2022 11:59:24 -0400 Subject: [PATCH 12/17] remove unused stuff in tests --- test-data/unit/check-typeddict.test | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index eb732805c38b..009dbe0bdf9a 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2090,7 +2090,7 @@ def f(arg: TD) -> None: [case testOperatorContainsNarrowsTypedDicts_final] # flags: --warn-unreachable from __future__ import annotations -from typing import assert_type, Literal, TypedDict, TypeVar, Union +from typing import assert_type, TypedDict, Union from typing_extensions import final @final @@ -2140,7 +2140,6 @@ class DNotTotal(TypedDict, total=False): d: DTotal | DNotTotal -val: int if 'required_key' in d: assert_type(d, DTotal) From bf4364f162baa6c02998b22aea61432e5d01a77e Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Wed, 19 Oct 2022 11:59:51 -0400 Subject: [PATCH 13/17] test totality through both total= and (Not)Required --- test-data/unit/check-typeddict.test | 34 ++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 009dbe0bdf9a..6b0b6736bd50 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2124,7 +2124,7 @@ else: [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] -[case testOperatorContainsNarrowsTypedDicts_partial] +[case testOperatorContainsNarrowsTypedDicts_partialThroughTotalFalse] from __future__ import annotations from typing import assert_type, Literal, TypedDict, Union from typing_extensions import final @@ -2160,6 +2160,38 @@ else: [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] +[case testOperatorContainsNarrowsTypedDicts_partialThroughNotRequired] +from __future__ import annotations +from typing import assert_type, Required, NotRequired, TypedDict, Union +from typing_extensions import final + +@final +class D1(TypedDict): + required_key: Required[int] + optional_key: NotRequired[int] + + +@final +class D2(TypedDict): + abc: int + xyz: int + + +d: D1 | D2 + +if 'required_key' in d: + assert_type(d, D1) +else: + assert_type(d, D2) + +if 'optional_key' in d: + assert_type(d, D1) +else: + assert_type(d, Union[D1, D2]) + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + [case testCannotSubclassFinalTypedDict] from typing import TypedDict from typing_extensions import final From 062fcb19b87525eeff2b7b8c18d9288d13a8e649 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Wed, 19 Oct 2022 12:02:54 -0400 Subject: [PATCH 14/17] add spam in d_final test --- test-data/unit/check-typeddict.test | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 6b0b6736bd50..0b03e77e17ec 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2114,6 +2114,13 @@ if 'spam' in d_not_final: else: assert_type(d_not_final, DNotFinal) +d_final: DFinal + +if 'spam' in d_final: + spam = 'ham' # E: Statement is unreachable +else: + assert_type(d_final, DFinal) + d_union: DFinal | DNotFinal if 'foo' in d_union: From 520df60f646787950e1eb6dd3796a16ba2056fd7 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Wed, 2 Nov 2022 22:23:35 -0400 Subject: [PATCH 15/17] update tests per hauntsaninja's code review --- test-data/unit/check-typeddict.test | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 0b03e77e17ec..45d23734fce4 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2022,12 +2022,14 @@ class D(TypedDict): foo: int -d: D | list[str] +d_or_list: D | list[str] -if 'foo' in d: - assert_type(d, Union[D, list[str]]) +if 'foo' in d_or_list: + assert_type(d_or_list, Union[D, list[str]]) +elif 'bar' in d_or_list: + assert_type(d_or_list, list[str]) else: - assert_type(d, list[str]) + assert_type(d_or_list, list[str]) [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] @@ -2046,12 +2048,6 @@ class D1(TypedDict): class D2(TypedDict): bar: int -d_or_list: D1 | list[str] - -if 'foo' in d_or_list: - assert_type(d_or_list, Union[D1, list[str]]) -else: - assert_type(d_or_list, list[str]) d: D1 | D2 @@ -2077,11 +2073,9 @@ TD = TypeVar('TD', D1, D2) def f(arg: TD) -> None: value: int if 'foo' in arg: - assert_type(d, Union[D1, D2]) # strangely enough it's seen as a union - value = arg['foo'] # but acts here as D1 + assert_type(arg['foo'], int) else: - assert_type(d, Union[D1, D2]) # ditto here, but D2 - value = arg['bar'] + assert_type(arg['bar'], int) [builtins fixtures/dict.pyi] From fb564eb75bbb4a76be08c6763a5b8da69bc724ee Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Thu, 3 Nov 2022 21:57:16 -0400 Subject: [PATCH 16/17] clarify we don't narrow the left operand --- test-data/unit/check-typeddict.test | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index b90b87119ccd..1a0b478db0f2 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2078,8 +2078,12 @@ else: foo_or_invalid: Literal['foo', 'invalid'] if foo_or_invalid in d: assert_type(d, D1) + # won't narrow 'foo_or_invalid' + assert_type(foo_or_invalid, Literal['foo', 'invalid']) else: assert_type(d, Union[D1, D2]) + # won't narrow 'foo_or_invalid' + assert_type(foo_or_invalid, Literal['foo', 'invalid']) TD = TypeVar('TD', D1, D2) From 178483e6de1d6c133c1466d4a5f8ff76e9187355 Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Fri, 4 Nov 2022 00:00:02 -0400 Subject: [PATCH 17/17] no more recursion + similar naming to conditional_types... --- mypy/checker.py | 71 +++++++++++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 9367541612d9..af2db52eae02 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5017,48 +5017,45 @@ def conditional_callable_type_map( return None, {} - def contains_operator_right_operand_type_map( - self, item_type: Type, collection_type: Type + def conditional_types_for_iterable( + self, item_type: Type, iterable_type: Type ) -> tuple[Type | None, Type | None]: """ - Deduces the type of the right operand of the `in` operator. + Narrows the type of `iterable_type` based on the type of `item_type`. For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s). """ - if_types, else_types = [collection_type], [collection_type] - item_strs = try_getting_str_literals_from_type(item_type) - if item_strs: - if_types, else_types = self._contains_string_right_operand_type_map( - item_strs, collection_type - ) + if_types: list[Type] = [] + else_types: list[Type] = [] + + iterable_type = get_proper_type(iterable_type) + if isinstance(iterable_type, UnionType): + possible_iterable_types = get_proper_types(iterable_type.relevant_items()) + else: + possible_iterable_types = [iterable_type] + + item_str_literals = try_getting_str_literals_from_type(item_type) + + for possible_iterable_type in possible_iterable_types: + if item_str_literals and isinstance(possible_iterable_type, TypedDictType): + for key in item_str_literals: + if key in possible_iterable_type.required_keys: + if_types.append(possible_iterable_type) + elif ( + key in possible_iterable_type.items or not possible_iterable_type.is_final + ): + if_types.append(possible_iterable_type) + else_types.append(possible_iterable_type) + else: + else_types.append(possible_iterable_type) + else: + if_types.append(possible_iterable_type) + else_types.append(possible_iterable_type) + return ( UnionType.make_union(if_types) if if_types else None, UnionType.make_union(else_types) if else_types else None, ) - def _contains_string_right_operand_type_map( - self, item_strs: Iterable[str], t: Type - ) -> tuple[list[Type], list[Type]]: - t = get_proper_type(t) - if_types: list[Type] = [] - else_types: list[Type] = [] - if isinstance(t, TypedDictType): - for key in item_strs: - if key in t.required_keys: - if_types.append(t) - elif key in t.items or not t.is_final: - if_types.append(t) - else_types.append(t) - else: - else_types.append(t) - elif isinstance(t, UnionType): - for union_item in t.items: - a, b = self._contains_string_right_operand_type_map(item_strs, union_item) - if_types.extend(a) - else_types.extend(b) - else: - if_types, else_types = [t], [t] - return if_types, else_types - def _is_truthy_type(self, t: ProperType) -> bool: return ( ( @@ -5367,7 +5364,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: assert len(expr_indices) == 2 left_index, right_index = expr_indices item_type = operand_types[left_index] - collection_type = operand_types[right_index] + iterable_type = operand_types[right_index] if_map, else_map = {}, {} @@ -5375,7 +5372,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: # We only try and narrow away 'None' for now if is_optional(item_type): collection_item_type = get_proper_type( - builtin_item_type(collection_type) + builtin_item_type(iterable_type) ) if ( collection_item_type is not None @@ -5389,8 +5386,8 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_map[operands[left_index]] = remove_optional(item_type) if right_index in narrowable_operand_index_to_hash: - (if_type, else_type) = self.contains_operator_right_operand_type_map( - item_type, collection_type + if_type, else_type = self.conditional_types_for_iterable( + item_type, iterable_type ) expr = operands[right_index] if if_type is None: