diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 1c25b8ea7a12..c2cf226ef210 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4188,6 +4188,17 @@ def fast_dict_type(self, e: DictExpr) -> Type | None: self.resolved_type[e] = dt return dt + def check_typeddict_literal_in_context( + self, e: DictExpr, typeddict_context: TypedDictType + ) -> Type: + orig_ret_type = self.check_typeddict_call_with_dict( + callee=typeddict_context, kwargs=e, context=e, orig_callee=None + ) + ret_type = get_proper_type(orig_ret_type) + if isinstance(ret_type, TypedDictType): + return ret_type.copy_modified() + return typeddict_context.copy_modified() + def visit_dict_expr(self, e: DictExpr) -> Type: """Type check a dict expression. @@ -4197,15 +4208,20 @@ def visit_dict_expr(self, e: DictExpr) -> Type: # an error, but returns the TypedDict type that matches the literal it found # that would cause a second error when that TypedDict type is returned upstream # to avoid the second error, we always return TypedDict type that was requested - typeddict_context = self.find_typeddict_context(self.type_context[-1], e) - if typeddict_context: - orig_ret_type = self.check_typeddict_call_with_dict( - callee=typeddict_context, kwargs=e, context=e, orig_callee=None - ) - ret_type = get_proper_type(orig_ret_type) - if isinstance(ret_type, TypedDictType): - return ret_type.copy_modified() - return typeddict_context.copy_modified() + typeddict_contexts = self.find_typeddict_context(self.type_context[-1], e) + if typeddict_contexts: + if len(typeddict_contexts) == 1: + return self.check_typeddict_literal_in_context(e, typeddict_contexts[0]) + # Multiple items union, check if at least one of them matches cleanly. + for typeddict_context in typeddict_contexts: + with self.msg.filter_errors() as err, self.chk.local_type_map() as tmap: + ret_type = self.check_typeddict_literal_in_context(e, typeddict_context) + if err.has_new_errors(): + continue + self.chk.store_types(tmap) + return ret_type + # No item matched without an error, so we can't unambiguously choose the item. + self.msg.typeddict_context_ambiguous(typeddict_contexts, e) # fast path attempt dt = self.fast_dict_type(e) @@ -4271,26 +4287,20 @@ def visit_dict_expr(self, e: DictExpr) -> Type: def find_typeddict_context( self, context: Type | None, dict_expr: DictExpr - ) -> TypedDictType | None: + ) -> list[TypedDictType]: context = get_proper_type(context) if isinstance(context, TypedDictType): - return context + return [context] elif isinstance(context, UnionType): items = [] for item in context.items: - item_context = self.find_typeddict_context(item, dict_expr) - if item_context is not None and self.match_typeddict_call_with_dict( - item_context, dict_expr, dict_expr - ): - items.append(item_context) - if len(items) == 1: - # Only one union item is valid TypedDict for the given dict_expr, so use the - # context as it's unambiguous. - return items[0] - if len(items) > 1: - self.msg.typeddict_context_ambiguous(items, dict_expr) + item_contexts = self.find_typeddict_context(item, dict_expr) + for item_context in item_contexts: + if self.match_typeddict_call_with_dict(item_context, dict_expr, dict_expr): + items.append(item_context) + return items # No TypedDict type in context. - return None + return [] def visit_lambda_expr(self, e: LambdaExpr) -> Type: """Type check lambda expression.""" diff --git a/mypy/messages.py b/mypy/messages.py index 5d8bf79ec8a3..94a97f696b6c 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1705,7 +1705,9 @@ def typeddict_key_not_found( def typeddict_context_ambiguous(self, types: list[TypedDictType], context: Context) -> None: formatted_types = ", ".join(list(format_type_distinctly(*types))) - self.fail(f"Type of TypedDict is ambiguous, could be any of ({formatted_types})", context) + self.fail( + f"Type of TypedDict is ambiguous, none of ({formatted_types}) matches cleanly", context + ) def typeddict_key_cannot_be_deleted( self, typ: TypedDictType, item_name: str, context: Context diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index e426b8a7630b..70ff6a4a6759 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -895,15 +895,25 @@ c: Union[A, B] = {'@type': 'a-type', 'a': 'Test'} reveal_type(c) # N: Revealed type is "Union[TypedDict('__main__.A', {'@type': Literal['a-type'], 'a': builtins.str}), TypedDict('__main__.B', {'@type': Literal['b-type'], 'b': builtins.int})]" [builtins fixtures/dict.pyi] -[case testTypedDictUnionAmbiguousCase] +[case testTypedDictUnionAmbiguousCaseBothMatch] from typing import Union, Mapping, Any, cast from typing_extensions import TypedDict, Literal -A = TypedDict('A', {'@type': Literal['a-type'], 'a': str}) -B = TypedDict('B', {'@type': Literal['a-type'], 'a': str}) +A = TypedDict('A', {'@type': Literal['a-type'], 'value': str}) +B = TypedDict('B', {'@type': Literal['b-type'], 'value': str}) + +c: Union[A, B] = {'@type': 'a-type', 'value': 'Test'} +[builtins fixtures/dict.pyi] + +[case testTypedDictUnionAmbiguousCaseNoMatch] +from typing import Union, Mapping, Any, cast +from typing_extensions import TypedDict, Literal -c: Union[A, B] = {'@type': 'a-type', 'a': 'Test'} # E: Type of TypedDict is ambiguous, could be any of ("A", "B") \ - # E: Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "Union[A, B]") +A = TypedDict('A', {'@type': Literal['a-type'], 'value': int}) +B = TypedDict('B', {'@type': Literal['b-type'], 'value': int}) + +c: Union[A, B] = {'@type': 'a-type', 'value': 'Test'} # E: Type of TypedDict is ambiguous, none of ("A", "B") matches cleanly \ + # E: Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "Union[A, B]") [builtins fixtures/dict.pyi] -- Use dict literals @@ -2786,3 +2796,79 @@ TDC = TypedDict("TDC", {"val": int, "next": Optional[Self]}) # E: Self type can [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + +[case testUnionOfEquivalentTypedDictsInferred] +from typing import TypedDict, Dict + +D = TypedDict("D", {"foo": int}, total=False) + +def f(d: Dict[str, D]) -> None: + args = d["a"] + args.update(d.get("b", {})) # OK +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testUnionOfEquivalentTypedDictsDeclared] +from typing import TypedDict, Union + +class A(TypedDict, total=False): + name: str +class B(TypedDict, total=False): + name: str + +def foo(data: Union[A, B]) -> None: ... +foo({"name": "Robert"}) # OK +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testUnionOfEquivalentTypedDictsEmpty] +from typing import TypedDict, Union + +class Foo(TypedDict, total=False): + foo: str +class Bar(TypedDict, total=False): + bar: str + +def foo(body: Union[Foo, Bar] = {}) -> None: # OK + ... +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testUnionOfEquivalentTypedDictsDistinct] +from typing import TypedDict, Union, Literal + +class A(TypedDict): + type: Literal['a'] + value: bool +class B(TypedDict): + type: Literal['b'] + value: str + +Response = Union[A, B] +def method(message: Response) -> None: ... + +method({'type': 'a', 'value': True}) # OK +method({'type': 'b', 'value': 'abc'}) # OK +method({'type': 'a', 'value': 'abc'}) # E: Type of TypedDict is ambiguous, none of ("A", "B") matches cleanly \ + # E: Argument 1 to "method" has incompatible type "Dict[str, str]"; expected "Union[A, B]" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testUnionOfEquivalentTypedDictsNested] +from typing import TypedDict, Union + +class A(TypedDict, total=False): + foo: C +class B(TypedDict, total=False): + foo: D +class C(TypedDict, total=False): + c: str +class D(TypedDict, total=False): + d: str + +def foo(data: Union[A, B]) -> None: ... +foo({"foo": {"c": "foo"}}) # OK +foo({"foo": {"e": "foo"}}) # E: Type of TypedDict is ambiguous, none of ("A", "B") matches cleanly \ + # E: Argument 1 to "foo" has incompatible type "Dict[str, Dict[str, str]]"; expected "Union[A, B]" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] diff --git a/test-data/unit/check-unions.test b/test-data/unit/check-unions.test index 4c4fbc32ec3f..cabc28e786b2 100644 --- a/test-data/unit/check-unions.test +++ b/test-data/unit/check-unions.test @@ -971,14 +971,14 @@ if x: [builtins fixtures/dict.pyi] [out] -[case testUnpackUnionNoCrashOnPartialNoneList] +[case testUnpackUnionNoCrashOnPartialList] # flags: --strict-optional from typing import Dict, Tuple, List, Any a: Any d: Dict[str, Tuple[List[Tuple[str, str]], str]] -x, _ = d.get(a, ([], [])) -reveal_type(x) # N: Revealed type is "Union[builtins.list[Tuple[builtins.str, builtins.str]], builtins.list[]]" +x, _ = d.get(a, ([], "")) +reveal_type(x) # N: Revealed type is "builtins.list[Tuple[builtins.str, builtins.str]]" for y in x: pass [builtins fixtures/dict.pyi] diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index f4ec15e4fa9a..856b8b7266c1 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -29,7 +29,7 @@ class dict(Mapping[KT, VT]): @overload def get(self, k: KT) -> Optional[VT]: pass @overload - def get(self, k: KT, default: Union[KT, T]) -> Union[VT, T]: pass + def get(self, k: KT, default: Union[VT, T]) -> Union[VT, T]: pass def __len__(self) -> int: ... class int: # for convenience