Skip to content

Commit 7beaec2

Browse files
committed
Support descriptors in dataclass transform (#15006)
Infer `__init__` argument types from the signatures of descriptor `__set__` methods, if present. We can't (easily) perform full type inference in a plugin, so we cheat and use a simplified implementation that should still cover most use cases. Here we assume that `__set__` is not decorated or overloaded, in particular. Fixes #14868.
1 parent a7a995a commit 7beaec2

File tree

2 files changed

+269
-6
lines changed

2 files changed

+269
-6
lines changed

Diff for: mypy/plugins/dataclasses.py

+55-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing_extensions import Final
77

88
from mypy import errorcodes, message_registry
9-
from mypy.expandtype import expand_type
9+
from mypy.expandtype import expand_type, expand_type_by_instance
1010
from mypy.nodes import (
1111
ARG_NAMED,
1212
ARG_NAMED_OPT,
@@ -23,6 +23,7 @@
2323
Context,
2424
DataclassTransformSpec,
2525
Expression,
26+
FuncDef,
2627
IfStmt,
2728
JsonDict,
2829
NameExpr,
@@ -98,7 +99,7 @@ def __init__(
9899
self.has_default = has_default
99100
self.line = line
100101
self.column = column
101-
self.type = type
102+
self.type = type # Type as __init__ argument
102103
self.info = info
103104
self.kw_only = kw_only
104105
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
@@ -535,9 +536,12 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
535536
elif not isinstance(stmt.rvalue, TempNode):
536537
has_default = True
537538

538-
if not has_default:
539-
# Make all non-default attributes implicit because they are de-facto set
540-
# on self in the generated __init__(), not in the class body.
539+
if not has_default and self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
540+
# Make all non-default dataclass attributes implicit because they are de-facto
541+
# set on self in the generated __init__(), not in the class body. On the other
542+
# hand, we don't know how custom dataclass transforms initialize attributes,
543+
# so we don't treat them as implicit. This is required to support descriptors
544+
# (https://github.com/python/mypy/issues/14868).
541545
sym.implicit = True
542546

543547
is_kw_only = kw_only
@@ -578,6 +582,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
578582
)
579583

580584
current_attr_names.add(lhs.name)
585+
init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt)
581586
found_attrs[lhs.name] = DataclassAttribute(
582587
name=lhs.name,
583588
alias=alias,
@@ -586,7 +591,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
586591
has_default=has_default,
587592
line=stmt.line,
588593
column=stmt.column,
589-
type=sym.type,
594+
type=init_type,
590595
info=cls.info,
591596
kw_only=is_kw_only,
592597
is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass(
@@ -755,6 +760,50 @@ def _get_bool_arg(self, name: str, default: bool) -> bool:
755760
return require_bool_literal_argument(self._api, expression, name, default)
756761
return default
757762

763+
def _infer_dataclass_attr_init_type(
764+
self, sym: SymbolTableNode, name: str, context: Context
765+
) -> Type | None:
766+
"""Infer __init__ argument type for an attribute.
767+
768+
In particular, possibly use the signature of __set__.
769+
"""
770+
default = sym.type
771+
if sym.implicit:
772+
return default
773+
t = get_proper_type(sym.type)
774+
775+
# Perform a simple-minded inference from the signature of __set__, if present.
776+
# We can't use mypy.checkmember here, since this plugin runs before type checking.
777+
# We only support some basic scanerios here, which is hopefully sufficient for
778+
# the vast majority of use cases.
779+
if not isinstance(t, Instance):
780+
return default
781+
setter = t.type.get("__set__")
782+
if setter:
783+
if isinstance(setter.node, FuncDef):
784+
super_info = t.type.get_containing_type_info("__set__")
785+
assert super_info
786+
if setter.type:
787+
setter_type = get_proper_type(
788+
map_type_from_supertype(setter.type, t.type, super_info)
789+
)
790+
else:
791+
return AnyType(TypeOfAny.unannotated)
792+
if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [
793+
ARG_POS,
794+
ARG_POS,
795+
ARG_POS,
796+
]:
797+
return expand_type_by_instance(setter_type.arg_types[2], t)
798+
else:
799+
self._api.fail(
800+
f'Unsupported signature for "__set__" in "{t.type.name}"', context
801+
)
802+
else:
803+
self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context)
804+
805+
return default
806+
758807

759808
def add_dataclass_tag(info: TypeInfo) -> None:
760809
# The value is ignored, only the existence matters.

Diff for: test-data/unit/check-dataclass-transform.test

+214
Original file line numberDiff line numberDiff line change
@@ -807,3 +807,217 @@ reveal_type(bar.base) # N: Revealed type is "builtins.int"
807807

808808
[typing fixtures/typing-full.pyi]
809809
[builtins fixtures/dataclasses.pyi]
810+
811+
[case testDataclassTransformSimpleDescriptor]
812+
# flags: --python-version 3.11
813+
814+
from typing import dataclass_transform, overload, Any
815+
816+
@dataclass_transform()
817+
def my_dataclass(cls): ...
818+
819+
class Desc:
820+
@overload
821+
def __get__(self, instance: None, owner: Any) -> Desc: ...
822+
@overload
823+
def __get__(self, instance: object, owner: Any) -> str: ...
824+
def __get__(self, instance: object | None, owner: Any) -> Desc | str: ...
825+
826+
def __set__(self, instance: Any, value: str) -> None: ...
827+
828+
@my_dataclass
829+
class C:
830+
x: Desc
831+
y: int
832+
833+
C(x='x', y=1)
834+
C(x=1, y=1) # E: Argument "x" to "C" has incompatible type "int"; expected "str"
835+
reveal_type(C(x='x', y=1).x) # N: Revealed type is "builtins.str"
836+
reveal_type(C(x='x', y=1).y) # N: Revealed type is "builtins.int"
837+
reveal_type(C.x) # N: Revealed type is "__main__.Desc"
838+
839+
[typing fixtures/typing-full.pyi]
840+
[builtins fixtures/dataclasses.pyi]
841+
842+
[case testDataclassTransformUnannotatedDescriptor]
843+
# flags: --python-version 3.11
844+
845+
from typing import dataclass_transform, overload, Any
846+
847+
@dataclass_transform()
848+
def my_dataclass(cls): ...
849+
850+
class Desc:
851+
@overload
852+
def __get__(self, instance: None, owner: Any) -> Desc: ...
853+
@overload
854+
def __get__(self, instance: object, owner: Any) -> str: ...
855+
def __get__(self, instance: object | None, owner: Any) -> Desc | str: ...
856+
857+
def __set__(*args, **kwargs): ...
858+
859+
@my_dataclass
860+
class C:
861+
x: Desc
862+
y: int
863+
864+
C(x='x', y=1)
865+
C(x=1, y=1)
866+
reveal_type(C(x='x', y=1).x) # N: Revealed type is "builtins.str"
867+
reveal_type(C(x='x', y=1).y) # N: Revealed type is "builtins.int"
868+
reveal_type(C.x) # N: Revealed type is "__main__.Desc"
869+
870+
[typing fixtures/typing-full.pyi]
871+
[builtins fixtures/dataclasses.pyi]
872+
873+
[case testDataclassTransformGenericDescriptor]
874+
# flags: --python-version 3.11
875+
876+
from typing import dataclass_transform, overload, Any, TypeVar, Generic
877+
878+
@dataclass_transform()
879+
def my_dataclass(frozen: bool = False): ...
880+
881+
T = TypeVar("T")
882+
883+
class Desc(Generic[T]):
884+
@overload
885+
def __get__(self, instance: None, owner: Any) -> Desc[T]: ...
886+
@overload
887+
def __get__(self, instance: object, owner: Any) -> T: ...
888+
def __get__(self, instance: object | None, owner: Any) -> Desc | T: ...
889+
890+
def __set__(self, instance: Any, value: T) -> None: ...
891+
892+
@my_dataclass()
893+
class C:
894+
x: Desc[str]
895+
896+
C(x='x')
897+
C(x=1) # E: Argument "x" to "C" has incompatible type "int"; expected "str"
898+
reveal_type(C(x='x').x) # N: Revealed type is "builtins.str"
899+
reveal_type(C.x) # N: Revealed type is "__main__.Desc[builtins.str]"
900+
901+
@my_dataclass()
902+
class D(C):
903+
y: Desc[int]
904+
905+
d = D(x='x', y=1)
906+
reveal_type(d.x) # N: Revealed type is "builtins.str"
907+
reveal_type(d.y) # N: Revealed type is "builtins.int"
908+
reveal_type(D.x) # N: Revealed type is "__main__.Desc[builtins.str]"
909+
reveal_type(D.y) # N: Revealed type is "__main__.Desc[builtins.int]"
910+
911+
@my_dataclass(frozen=True)
912+
class F:
913+
x: Desc[str] = Desc()
914+
915+
F(x='x')
916+
F(x=1) # E: Argument "x" to "F" has incompatible type "int"; expected "str"
917+
reveal_type(F(x='x').x) # N: Revealed type is "builtins.str"
918+
reveal_type(F.x) # N: Revealed type is "__main__.Desc[builtins.str]"
919+
920+
[typing fixtures/typing-full.pyi]
921+
[builtins fixtures/dataclasses.pyi]
922+
923+
[case testDataclassTransformGenericDescriptorWithInheritance]
924+
# flags: --python-version 3.11
925+
926+
from typing import dataclass_transform, overload, Any, TypeVar, Generic
927+
928+
@dataclass_transform()
929+
def my_dataclass(cls): ...
930+
931+
T = TypeVar("T")
932+
933+
class Desc(Generic[T]):
934+
@overload
935+
def __get__(self, instance: None, owner: Any) -> Desc[T]: ...
936+
@overload
937+
def __get__(self, instance: object, owner: Any) -> T: ...
938+
def __get__(self, instance: object | None, owner: Any) -> Desc | T: ...
939+
940+
def __set__(self, instance: Any, value: T) -> None: ...
941+
942+
class Desc2(Desc[str]):
943+
pass
944+
945+
@my_dataclass
946+
class C:
947+
x: Desc2
948+
949+
C(x='x')
950+
C(x=1) # E: Argument "x" to "C" has incompatible type "int"; expected "str"
951+
reveal_type(C(x='x').x) # N: Revealed type is "builtins.str"
952+
reveal_type(C.x) # N: Revealed type is "__main__.Desc[builtins.str]"
953+
954+
[typing fixtures/typing-full.pyi]
955+
[builtins fixtures/dataclasses.pyi]
956+
957+
[case testDataclassTransformDescriptorWithDifferentGetSetTypes]
958+
# flags: --python-version 3.11
959+
960+
from typing import dataclass_transform, overload, Any
961+
962+
@dataclass_transform()
963+
def my_dataclass(cls): ...
964+
965+
class Desc:
966+
@overload
967+
def __get__(self, instance: None, owner: Any) -> int: ...
968+
@overload
969+
def __get__(self, instance: object, owner: Any) -> str: ...
970+
def __get__(self, instance, owner): ...
971+
972+
def __set__(self, instance: Any, value: bytes) -> None: ...
973+
974+
@my_dataclass
975+
class C:
976+
x: Desc
977+
978+
c = C(x=b'x')
979+
C(x=1) # E: Argument "x" to "C" has incompatible type "int"; expected "bytes"
980+
reveal_type(c.x) # N: Revealed type is "builtins.str"
981+
reveal_type(C.x) # N: Revealed type is "builtins.int"
982+
c.x = b'x'
983+
c.x = 1 # E: Incompatible types in assignment (expression has type "int", variable has type "bytes")
984+
985+
[typing fixtures/typing-full.pyi]
986+
[builtins fixtures/dataclasses.pyi]
987+
988+
[case testDataclassTransformUnsupportedDescriptors]
989+
# flags: --python-version 3.11
990+
991+
from typing import dataclass_transform, overload, Any
992+
993+
@dataclass_transform()
994+
def my_dataclass(cls): ...
995+
996+
class Desc:
997+
@overload
998+
def __get__(self, instance: None, owner: Any) -> int: ...
999+
@overload
1000+
def __get__(self, instance: object, owner: Any) -> str: ...
1001+
def __get__(self, instance, owner): ...
1002+
1003+
def __set__(*args, **kwargs) -> None: ...
1004+
1005+
class Desc2:
1006+
@overload
1007+
def __get__(self, instance: None, owner: Any) -> int: ...
1008+
@overload
1009+
def __get__(self, instance: object, owner: Any) -> str: ...
1010+
def __get__(self, instance, owner): ...
1011+
1012+
@overload
1013+
def __set__(self, instance: Any, value: bytes) -> None: ...
1014+
@overload
1015+
def __set__(self) -> None: ...
1016+
def __set__(self, *args, **kawrga) -> None: ...
1017+
1018+
@my_dataclass
1019+
class C:
1020+
x: Desc # E: Unsupported signature for "__set__" in "Desc"
1021+
y: Desc2 # E: Unsupported "__set__" in "Desc2"
1022+
[typing fixtures/typing-full.pyi]
1023+
[builtins fixtures/dataclasses.pyi]

0 commit comments

Comments
 (0)