Skip to content

Commit fc66fc9

Browse files
committed
🐛 support py3.9 native collection types with generics. i.e.: list[T]
1 parent 2ef5a71 commit fc66fc9

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

marshmallow_dataclass/generic_resolver.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import dataclasses
3-
import inspect
43
import sys
54
from typing import (
65
Dict,
@@ -11,14 +10,18 @@
1110
TypeVar,
1211
Union,
1312
)
14-
1513
import typing_inspect
14+
import warnings
1615

1716
if sys.version_info >= (3, 9):
1817
from typing import Annotated, get_args, get_origin
18+
from types import GenericAlias
1919
else:
2020
from typing_extensions import Annotated, get_args, get_origin
2121

22+
GenericAlias = type(list)
23+
24+
2225
if sys.version_info >= (3, 13):
2326
from typing import NoDefault
2427
else:
@@ -194,25 +197,27 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
194197
def _replace_typevars(
195198
clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None
196199
) -> type:
197-
if (
198-
not resolved_generics
199-
or inspect.isclass(clazz)
200-
or not may_contain_typevars(clazz)
201-
):
200+
if not resolved_generics or not may_contain_typevars(clazz):
202201
return clazz
203202

204-
return clazz.copy_with( # type: ignore
205-
tuple(
206-
(
207-
_replace_typevars(arg, resolved_generics)
208-
if may_contain_typevars(arg)
209-
else (
210-
resolved_generics[arg].result() if arg in resolved_generics else arg
211-
)
212-
)
213-
for arg in get_args(clazz)
203+
new_args = tuple(
204+
(
205+
_replace_typevars(arg, resolved_generics)
206+
if may_contain_typevars(arg)
207+
else (resolved_generics[arg].result() if arg in resolved_generics else arg)
214208
)
209+
for arg in get_args(clazz)
215210
)
211+
# i.e.: typing.List, typing.Dict, but not list, and dict
212+
if hasattr(clazz, "copy_with"):
213+
return clazz.copy_with(new_args)
214+
# i.e.: list, dict - inspired by typing._strip_annotations
215+
if sys.version_info >= (3, 9) and isinstance(clazz, GenericAlias):
216+
return GenericAlias(clazz.__origin__, new_args) # type:ignore[return-value]
217+
218+
# I'm not sure how we'd end up here. But raise a warnings so people can create an issue
219+
warnings.warn(f"Unable to replace typevars in {clazz}")
220+
return clazz
216221

217222

218223
def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
@@ -237,7 +242,7 @@ def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
237242
# If it's a class we handle it later as a Nested. Nothing to resolve now.
238243
new_field = field
239244
field_type: type = field.type # type: ignore[assignment]
240-
if not inspect.isclass(field_type) and may_contain_typevars(field_type):
245+
if may_contain_typevars(field_type):
241246
new_field = copy.copy(field)
242247
new_field.type = _replace_typevars(
243248
field_type, resolved_typevars[subclass]

tests/test_generics.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,23 @@ class TestClass(typing.Generic[T, U]):
183183
test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)])
184184
)
185185

186+
@pytest.mark.skipif(
187+
sys.version_info < (3, 9), reason="requires python 3.9 or higher"
188+
)
189+
def test_deep_generic_native(self):
190+
T = typing.TypeVar("T")
191+
U = typing.TypeVar("U")
192+
193+
@dataclasses.dataclass
194+
class TestClass(typing.Generic[T, U]):
195+
pairs: list[tuple[T, U]]
196+
197+
test_schema = class_schema(TestClass[str, int])()
198+
199+
self.assertEqual(
200+
test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)])
201+
)
202+
186203
def test_deep_generic_with_union(self):
187204
T = typing.TypeVar("T")
188205
U = typing.TypeVar("U")
@@ -348,7 +365,7 @@ class OptionalGeneric(typing.Generic[T]):
348365
schema_s.load({"data": 2})
349366

350367
@pytest.mark.skipif(
351-
sys.version_info <= (3, 13), reason="requires python 3.13 or higher"
368+
sys.version_info < (3, 13), reason="requires python 3.13 or higher"
352369
)
353370
def test_generic_default(self):
354371
T = typing.TypeVar("T", default=str)
@@ -399,7 +416,7 @@ class NestedGeneric(typing.Generic[T]):
399416
schema_nested_generic.load({"data": {"data": "str"}})
400417

401418
@pytest.mark.skipif(
402-
sys.version_info <= (3, 13), reason="requires python 3.13 or higher"
419+
sys.version_info < (3, 13), reason="requires python 3.13 or higher"
403420
)
404421
def test_deep_generic_with_default_overrides(self):
405422
T = typing.TypeVar("T", default=bool)
@@ -482,7 +499,7 @@ class TestClass3(TestAlias[typing.List[int]]): # type: ignore[override]
482499
)
483500

484501
@pytest.mark.skipif(
485-
sys.version_info <= (3, 13), reason="requires python 3.13 or higher"
502+
sys.version_info < (3, 13), reason="requires python 3.13 or higher"
486503
)
487504
def test_generic_default_recursion(self):
488505
T = typing.TypeVar("T", default=str)

0 commit comments

Comments
 (0)