Skip to content

Commit c005895

Browse files
committed
Generic dataclass support
1 parent 40315b7 commit c005895

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

mypy/plugins/dataclasses.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing_extensions import Final
77

88
from mypy import errorcodes, message_registry
9-
from mypy.expandtype import expand_type
9+
from mypy.expandtype import expand_type, expand_type_by_instance
1010
from mypy.messages import format_type_bare
1111
from mypy.nodes import (
1212
ARG_NAMED,
@@ -350,12 +350,15 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
350350
Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass
351351
to be used later whenever 'dataclasses.replace' is called for this dataclass.
352352
"""
353-
arg_types: list[Type] = [Instance(self._cls.info, [])]
354-
arg_kinds = [ARG_POS]
355-
arg_names: list[str | None] = [None]
353+
arg_types: list[Type] = []
354+
arg_kinds = []
355+
arg_names: list[str | None] = []
356+
357+
info = self._cls.info
356358
for attr in attributes:
357-
assert attr.type is not None
358-
arg_types.append(attr.type)
359+
attr_type = attr.expand_type(info)
360+
assert attr_type is not None
361+
arg_types.append(attr_type)
359362
arg_kinds.append(
360363
ARG_NAMED if attr.is_init_var and not attr.has_default else ARG_NAMED_OPT
361364
)
@@ -365,7 +368,7 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
365368
arg_types=arg_types,
366369
arg_kinds=arg_kinds,
367370
arg_names=arg_names,
368-
ret_type=Instance(self._cls.info, []),
371+
ret_type=NoneType(),
369372
fallback=self._api.named_type("builtins.function"),
370373
name=f"replace of {self._cls.info.name}",
371374
)
@@ -883,4 +886,12 @@ def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
883886

884887
signature = get_proper_type(replace_func.type)
885888
assert isinstance(signature, CallableType)
889+
signature = expand_type_by_instance(signature, obj_type)
890+
# re-add the instance type
891+
signature = signature.copy_modified(
892+
arg_types=[obj_type, *signature.arg_types],
893+
arg_kinds=[ARG_POS, *signature.arg_kinds],
894+
arg_names=[None, *signature.arg_names],
895+
ret_type=obj_type,
896+
)
886897
return signature

test-data/unit/check-dataclasses.test

+18
Original file line numberDiff line numberDiff line change
@@ -2070,3 +2070,21 @@ class C:
20702070
pass
20712071

20722072
replace(C()) # E: Argument 1 to "replace" has incompatible type "C"; expected a dataclass
2073+
2074+
[case testReplaceGeneric]
2075+
from dataclasses import dataclass, replace, InitVar
2076+
from typing import ClassVar, Generic, TypeVar
2077+
2078+
T = TypeVar('T')
2079+
2080+
@dataclass
2081+
class A(Generic[T]):
2082+
x: T
2083+
2084+
2085+
a = A(x=42)
2086+
reveal_type(a) # N: Revealed type is "__main__.A[builtins.int]"
2087+
a2 = replace(a, x=42)
2088+
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
2089+
a2 = replace(a, x='42') # E: Argument "x" to "replace" of "A" has incompatible type "str"; expected "int"
2090+
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"

0 commit comments

Comments
 (0)