Skip to content

Commit ca87360

Browse files
committed
'in' can narrow TypedDict unions
1 parent abc9d15 commit ca87360

File tree

2 files changed

+112
-18
lines changed

2 files changed

+112
-18
lines changed

mypy/checker.py

+71-18
Original file line numberDiff line numberDiff line change
@@ -5009,6 +5009,44 @@ def conditional_callable_type_map(
50095009

50105010
return None, {}
50115011

5012+
def contains_operator_right_operand_type_map(
5013+
self, item_type: Type, collection_type: Type
5014+
) -> tuple[Type, Type]:
5015+
"""
5016+
Deduces the type of the right operand of the `in` operator.
5017+
For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s).
5018+
"""
5019+
if_types, else_types = [collection_type], [collection_type]
5020+
item_strs = try_getting_str_literals_from_type(item_type)
5021+
if item_strs:
5022+
if_types, else_types = self._contains_string_right_operand_type_map(
5023+
set(item_strs), collection_type
5024+
)
5025+
return UnionType.make_union(if_types), UnionType.make_union(else_types)
5026+
5027+
def _contains_string_right_operand_type_map(
5028+
self, item_strs: set[str], t: Type
5029+
) -> tuple[list[Type], list[Type]]:
5030+
t = get_proper_type(t)
5031+
if_types: list[Type] = []
5032+
else_types: list[Type] = []
5033+
if isinstance(t, TypedDictType):
5034+
if item_strs <= t.items.keys():
5035+
if_types.append(t)
5036+
elif item_strs.isdisjoint(t.items.keys()):
5037+
else_types.append(t)
5038+
else:
5039+
if_types.append(t)
5040+
else_types.append(t)
5041+
elif isinstance(t, UnionType):
5042+
for union_item in t.items:
5043+
a, b = self._contains_string_right_operand_type_map(item_strs, union_item)
5044+
if_types.extend(a)
5045+
else_types.extend(b)
5046+
else:
5047+
if_types = else_types = [t]
5048+
return if_types, else_types
5049+
50125050
def _is_truthy_type(self, t: ProperType) -> bool:
50135051
return (
50145052
(
@@ -5316,28 +5354,39 @@ def has_no_custom_eq_checks(t: Type) -> bool:
53165354
elif operator in {"in", "not in"}:
53175355
assert len(expr_indices) == 2
53185356
left_index, right_index = expr_indices
5319-
if left_index not in narrowable_operand_index_to_hash:
5320-
continue
5321-
53225357
item_type = operand_types[left_index]
53235358
collection_type = operand_types[right_index]
53245359

5325-
# We only try and narrow away 'None' for now
5326-
if not is_optional(item_type):
5327-
continue
5360+
if_map, else_map = {}, {}
5361+
5362+
if left_index in narrowable_operand_index_to_hash:
5363+
# We only try and narrow away 'None' for now
5364+
if is_optional(item_type):
5365+
collection_item_type = get_proper_type(
5366+
builtin_item_type(collection_type)
5367+
)
5368+
if (
5369+
collection_item_type is not None
5370+
and not is_optional(collection_item_type)
5371+
and not (
5372+
isinstance(collection_item_type, Instance)
5373+
and collection_item_type.type.fullname == "builtins.object"
5374+
)
5375+
and is_overlapping_erased_types(item_type, collection_item_type)
5376+
):
5377+
if_map[operands[left_index]] = remove_optional(item_type)
5378+
5379+
if right_index in narrowable_operand_index_to_hash:
5380+
(
5381+
right_if_type,
5382+
right_else_type,
5383+
) = self.contains_operator_right_operand_type_map(
5384+
item_type, collection_type
5385+
)
5386+
expr = operands[right_index]
5387+
if_map[expr] = right_if_type
5388+
else_map[expr] = right_else_type
53285389

5329-
collection_item_type = get_proper_type(builtin_item_type(collection_type))
5330-
if collection_item_type is None or is_optional(collection_item_type):
5331-
continue
5332-
if (
5333-
isinstance(collection_item_type, Instance)
5334-
and collection_item_type.type.fullname == "builtins.object"
5335-
):
5336-
continue
5337-
if is_overlapping_erased_types(item_type, collection_item_type):
5338-
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
5339-
else:
5340-
continue
53415390
else:
53425391
if_map = {}
53435392
else_map = {}
@@ -5390,6 +5439,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:
53905439
or_conditional_maps(left_if_vars, right_if_vars),
53915440
and_conditional_maps(left_else_vars, right_else_vars),
53925441
)
5442+
elif isinstance(node, OpExpr) and node.op == "in":
5443+
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
5444+
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
5445+
53935446
elif isinstance(node, UnaryExpr) and node.op == "not":
53945447
left, right = self.find_isinstance_check(node.expr)
53955448
return right, left

test-data/unit/check-typeddict.test

+41
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,47 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value"
20122012
[builtins fixtures/dict.pyi]
20132013
[typing fixtures/typing-typeddict.pyi]
20142014

2015+
[case testFinalTypedDictTagged]
2016+
from __future__ import annotations
2017+
from typing import Literal, TypedDict
2018+
from typing_extensions import final
2019+
2020+
@final
2021+
class D1(TypedDict):
2022+
foo: int
2023+
2024+
2025+
@final
2026+
class D2(TypedDict):
2027+
bar: int
2028+
2029+
d: D1 | D2
2030+
val: int
2031+
2032+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2033+
if 'foo' in d:
2034+
val = d['foo']
2035+
else:
2036+
val = d['bar']
2037+
2038+
foo_or_bar: Literal['foo', 'bar']
2039+
if foo_or_bar in d:
2040+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2041+
val = d['bar'] # E: TypedDict "D1" has no key "bar"
2042+
else:
2043+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2044+
val = d['bar'] # E: TypedDict "D1" has no key "bar"
2045+
2046+
foo_or_invalid: Literal['foo', 'invalid']
2047+
if foo_or_invalid in d:
2048+
val = d['foo']
2049+
else:
2050+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2051+
val = d['bar'] # E: TypedDict "D1" has no key "bar"
2052+
2053+
[builtins fixtures/dict.pyi]
2054+
[typing fixtures/typing-typeddict.pyi]
2055+
20152056
[case testCannotSubclassFinalTypedDict]
20162057
from typing import TypedDict
20172058
from typing_extensions import final

0 commit comments

Comments
 (0)