Skip to content

Commit 41574e0

Browse files
authored
Allow 'in' to narrow TypedDict unions (#13838)
`in` could narrow unions of TypeDicts, e.g. ```python class A(TypedDict) foo: int @Final class B(TypedDict): bar: int union: Union[A, B] = ... value: int if 'foo' in union: # Cannot be a B as it is final and has no "foo" field, so must be an A value = union['foo'] else: # Cannot be an A as those went to the if branch value = union['bar'] ```
1 parent 7ed4f5e commit 41574e0

File tree

4 files changed

+262
-19
lines changed

4 files changed

+262
-19
lines changed

mypy/checker.py

+72-19
Original file line numberDiff line numberDiff line change
@@ -5097,6 +5097,45 @@ def conditional_callable_type_map(
50975097

50985098
return None, {}
50995099

5100+
def conditional_types_for_iterable(
5101+
self, item_type: Type, iterable_type: Type
5102+
) -> tuple[Type | None, Type | None]:
5103+
"""
5104+
Narrows the type of `iterable_type` based on the type of `item_type`.
5105+
For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s).
5106+
"""
5107+
if_types: list[Type] = []
5108+
else_types: list[Type] = []
5109+
5110+
iterable_type = get_proper_type(iterable_type)
5111+
if isinstance(iterable_type, UnionType):
5112+
possible_iterable_types = get_proper_types(iterable_type.relevant_items())
5113+
else:
5114+
possible_iterable_types = [iterable_type]
5115+
5116+
item_str_literals = try_getting_str_literals_from_type(item_type)
5117+
5118+
for possible_iterable_type in possible_iterable_types:
5119+
if item_str_literals and isinstance(possible_iterable_type, TypedDictType):
5120+
for key in item_str_literals:
5121+
if key in possible_iterable_type.required_keys:
5122+
if_types.append(possible_iterable_type)
5123+
elif (
5124+
key in possible_iterable_type.items or not possible_iterable_type.is_final
5125+
):
5126+
if_types.append(possible_iterable_type)
5127+
else_types.append(possible_iterable_type)
5128+
else:
5129+
else_types.append(possible_iterable_type)
5130+
else:
5131+
if_types.append(possible_iterable_type)
5132+
else_types.append(possible_iterable_type)
5133+
5134+
return (
5135+
UnionType.make_union(if_types) if if_types else None,
5136+
UnionType.make_union(else_types) if else_types else None,
5137+
)
5138+
51005139
def _is_truthy_type(self, t: ProperType) -> bool:
51015140
return (
51025141
(
@@ -5412,28 +5451,42 @@ def has_no_custom_eq_checks(t: Type) -> bool:
54125451
elif operator in {"in", "not in"}:
54135452
assert len(expr_indices) == 2
54145453
left_index, right_index = expr_indices
5415-
if left_index not in narrowable_operand_index_to_hash:
5416-
continue
5417-
54185454
item_type = operand_types[left_index]
5419-
collection_type = operand_types[right_index]
5455+
iterable_type = operand_types[right_index]
54205456

5421-
# We only try and narrow away 'None' for now
5422-
if not is_optional(item_type):
5423-
continue
5457+
if_map, else_map = {}, {}
5458+
5459+
if left_index in narrowable_operand_index_to_hash:
5460+
# We only try and narrow away 'None' for now
5461+
if is_optional(item_type):
5462+
collection_item_type = get_proper_type(
5463+
builtin_item_type(iterable_type)
5464+
)
5465+
if (
5466+
collection_item_type is not None
5467+
and not is_optional(collection_item_type)
5468+
and not (
5469+
isinstance(collection_item_type, Instance)
5470+
and collection_item_type.type.fullname == "builtins.object"
5471+
)
5472+
and is_overlapping_erased_types(item_type, collection_item_type)
5473+
):
5474+
if_map[operands[left_index]] = remove_optional(item_type)
5475+
5476+
if right_index in narrowable_operand_index_to_hash:
5477+
if_type, else_type = self.conditional_types_for_iterable(
5478+
item_type, iterable_type
5479+
)
5480+
expr = operands[right_index]
5481+
if if_type is None:
5482+
if_map = None
5483+
else:
5484+
if_map[expr] = if_type
5485+
if else_type is None:
5486+
else_map = None
5487+
else:
5488+
else_map[expr] = else_type
54245489

5425-
collection_item_type = get_proper_type(builtin_item_type(collection_type))
5426-
if collection_item_type is None or is_optional(collection_item_type):
5427-
continue
5428-
if (
5429-
isinstance(collection_item_type, Instance)
5430-
and collection_item_type.type.fullname == "builtins.object"
5431-
):
5432-
continue
5433-
if is_overlapping_erased_types(item_type, collection_item_type):
5434-
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
5435-
else:
5436-
continue
54375490
else:
54385491
if_map = {}
54395492
else_map = {}

mypy/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -2334,6 +2334,10 @@ def deserialize(cls, data: JsonDict) -> TypedDictType:
23342334
Instance.deserialize(data["fallback"]),
23352335
)
23362336

2337+
@property
2338+
def is_final(self) -> bool:
2339+
return self.fallback.type.is_final
2340+
23372341
def is_anonymous(self) -> bool:
23382342
return self.fallback.type.fullname in TPDICT_FB_NAMES
23392343

test-data/unit/check-typeddict.test

+185
Original file line numberDiff line numberDiff line change
@@ -2025,6 +2025,191 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value"
20252025
[builtins fixtures/dict.pyi]
20262026
[typing fixtures/typing-typeddict.pyi]
20272027

2028+
[case testOperatorContainsNarrowsTypedDicts_unionWithList]
2029+
from __future__ import annotations
2030+
from typing import assert_type, TypedDict, Union
2031+
from typing_extensions import final
2032+
2033+
@final
2034+
class D(TypedDict):
2035+
foo: int
2036+
2037+
2038+
d_or_list: D | list[str]
2039+
2040+
if 'foo' in d_or_list:
2041+
assert_type(d_or_list, Union[D, list[str]])
2042+
elif 'bar' in d_or_list:
2043+
assert_type(d_or_list, list[str])
2044+
else:
2045+
assert_type(d_or_list, list[str])
2046+
2047+
[builtins fixtures/dict.pyi]
2048+
[typing fixtures/typing-typeddict.pyi]
2049+
2050+
[case testOperatorContainsNarrowsTypedDicts_total]
2051+
from __future__ import annotations
2052+
from typing import assert_type, Literal, TypedDict, TypeVar, Union
2053+
from typing_extensions import final
2054+
2055+
@final
2056+
class D1(TypedDict):
2057+
foo: int
2058+
2059+
2060+
@final
2061+
class D2(TypedDict):
2062+
bar: int
2063+
2064+
2065+
d: D1 | D2
2066+
2067+
if 'foo' in d:
2068+
assert_type(d, D1)
2069+
else:
2070+
assert_type(d, D2)
2071+
2072+
foo_or_bar: Literal['foo', 'bar']
2073+
if foo_or_bar in d:
2074+
assert_type(d, Union[D1, D2])
2075+
else:
2076+
assert_type(d, Union[D1, D2])
2077+
2078+
foo_or_invalid: Literal['foo', 'invalid']
2079+
if foo_or_invalid in d:
2080+
assert_type(d, D1)
2081+
# won't narrow 'foo_or_invalid'
2082+
assert_type(foo_or_invalid, Literal['foo', 'invalid'])
2083+
else:
2084+
assert_type(d, Union[D1, D2])
2085+
# won't narrow 'foo_or_invalid'
2086+
assert_type(foo_or_invalid, Literal['foo', 'invalid'])
2087+
2088+
TD = TypeVar('TD', D1, D2)
2089+
2090+
def f(arg: TD) -> None:
2091+
value: int
2092+
if 'foo' in arg:
2093+
assert_type(arg['foo'], int)
2094+
else:
2095+
assert_type(arg['bar'], int)
2096+
2097+
2098+
[builtins fixtures/dict.pyi]
2099+
[typing fixtures/typing-typeddict.pyi]
2100+
2101+
[case testOperatorContainsNarrowsTypedDicts_final]
2102+
# flags: --warn-unreachable
2103+
from __future__ import annotations
2104+
from typing import assert_type, TypedDict, Union
2105+
from typing_extensions import final
2106+
2107+
@final
2108+
class DFinal(TypedDict):
2109+
foo: int
2110+
2111+
2112+
class DNotFinal(TypedDict):
2113+
bar: int
2114+
2115+
2116+
d_not_final: DNotFinal
2117+
2118+
if 'bar' in d_not_final:
2119+
assert_type(d_not_final, DNotFinal)
2120+
else:
2121+
spam = 'ham' # E: Statement is unreachable
2122+
2123+
if 'spam' in d_not_final:
2124+
assert_type(d_not_final, DNotFinal)
2125+
else:
2126+
assert_type(d_not_final, DNotFinal)
2127+
2128+
d_final: DFinal
2129+
2130+
if 'spam' in d_final:
2131+
spam = 'ham' # E: Statement is unreachable
2132+
else:
2133+
assert_type(d_final, DFinal)
2134+
2135+
d_union: DFinal | DNotFinal
2136+
2137+
if 'foo' in d_union:
2138+
assert_type(d_union, Union[DFinal, DNotFinal])
2139+
else:
2140+
assert_type(d_union, DNotFinal)
2141+
2142+
[builtins fixtures/dict.pyi]
2143+
[typing fixtures/typing-typeddict.pyi]
2144+
2145+
[case testOperatorContainsNarrowsTypedDicts_partialThroughTotalFalse]
2146+
from __future__ import annotations
2147+
from typing import assert_type, Literal, TypedDict, Union
2148+
from typing_extensions import final
2149+
2150+
@final
2151+
class DTotal(TypedDict):
2152+
required_key: int
2153+
2154+
2155+
@final
2156+
class DNotTotal(TypedDict, total=False):
2157+
optional_key: int
2158+
2159+
2160+
d: DTotal | DNotTotal
2161+
2162+
if 'required_key' in d:
2163+
assert_type(d, DTotal)
2164+
else:
2165+
assert_type(d, DNotTotal)
2166+
2167+
if 'optional_key' in d:
2168+
assert_type(d, DNotTotal)
2169+
else:
2170+
assert_type(d, Union[DTotal, DNotTotal])
2171+
2172+
key: Literal['optional_key', 'required_key']
2173+
if key in d:
2174+
assert_type(d, Union[DTotal, DNotTotal])
2175+
else:
2176+
assert_type(d, Union[DTotal, DNotTotal])
2177+
2178+
[builtins fixtures/dict.pyi]
2179+
[typing fixtures/typing-typeddict.pyi]
2180+
2181+
[case testOperatorContainsNarrowsTypedDicts_partialThroughNotRequired]
2182+
from __future__ import annotations
2183+
from typing import assert_type, Required, NotRequired, TypedDict, Union
2184+
from typing_extensions import final
2185+
2186+
@final
2187+
class D1(TypedDict):
2188+
required_key: Required[int]
2189+
optional_key: NotRequired[int]
2190+
2191+
2192+
@final
2193+
class D2(TypedDict):
2194+
abc: int
2195+
xyz: int
2196+
2197+
2198+
d: D1 | D2
2199+
2200+
if 'required_key' in d:
2201+
assert_type(d, D1)
2202+
else:
2203+
assert_type(d, D2)
2204+
2205+
if 'optional_key' in d:
2206+
assert_type(d, D1)
2207+
else:
2208+
assert_type(d, Union[D1, D2])
2209+
2210+
[builtins fixtures/dict.pyi]
2211+
[typing fixtures/typing-typeddict.pyi]
2212+
20282213
[case testCannotSubclassFinalTypedDict]
20292214
from typing import TypedDict
20302215
from typing_extensions import final

test-data/unit/fixtures/typing-typeddict.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from abc import ABCMeta
1010

1111
cast = 0
12+
assert_type = 0
1213
overload = 0
1314
Any = 0
1415
Union = 0

0 commit comments

Comments
 (0)