Skip to content

Commit 4642a9c

Browse files
committed
'in' can narrow TypedDict unions
1 parent 1a8e6c8 commit 4642a9c

File tree

2 files changed

+60
-18
lines changed

2 files changed

+60
-18
lines changed

mypy/checker.py

+36-18
Original file line numberDiff line numberDiff line change
@@ -5006,6 +5006,18 @@ def conditional_callable_type_map(
50065006

50075007
return None, {}
50085008

5009+
def _string_in_type_map(self, item_str: str, collection_expr: Expression, t: Type) -> tuple[TypeMap, TypeMap]:
5010+
if_map, else_map = {}, {}
5011+
t = get_proper_type(t)
5012+
if isinstance(t, TypedDictType):
5013+
(if_map if item_str in t.items else else_map)[collection_expr] = t
5014+
elif isinstance(t, UnionType):
5015+
for union_item in t.items:
5016+
union_if_map, union_else_map = self._string_in_type_map(item_str, collection_expr, union_item)
5017+
if_map.update(union_if_map)
5018+
else_map.update(union_else_map)
5019+
return if_map, else_map
5020+
50095021
def _is_truthy_type(self, t: ProperType) -> bool:
50105022
return (
50115023
(
@@ -5313,28 +5325,30 @@ def has_no_custom_eq_checks(t: Type) -> bool:
53135325
elif operator in {"in", "not in"}:
53145326
assert len(expr_indices) == 2
53155327
left_index, right_index = expr_indices
5316-
if left_index not in narrowable_operand_index_to_hash:
5317-
continue
5318-
53195328
item_type = operand_types[left_index]
53205329
collection_type = operand_types[right_index]
53215330

5322-
# We only try and narrow away 'None' for now
5323-
if not is_optional(item_type):
5324-
continue
5331+
if left_index in narrowable_operand_index_to_hash:
5332+
# We only try and narrow away 'None' for now
5333+
if not is_optional(item_type):
5334+
continue
5335+
5336+
collection_item_type = get_proper_type(builtin_item_type(collection_type))
5337+
if collection_item_type is None or is_optional(collection_item_type):
5338+
continue
5339+
if (
5340+
isinstance(collection_item_type, Instance)
5341+
and collection_item_type.type.fullname == "builtins.object"
5342+
):
5343+
continue
5344+
if is_overlapping_erased_types(item_type, collection_item_type):
5345+
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
5346+
else:
5347+
continue
5348+
5349+
elif isinstance(item_type.last_known_value, LiteralType) and isinstance(item_type.last_known_value.value, str):
5350+
if_map, else_map = self._string_in_type_map(item_type.last_known_value.value, operands[right_index], collection_type)
53255351

5326-
collection_item_type = get_proper_type(builtin_item_type(collection_type))
5327-
if collection_item_type is None or is_optional(collection_item_type):
5328-
continue
5329-
if (
5330-
isinstance(collection_item_type, Instance)
5331-
and collection_item_type.type.fullname == "builtins.object"
5332-
):
5333-
continue
5334-
if is_overlapping_erased_types(item_type, collection_item_type):
5335-
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
5336-
else:
5337-
continue
53385352
else:
53395353
if_map = {}
53405354
else_map = {}
@@ -5387,6 +5401,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:
53875401
or_conditional_maps(left_if_vars, right_if_vars),
53885402
and_conditional_maps(left_else_vars, right_else_vars),
53895403
)
5404+
elif isinstance(node, OpExpr) and node.op == "in":
5405+
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
5406+
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
5407+
53905408
elif isinstance(node, UnaryExpr) and node.op == "not":
53915409
left, right = self.find_isinstance_check(node.expr)
53925410
return right, left

test-data/unit/check-typeddict.test

+24
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,30 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value"
20122012
[builtins fixtures/dict.pyi]
20132013
[typing fixtures/typing-typeddict.pyi]
20142014

2015+
[case testFinalTypedDictTagged]
2016+
from __future__ import annotations
2017+
from typing import TypedDict
2018+
from typing_extensions import final
2019+
2020+
@final
2021+
class D1(TypedDict):
2022+
foo: int
2023+
2024+
2025+
@final
2026+
class D2(TypedDict):
2027+
bar: int
2028+
2029+
d: D1 | D2
2030+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2031+
if 'foo' in d:
2032+
val = d['foo']
2033+
else:
2034+
val = d['bar']
2035+
2036+
[builtins fixtures/dict.pyi]
2037+
[typing fixtures/typing-typeddict.pyi]
2038+
20152039
[case testCannotSubclassFinalTypedDict]
20162040
from typing import TypedDict
20172041
from typing_extensions import final

0 commit comments

Comments
 (0)