Skip to content

Commit e714a52

Browse files
committed
'in' can narrow TypedDict unions
1 parent 1a8e6c8 commit e714a52

File tree

2 files changed

+66
-18
lines changed

2 files changed

+66
-18
lines changed

mypy/checker.py

+42-18
Original file line numberDiff line numberDiff line change
@@ -5006,6 +5006,24 @@ def conditional_callable_type_map(
50065006

50075007
return None, {}
50085008

5009+
def _string_in_type_map(self, item_str: str, collection_expr: Expression, t: Type) -> tuple[TypeMap, TypeMap]:
5010+
t = get_proper_type(t)
5011+
if isinstance(t, TypedDictType):
5012+
m = {collection_expr: t}
5013+
if item_str in t.items:
5014+
return m, {}
5015+
else:
5016+
return {}, m
5017+
elif isinstance(t, UnionType):
5018+
if_map, else_map = {}, {}
5019+
for union_item in t.items:
5020+
union_if_map, union_else_map = self._string_in_type_map(item_str, collection_expr, union_item)
5021+
if_map.update(union_if_map)
5022+
else_map.update(union_else_map)
5023+
return if_map, else_map
5024+
else:
5025+
return {}, {}
5026+
50095027
def _is_truthy_type(self, t: ProperType) -> bool:
50105028
return (
50115029
(
@@ -5313,28 +5331,30 @@ def has_no_custom_eq_checks(t: Type) -> bool:
53135331
elif operator in {"in", "not in"}:
53145332
assert len(expr_indices) == 2
53155333
left_index, right_index = expr_indices
5316-
if left_index not in narrowable_operand_index_to_hash:
5317-
continue
5318-
53195334
item_type = operand_types[left_index]
53205335
collection_type = operand_types[right_index]
53215336

5322-
# We only try and narrow away 'None' for now
5323-
if not is_optional(item_type):
5324-
continue
5337+
if left_index in narrowable_operand_index_to_hash:
5338+
# We only try and narrow away 'None' for now
5339+
if not is_optional(item_type):
5340+
continue
5341+
5342+
collection_item_type = get_proper_type(builtin_item_type(collection_type))
5343+
if collection_item_type is None or is_optional(collection_item_type):
5344+
continue
5345+
if (
5346+
isinstance(collection_item_type, Instance)
5347+
and collection_item_type.type.fullname == "builtins.object"
5348+
):
5349+
continue
5350+
if is_overlapping_erased_types(item_type, collection_item_type):
5351+
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
5352+
else:
5353+
continue
5354+
5355+
elif isinstance(item_type.last_known_value, LiteralType) and isinstance(item_type.last_known_value.value, str):
5356+
if_map, else_map = self._string_in_type_map(item_type.last_known_value.value, operands[right_index], collection_type)
53255357

5326-
collection_item_type = get_proper_type(builtin_item_type(collection_type))
5327-
if collection_item_type is None or is_optional(collection_item_type):
5328-
continue
5329-
if (
5330-
isinstance(collection_item_type, Instance)
5331-
and collection_item_type.type.fullname == "builtins.object"
5332-
):
5333-
continue
5334-
if is_overlapping_erased_types(item_type, collection_item_type):
5335-
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
5336-
else:
5337-
continue
53385358
else:
53395359
if_map = {}
53405360
else_map = {}
@@ -5387,6 +5407,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:
53875407
or_conditional_maps(left_if_vars, right_if_vars),
53885408
and_conditional_maps(left_else_vars, right_else_vars),
53895409
)
5410+
elif isinstance(node, OpExpr) and node.op == "in":
5411+
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
5412+
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
5413+
53905414
elif isinstance(node, UnaryExpr) and node.op == "not":
53915415
left, right = self.find_isinstance_check(node.expr)
53925416
return right, left

test-data/unit/check-typeddict.test

+24
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,30 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value"
20122012
[builtins fixtures/dict.pyi]
20132013
[typing fixtures/typing-typeddict.pyi]
20142014

2015+
[case testFinalTypedDictTagged]
2016+
from __future__ import annotations
2017+
from typing import TypedDict
2018+
from typing_extensions import final
2019+
2020+
@final
2021+
class D1(TypedDict):
2022+
foo: int
2023+
2024+
2025+
@final
2026+
class D2(TypedDict):
2027+
bar: int
2028+
2029+
d: D1 | D2
2030+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2031+
if 'foo' in d:
2032+
val = d['foo']
2033+
else:
2034+
val = d['bar']
2035+
2036+
[builtins fixtures/dict.pyi]
2037+
[typing fixtures/typing-typeddict.pyi]
2038+
20152039
[case testCannotSubclassFinalTypedDict]
20162040
from typing import TypedDict
20172041
from typing_extensions import final

0 commit comments

Comments
 (0)