Skip to content

Commit 8236c93

Browse files
Add |= and | operators support for TypedDict (python#16249)
Please, note that there are several problems with `__ror__` definitions. 1. `dict.__ror__` does not define support for `Mapping?` types. For example: ```python >>> import types >>> {'a': 1} | types.MappingProxyType({'b': 2}) {'a': 1, 'b': 2} >>> ``` 2. `TypedDict.__ror__` also does not define this support So, I would like to defer this feature for the future, we need some discussion to happen. However, this PR does fully solve the problem OP had. Closes python#16244 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cda163d commit 8236c93

File tree

9 files changed

+316
-11
lines changed

9 files changed

+316
-11
lines changed

mypy/checker.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7783,14 +7783,25 @@ def infer_operator_assignment_method(typ: Type, operator: str) -> tuple[bool, st
77837783
"""
77847784
typ = get_proper_type(typ)
77857785
method = operators.op_methods[operator]
7786+
existing_method = None
77867787
if isinstance(typ, Instance):
7787-
if operator in operators.ops_with_inplace_method:
7788-
inplace_method = "__i" + method[2:]
7789-
if typ.type.has_readable_member(inplace_method):
7790-
return True, inplace_method
7788+
existing_method = _find_inplace_method(typ, method, operator)
7789+
elif isinstance(typ, TypedDictType):
7790+
existing_method = _find_inplace_method(typ.fallback, method, operator)
7791+
7792+
if existing_method is not None:
7793+
return True, existing_method
77917794
return False, method
77927795

77937796

7797+
def _find_inplace_method(inst: Instance, method: str, operator: str) -> str | None:
7798+
if operator in operators.ops_with_inplace_method:
7799+
inplace_method = "__i" + method[2:]
7800+
if inst.type.has_readable_member(inplace_method):
7801+
return inplace_method
7802+
return None
7803+
7804+
77947805
def is_valid_inferred_type(typ: Type, is_lvalue_final: bool = False) -> bool:
77957806
"""Is an inferred type valid and needs no further refinement?
77967807

mypy/checkexpr.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from __future__ import annotations
44

5+
import enum
56
import itertools
67
import time
78
from collections import defaultdict
89
from contextlib import contextmanager
910
from typing import Callable, ClassVar, Final, Iterable, Iterator, List, Optional, Sequence, cast
10-
from typing_extensions import TypeAlias as _TypeAlias, overload
11+
from typing_extensions import TypeAlias as _TypeAlias, assert_never, overload
1112

1213
import mypy.checker
1314
import mypy.errorcodes as codes
@@ -277,6 +278,20 @@ class Finished(Exception):
277278
"""Raised if we can terminate overload argument check early (no match)."""
278279

279280

281+
@enum.unique
282+
class UseReverse(enum.Enum):
283+
"""Used in `visit_op_expr` to enable or disable reverse method checks."""
284+
285+
DEFAULT = 0
286+
ALWAYS = 1
287+
NEVER = 2
288+
289+
290+
USE_REVERSE_DEFAULT: Final = UseReverse.DEFAULT
291+
USE_REVERSE_ALWAYS: Final = UseReverse.ALWAYS
292+
USE_REVERSE_NEVER: Final = UseReverse.NEVER
293+
294+
280295
class ExpressionChecker(ExpressionVisitor[Type]):
281296
"""Expression type checker.
282297
@@ -3371,6 +3386,24 @@ def visit_op_expr(self, e: OpExpr) -> Type:
33713386
return proper_left_type.copy_modified(
33723387
items=proper_left_type.items + [UnpackType(mapped)]
33733388
)
3389+
3390+
use_reverse: UseReverse = USE_REVERSE_DEFAULT
3391+
if e.op == "|":
3392+
if is_named_instance(proper_left_type, "builtins.dict"):
3393+
# This is a special case for `dict | TypedDict`.
3394+
# 1. Find `dict | TypedDict` case
3395+
# 2. Switch `dict.__or__` to `TypedDict.__ror__` (the same from both runtime and typing perspective)
3396+
proper_right_type = get_proper_type(self.accept(e.right))
3397+
if isinstance(proper_right_type, TypedDictType):
3398+
use_reverse = USE_REVERSE_ALWAYS
3399+
if isinstance(proper_left_type, TypedDictType):
3400+
# This is the reverse case: `TypedDict | dict`,
3401+
# simply do not allow the reverse checking:
3402+
# do not call `__dict__.__ror__`.
3403+
proper_right_type = get_proper_type(self.accept(e.right))
3404+
if is_named_instance(proper_right_type, "builtins.dict"):
3405+
use_reverse = USE_REVERSE_NEVER
3406+
33743407
if TYPE_VAR_TUPLE in self.chk.options.enable_incomplete_feature:
33753408
# Handle tuple[X, ...] + tuple[Y, Z] = tuple[*tuple[X, ...], Y, Z].
33763409
if (
@@ -3390,7 +3423,25 @@ def visit_op_expr(self, e: OpExpr) -> Type:
33903423

33913424
if e.op in operators.op_methods:
33923425
method = operators.op_methods[e.op]
3393-
result, method_type = self.check_op(method, left_type, e.right, e, allow_reverse=True)
3426+
if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER:
3427+
result, method_type = self.check_op(
3428+
method,
3429+
base_type=left_type,
3430+
arg=e.right,
3431+
context=e,
3432+
allow_reverse=use_reverse is UseReverse.DEFAULT,
3433+
)
3434+
elif use_reverse is UseReverse.ALWAYS:
3435+
result, method_type = self.check_op(
3436+
# The reverse operator here gives better error messages:
3437+
operators.reverse_op_methods[method],
3438+
base_type=self.accept(e.right),
3439+
arg=e.left,
3440+
context=e,
3441+
allow_reverse=False,
3442+
)
3443+
else:
3444+
assert_never(use_reverse)
33943445
e.method_type = method_type
33953446
return result
33963447
else:

mypy/plugins/default.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,21 @@ def get_method_signature_hook(
7474
return typed_dict_setdefault_signature_callback
7575
elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
7676
return typed_dict_pop_signature_callback
77-
elif fullname in {n + ".update" for n in TPDICT_FB_NAMES}:
78-
return typed_dict_update_signature_callback
7977
elif fullname == "_ctypes.Array.__setitem__":
8078
return ctypes.array_setitem_callback
8179
elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD:
8280
return singledispatch.call_singledispatch_function_callback
81+
82+
typed_dict_updates = set()
83+
for n in TPDICT_FB_NAMES:
84+
typed_dict_updates.add(n + ".update")
85+
typed_dict_updates.add(n + ".__or__")
86+
typed_dict_updates.add(n + ".__ror__")
87+
typed_dict_updates.add(n + ".__ior__")
88+
89+
if fullname in typed_dict_updates:
90+
return typed_dict_update_signature_callback
91+
8392
return None
8493

8594
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
@@ -401,11 +410,16 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
401410

402411

403412
def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
404-
"""Try to infer a better signature type for TypedDict.update."""
413+
"""Try to infer a better signature type for methods that update `TypedDict`.
414+
415+
This includes: `TypedDict.update`, `TypedDict.__or__`, `TypedDict.__ror__`,
416+
and `TypedDict.__ior__`.
417+
"""
405418
signature = ctx.default_signature
406419
if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1:
407420
arg_type = get_proper_type(signature.arg_types[0])
408-
assert isinstance(arg_type, TypedDictType)
421+
if not isinstance(arg_type, TypedDictType):
422+
return signature
409423
arg_type = arg_type.as_anonymous()
410424
arg_type = arg_type.copy_modified(required_keys=set())
411425
if ctx.args and ctx.args[0]:

test-data/unit/check-typeddict.test

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3236,3 +3236,146 @@ def foo(x: int) -> Foo: ...
32363236
f: Foo = {**foo("no")} # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
32373237
[builtins fixtures/dict.pyi]
32383238
[typing fixtures/typing-typeddict.pyi]
3239+
3240+
3241+
[case testTypedDictWith__or__method]
3242+
from typing import Dict
3243+
from mypy_extensions import TypedDict
3244+
3245+
class Foo(TypedDict):
3246+
key: int
3247+
3248+
foo1: Foo = {'key': 1}
3249+
foo2: Foo = {'key': 2}
3250+
3251+
reveal_type(foo1 | foo2) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"
3252+
reveal_type(foo1 | {'key': 1}) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"
3253+
reveal_type(foo1 | {'key': 'a'}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3254+
reveal_type(foo1 | {}) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"
3255+
3256+
d1: Dict[str, int]
3257+
d2: Dict[int, str]
3258+
3259+
reveal_type(foo1 | d1) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3260+
foo1 | d2 # E: Unsupported operand types for | ("Foo" and "Dict[int, str]")
3261+
3262+
3263+
class Bar(TypedDict):
3264+
key: int
3265+
value: str
3266+
3267+
bar: Bar
3268+
reveal_type(bar | {}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
3269+
reveal_type(bar | {'key': 1, 'value': 'v'}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
3270+
reveal_type(bar | {'key': 1}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
3271+
reveal_type(bar | {'value': 'v'}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
3272+
reveal_type(bar | {'key': 'a'}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3273+
reveal_type(bar | {'value': 1}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3274+
reveal_type(bar | {'key': 'a', 'value': 1}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3275+
3276+
reveal_type(bar | foo1) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
3277+
reveal_type(bar | d1) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3278+
bar | d2 # E: Unsupported operand types for | ("Bar" and "Dict[int, str]")
3279+
[builtins fixtures/dict.pyi]
3280+
[typing fixtures/typing-typeddict-iror.pyi]
3281+
3282+
[case testTypedDictWith__or__method_error]
3283+
from mypy_extensions import TypedDict
3284+
3285+
class Foo(TypedDict):
3286+
key: int
3287+
3288+
foo: Foo = {'key': 1}
3289+
foo | 1
3290+
3291+
class SubDict(dict): ...
3292+
foo | SubDict()
3293+
[out]
3294+
main:7: error: No overload variant of "__or__" of "TypedDict" matches argument type "int"
3295+
main:7: note: Possible overload variants:
3296+
main:7: note: def __or__(self, TypedDict({'key'?: int}), /) -> Foo
3297+
main:7: note: def __or__(self, Dict[str, Any], /) -> Dict[str, object]
3298+
main:10: error: No overload variant of "__ror__" of "dict" matches argument type "Foo"
3299+
main:10: note: Possible overload variants:
3300+
main:10: note: def __ror__(self, Dict[Any, Any], /) -> Dict[Any, Any]
3301+
main:10: note: def [T, T2] __ror__(self, Dict[T, T2], /) -> Dict[Union[Any, T], Union[Any, T2]]
3302+
[builtins fixtures/dict.pyi]
3303+
[typing fixtures/typing-typeddict-iror.pyi]
3304+
3305+
[case testTypedDictWith__ror__method]
3306+
from typing import Dict
3307+
from mypy_extensions import TypedDict
3308+
3309+
class Foo(TypedDict):
3310+
key: int
3311+
3312+
foo: Foo = {'key': 1}
3313+
3314+
reveal_type({'key': 1} | foo) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"
3315+
reveal_type({'key': 'a'} | foo) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3316+
reveal_type({} | foo) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"
3317+
{1: 'a'} | foo # E: Dict entry 0 has incompatible type "int": "str"; expected "str": "Any"
3318+
3319+
d1: Dict[str, int]
3320+
d2: Dict[int, str]
3321+
3322+
reveal_type(d1 | foo) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3323+
d2 | foo # E: Unsupported operand types for | ("Dict[int, str]" and "Foo")
3324+
1 | foo # E: Unsupported left operand type for | ("int")
3325+
3326+
3327+
class Bar(TypedDict):
3328+
key: int
3329+
value: str
3330+
3331+
bar: Bar
3332+
reveal_type({} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
3333+
reveal_type({'key': 1, 'value': 'v'} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
3334+
reveal_type({'key': 1} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
3335+
reveal_type({'value': 'v'} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
3336+
reveal_type({'key': 'a'} | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3337+
reveal_type({'value': 1} | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3338+
reveal_type({'key': 'a', 'value': 1} | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3339+
3340+
reveal_type(d1 | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
3341+
d2 | bar # E: Unsupported operand types for | ("Dict[int, str]" and "Bar")
3342+
[builtins fixtures/dict.pyi]
3343+
[typing fixtures/typing-typeddict-iror.pyi]
3344+
3345+
[case testTypedDictWith__ior__method]
3346+
from typing import Dict
3347+
from mypy_extensions import TypedDict
3348+
3349+
class Foo(TypedDict):
3350+
key: int
3351+
3352+
foo: Foo = {'key': 1}
3353+
foo |= {'key': 2}
3354+
3355+
foo |= {}
3356+
foo |= {'key': 'a', 'b': 'a'} # E: Expected TypedDict key "key" but found keys ("key", "b") \
3357+
# E: Incompatible types (expression has type "str", TypedDict item "key" has type "int")
3358+
foo |= {'b': 2} # E: Unexpected TypedDict key "b"
3359+
3360+
d1: Dict[str, int]
3361+
d2: Dict[int, str]
3362+
3363+
foo |= d1 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[str, int]"; expected "TypedDict({'key'?: int})"
3364+
foo |= d2 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[int, str]"; expected "TypedDict({'key'?: int})"
3365+
3366+
3367+
class Bar(TypedDict):
3368+
key: int
3369+
value: str
3370+
3371+
bar: Bar
3372+
bar |= {}
3373+
bar |= {'key': 1, 'value': 'a'}
3374+
bar |= {'key': 'a', 'value': 'a', 'b': 'a'} # E: Expected TypedDict keys ("key", "value") but found keys ("key", "value", "b") \
3375+
# E: Incompatible types (expression has type "str", TypedDict item "key" has type "int")
3376+
3377+
bar |= foo
3378+
bar |= d1 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[str, int]"; expected "TypedDict({'key'?: int, 'value'?: str})"
3379+
bar |= d2 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[int, str]"; expected "TypedDict({'key'?: int, 'value'?: str})"
3380+
[builtins fixtures/dict.pyi]
3381+
[typing fixtures/typing-typeddict-iror.pyi]

test-data/unit/fixtures/dict.pyi

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from _typeshed import SupportsKeysAndGetItem
44
import _typeshed
55
from typing import (
6-
TypeVar, Generic, Iterable, Iterator, Mapping, Tuple, overload, Optional, Union, Sequence
6+
TypeVar, Generic, Iterable, Iterator, Mapping, Tuple, overload, Optional, Union, Sequence,
7+
Self,
78
)
89

910
T = TypeVar('T')
11+
T2 = TypeVar('T2')
1012
KT = TypeVar('KT')
1113
VT = TypeVar('VT')
1214

@@ -34,6 +36,21 @@ class dict(Mapping[KT, VT]):
3436
def get(self, k: KT, default: Union[VT, T]) -> Union[VT, T]: pass
3537
def __len__(self) -> int: ...
3638

39+
# This was actually added in 3.9:
40+
@overload
41+
def __or__(self, __value: dict[KT, VT]) -> dict[KT, VT]: ...
42+
@overload
43+
def __or__(self, __value: dict[T, T2]) -> dict[Union[KT, T], Union[VT, T2]]: ...
44+
@overload
45+
def __ror__(self, __value: dict[KT, VT]) -> dict[KT, VT]: ...
46+
@overload
47+
def __ror__(self, __value: dict[T, T2]) -> dict[Union[KT, T], Union[VT, T2]]: ...
48+
# dict.__ior__ should be kept roughly in line with MutableMapping.update()
49+
@overload # type: ignore[misc]
50+
def __ior__(self, __value: _typeshed.SupportsKeysAndGetItem[KT, VT]) -> Self: ...
51+
@overload
52+
def __ior__(self, __value: Iterable[Tuple[KT, VT]]) -> Self: ...
53+
3754
class int: # for convenience
3855
def __add__(self, x: Union[int, complex]) -> int: pass
3956
def __radd__(self, x: int) -> int: pass

test-data/unit/fixtures/typing-async.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ ClassVar = 0
2424
Final = 0
2525
Literal = 0
2626
NoReturn = 0
27+
Self = 0
2728

2829
T = TypeVar('T')
2930
T_co = TypeVar('T_co', covariant=True)

test-data/unit/fixtures/typing-full.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Literal = 0
3030
TypedDict = 0
3131
NoReturn = 0
3232
NewType = 0
33+
Self = 0
3334

3435
T = TypeVar('T')
3536
T_co = TypeVar('T_co', covariant=True)

test-data/unit/fixtures/typing-medium.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ NoReturn = 0
2828
NewType = 0
2929
TypeAlias = 0
3030
LiteralString = 0
31+
Self = 0
3132

3233
T = TypeVar('T')
3334
T_co = TypeVar('T_co', covariant=True)

0 commit comments

Comments
 (0)