Skip to content

Commit 4028dfc

Browse files
author
Roy Williams
committed
Implement type-aware get for TypedDict
Previously, `get` would simply fallback to the type of the underlying dictionary which made TypedDicts hard to use with code that's parsing objects where fields may or may not be present (for example, parsing a response). This implementation _explicitly_ ignores the default parameter's type as it's quite useful to chain together get calls (Until something like PEP 505 hits 😄) ```python foo.get('a', {}).get('b', {}).get('c') ``` This fixes python#2612
1 parent 011ba37 commit 4028dfc

File tree

6 files changed

+72
-8
lines changed

6 files changed

+72
-8
lines changed

mypy/checkexpr.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
PartialType, DeletedType, UnboundType, UninhabitedType, TypeType,
1010
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
1111
get_typ_args, set_typ_args,
12-
)
12+
TypedDictGetFunction)
1313
from mypy.nodes import (
1414
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
1515
MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr,
@@ -341,6 +341,10 @@ def check_call(self, callee: Type, args: List[Expression],
341341
"""
342342
arg_messages = arg_messages or self.msg
343343
if isinstance(callee, CallableType):
344+
if isinstance(callee, TypedDictGetFunction):
345+
if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)):
346+
return_type = self.get_typeddict_index_type(callee.typed_dict, args[0])
347+
return return_type, callee
344348
if callee.is_concrete_type_obj() and callee.type_object().is_abstract:
345349
type = callee.type_object()
346350
self.msg.cannot_instantiate_abstract_class(
@@ -1484,11 +1488,13 @@ def _get_value(self, index: Expression) -> Optional[int]:
14841488
return None
14851489

14861490
def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type:
1491+
return self.get_typeddict_index_type(td_type, index)
1492+
1493+
def get_typeddict_index_type(self, td_type: TypedDictType, index: Expression) -> Type:
14871494
if not isinstance(index, (StrExpr, UnicodeExpr)):
14881495
self.msg.typeddict_item_name_must_be_string_literal(td_type, index)
14891496
return AnyType()
14901497
item_name = index.value
1491-
14921498
item_type = td_type.items.get(item_name)
14931499
if item_type is None:
14941500
self.msg.typeddict_item_name_not_found(td_type, item_name, index)

mypy/checkmember.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from mypy.types import (
66
Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef,
77
Overloaded, TypeVarType, UnionType, PartialType,
8-
DeletedType, NoneTyp, TypeType, function_type
9-
)
8+
DeletedType, NoneTyp, TypeType, function_type,
9+
TypedDictGetFunction)
1010
from mypy.nodes import (
1111
TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr,
1212
ARG_POS, ARG_STAR, ARG_STAR2,
@@ -120,9 +120,12 @@ def analyze_member_access(name: str,
120120
original_type=original_type, chk=chk)
121121
elif isinstance(typ, TypedDictType):
122122
# Actually look up from the fallback instance type.
123-
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
124-
is_operator, builtin_type, not_ready_callback, msg,
125-
original_type=original_type, chk=chk)
123+
result = analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
124+
is_operator, builtin_type, not_ready_callback, msg,
125+
original_type=original_type, chk=chk)
126+
if name == 'get' and isinstance(result, CallableType):
127+
result = TypedDictGetFunction(typ, result)
128+
return result
126129
elif isinstance(typ, FunctionLike) and typ.is_type_obj():
127130
# Class attribute.
128131
# TODO super?

mypy/types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,26 @@ def zipall(self, right: 'TypedDictType') \
980980
yield (item_name, None, right_item_type)
981981

982982

983+
class TypedDictGetFunction(CallableType):
984+
"""A special callable type containing a reference to the TypedDict `get` callable instance.
985+
This is needed to delay determining the signature of a TypedDict's `get` method until the
986+
method is actually called. This allows `get` to behave just as indexing into the TypedDict
987+
would.
988+
989+
This is not a real type, but is needed to allow TypedDict.get to behave as expected.
990+
"""
991+
def __init__(self, typed_dict: TypedDictType, fallback_callable: CallableType) -> None:
992+
super().__init__(fallback_callable.arg_types, fallback_callable.arg_kinds,
993+
fallback_callable.arg_names, fallback_callable.ret_type,
994+
fallback_callable.fallback, fallback_callable.name,
995+
fallback_callable.definition, fallback_callable.variables,
996+
fallback_callable.line, fallback_callable.column,
997+
fallback_callable.is_ellipsis_args, fallback_callable.implicit,
998+
fallback_callable.is_classmethod_class, fallback_callable.special_sig)
999+
self.typed_dict = typed_dict
1000+
self.fallback_callable = fallback_callable
1001+
1002+
9831003
class StarType(Type):
9841004
"""The star type *type_parameter.
9851005

test-data/unit/check-typeddict.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,38 @@ def set_coordinate(p: TaggedPoint, key: str, value: int) -> None:
431431

432432
-- Special Method: get
433433

434+
[case testCanUseGetMethodWithStringLiteralKey]
435+
from mypy_extensions import TypedDict
436+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
437+
p = TaggedPoint(type='2d', x=42, y=1337)
438+
reveal_type(p.get('type')) # E: Revealed type is 'builtins.str'
439+
reveal_type(p.get('x')) # E: Revealed type is 'builtins.int'
440+
reveal_type(p.get('y')) # E: Revealed type is 'builtins.int'
441+
[builtins fixtures/dict.pyi]
442+
443+
[case testCannotGetMethodWithInvalidStringLiteralKey]
444+
from mypy_extensions import TypedDict
445+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
446+
p = TaggedPoint(type='2d', x=42, y=1337)
447+
p.get('z') # E: 'z' is not a valid item name; expected one of ['type', 'x', 'y']
448+
[builtins fixtures/dict.pyi]
449+
450+
[case testGetMethodWithVariableKeyFallsBack]
451+
from mypy_extensions import TypedDict
452+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
453+
p = TaggedPoint(type='2d', x=42, y=1337)
454+
key = 'type'
455+
reveal_type(p.get(key)) # E: Revealed type is 'builtins.object*'
456+
[builtins fixtures/dict.pyi]
457+
458+
[case testChainedGetMethodWithFallback]
459+
from mypy_extensions import TypedDict
460+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
461+
PointSet = TypedDict('PointSet', {'first_point': TaggedPoint})
462+
p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
463+
reveal_type(p.get('first_point', {}).get('x')) # E: Revealed type is 'builtins.int'
464+
[builtins fixtures/dict.pyi]
465+
434466
-- TODO: Implement support for these cases:
435467
--[case testGetOfTypedDictWithValidStringLiteralKeyReturnsPreciseType]
436468
--[case testGetOfTypedDictWithInvalidStringLiteralKeyIsError]

test-data/unit/fixtures/dict.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]):
1818
def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass
1919
def __setitem__(self, k: KT, v: VT) -> None: pass
2020
def __iter__(self) -> Iterator[KT]: pass
21+
def get(self, k: KT, default: VT=None) -> VT: pass
2122
def update(self, a: Mapping[KT, VT]) -> None: pass
2223

2324
class int: # for convenience

test-data/unit/lib-stub/typing.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ class Sequence(Iterable[T], Generic[T]):
7878
@abstractmethod
7979
def __getitem__(self, n: Any) -> T: pass
8080

81-
class Mapping(Generic[T, U]): pass
81+
class Mapping(Generic[T, U]):
82+
@abstractmethod
83+
def get(self, k: T, default: U=None) -> U: pass
8284

8385
class MutableMapping(Generic[T, U]): pass
8486

0 commit comments

Comments
 (0)