Skip to content

Commit c3f5da1

Browse files
committed
Add support for TypeVar defaults
Tested with 3.12.0rc2
1 parent c361f3a commit c3f5da1

File tree

2 files changed

+210
-2
lines changed

2 files changed

+210
-2
lines changed

marshmallow_dataclass/generic_resolver.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,31 @@
1818
else:
1919
from typing_extensions import Annotated, get_args, get_origin
2020

21+
if sys.version_info >= (3, 13):
22+
from typing import NoDefault
23+
else:
24+
from typing import final
25+
26+
@final
27+
class NoDefault:
28+
pass
29+
30+
2131
_U = TypeVar("_U")
2232

2333

2434
class UnboundTypeVarError(TypeError):
2535
"""TypeVar instance can not be resolved to a type spec.
2636
2737
This exception is raised when an unbound TypeVar is encountered.
38+
"""
2839

40+
41+
class InvalidTypeVarDefaultError(TypeError):
42+
"""TypeVar default can not be resolved to a type spec.
43+
44+
This exception is raised when an invalid TypeVar default is encountered.
45+
This is most likely a scoping error: https://peps.python.org/pep-0696/#scoping-rules
2946
"""
3047

3148

@@ -42,9 +59,11 @@ class _Future(Generic[_U]):
4259

4360
_done: bool
4461
_result: _U
62+
_default: _U | "_Future[_U]"
4563

46-
def __init__(self) -> None:
64+
def __init__(self, default=NoDefault) -> None:
4765
self._done = False
66+
self._default = default
4867

4968
def done(self) -> bool:
5069
"""Return ``True`` if the value is available"""
@@ -57,6 +76,13 @@ def result(self) -> _U:
5776
"""
5877
if self.done():
5978
return self._result
79+
80+
if self._default is not NoDefault:
81+
if isinstance(self._default, _Future):
82+
return self._default.result()
83+
84+
return self._default
85+
6086
raise InvalidStateError("result has not been set")
6187

6288
def set_result(self, result: _U) -> None:
@@ -120,13 +146,35 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
120146
):
121147
if isinstance(potential_type, TypeVar):
122148
subclass_generic_params_to_args.append((potential_type, future))
149+
default = getattr(potential_type, "__default__", NoDefault)
150+
if default is not None:
151+
future._default = default
123152
else:
124153
future.set_result(potential_type)
125154

126155
args_by_class[subclass] = tuple(subclass_generic_params_to_args)
127156

128157
else:
129-
args_by_class[subclass] = tuple((arg, _Future()) for arg in args)
158+
# PEP-696: Typevar's may be used as defaults, but T1 must be used before T2
159+
# https://peps.python.org/pep-0696/#scoping-rules
160+
seen_type_args: Dict[TypeVar, _Future] = {}
161+
for arg in args:
162+
default = getattr(arg, "__default__", NoDefault)
163+
if default is not None:
164+
if isinstance(default, TypeVar):
165+
if default in seen_type_args:
166+
# We've already seen this TypeVar, Set the default to it's _Future
167+
default = seen_type_args[default]
168+
169+
else:
170+
# We haven't seen this yet, according to PEP-696 this is invalid.
171+
raise InvalidTypeVarDefaultError(
172+
f"{subclass.__name__} has an invalid TypeVar default for field {arg}"
173+
)
174+
175+
seen_type_args[arg] = _Future(default=default)
176+
177+
args_by_class[subclass] = tuple(seen_type_args.items())
130178

131179
parent_class = subclass
132180

tests/test_generics.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing_inspect import is_generic_type
77

88
import marshmallow.fields
9+
import pytest
910
from marshmallow import ValidationError
1011

1112
from marshmallow_dataclass import (
@@ -345,6 +346,165 @@ class OptionalGeneric(typing.Generic[T]):
345346
with self.assertRaises(ValidationError):
346347
schema_s.load({"data": 2})
347348

349+
@pytest.mark.skipif(
350+
sys.version_info <= (3, 13), reason="requires python 3.13 or higher"
351+
)
352+
def test_generic_default(self):
353+
T = typing.TypeVar("T", default=str)
354+
355+
@dataclasses.dataclass
356+
class SimpleGeneric(typing.Generic[T]):
357+
data: T
358+
359+
@dataclasses.dataclass
360+
class NestedFixed:
361+
data: SimpleGeneric[int]
362+
363+
@dataclasses.dataclass
364+
class NestedGeneric(typing.Generic[T]):
365+
data: SimpleGeneric[T]
366+
367+
self.assertTrue(is_generic_alias_of_dataclass(SimpleGeneric[int]))
368+
self.assertFalse(is_generic_alias_of_dataclass(SimpleGeneric))
369+
370+
schema_s = class_schema(SimpleGeneric)()
371+
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
372+
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
373+
with self.assertRaises(ValidationError):
374+
schema_s.load({"data": 2})
375+
376+
schema_nested = class_schema(NestedFixed)()
377+
self.assertEqual(
378+
NestedFixed(data=SimpleGeneric(1)),
379+
schema_nested.load({"data": {"data": 1}}),
380+
)
381+
self.assertEqual(
382+
schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))),
383+
{"data": {"data": 1}},
384+
)
385+
with self.assertRaises(ValidationError):
386+
schema_nested.load({"data": {"data": "str"}})
387+
388+
schema_nested_generic = class_schema(NestedGeneric[int])()
389+
self.assertEqual(
390+
NestedGeneric(data=SimpleGeneric(1)),
391+
schema_nested_generic.load({"data": {"data": 1}}),
392+
)
393+
self.assertEqual(
394+
schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))),
395+
{"data": {"data": 1}},
396+
)
397+
with self.assertRaises(ValidationError):
398+
schema_nested_generic.load({"data": {"data": "str"}})
399+
400+
@pytest.mark.skipif(
401+
sys.version_info <= (3, 13), reason="requires python 3.13 or higher"
402+
)
403+
def test_deep_generic_with_default_overrides(self):
404+
T = typing.TypeVar("T", default=bool)
405+
U = typing.TypeVar("U", default=int)
406+
V = typing.TypeVar("V", default=str)
407+
W = typing.TypeVar("W", default=float)
408+
409+
@dataclasses.dataclass
410+
class TestClass(typing.Generic[T, U, V]):
411+
pairs: typing.List[typing.Tuple[T, U]]
412+
gen: V
413+
override: int
414+
415+
test_schema = class_schema(TestClass)()
416+
assert list(test_schema.fields) == ["pairs", "gen", "override"]
417+
assert isinstance(test_schema.fields["pairs"], marshmallow.fields.List)
418+
assert isinstance(test_schema.fields["pairs"].inner, marshmallow.fields.Tuple)
419+
assert isinstance(
420+
test_schema.fields["pairs"].inner.tuple_fields[0],
421+
marshmallow.fields.Boolean,
422+
)
423+
assert isinstance(
424+
test_schema.fields["pairs"].inner.tuple_fields[1],
425+
marshmallow.fields.Integer,
426+
)
427+
428+
assert isinstance(test_schema.fields["gen"], marshmallow.fields.String)
429+
assert isinstance(test_schema.fields["override"], marshmallow.fields.Integer)
430+
431+
# Don't only override typevar, but switch order to further confuse things
432+
@dataclasses.dataclass
433+
class TestClass2(TestClass[str, W, U]):
434+
override: str # type: ignore # Want to test that it works, even if incompatible types
435+
436+
TestAlias = TestClass2[int, T]
437+
test_schema2 = class_schema(TestClass2)()
438+
assert list(test_schema2.fields) == ["pairs", "gen", "override"]
439+
assert isinstance(test_schema2.fields["pairs"], marshmallow.fields.List)
440+
assert isinstance(test_schema2.fields["pairs"].inner, marshmallow.fields.Tuple)
441+
assert isinstance(
442+
test_schema2.fields["pairs"].inner.tuple_fields[0],
443+
marshmallow.fields.String,
444+
)
445+
assert isinstance(
446+
test_schema2.fields["pairs"].inner.tuple_fields[1],
447+
marshmallow.fields.Float,
448+
)
449+
450+
assert isinstance(test_schema2.fields["gen"], marshmallow.fields.Integer)
451+
assert isinstance(test_schema2.fields["override"], marshmallow.fields.String)
452+
453+
# inherit from alias
454+
@dataclasses.dataclass
455+
class TestClass3(TestAlias[typing.List[int]]):
456+
pass
457+
458+
test_schema3 = class_schema(TestClass3)()
459+
assert list(test_schema3.fields) == ["pairs", "gen", "override"]
460+
assert isinstance(test_schema3.fields["pairs"], marshmallow.fields.List)
461+
assert isinstance(test_schema3.fields["pairs"].inner, marshmallow.fields.Tuple)
462+
assert isinstance(
463+
test_schema3.fields["pairs"].inner.tuple_fields[0],
464+
marshmallow.fields.String,
465+
)
466+
assert isinstance(
467+
test_schema3.fields["pairs"].inner.tuple_fields[1],
468+
marshmallow.fields.Integer,
469+
)
470+
471+
assert isinstance(test_schema3.fields["gen"], marshmallow.fields.List)
472+
assert isinstance(test_schema3.fields["gen"].inner, marshmallow.fields.Integer)
473+
assert isinstance(test_schema3.fields["override"], marshmallow.fields.String)
474+
475+
self.assertEqual(
476+
test_schema3.load(
477+
{"pairs": [("first", "1")], "gen": ["1", 2], "override": "overridden"}
478+
),
479+
TestClass3([("first", 1)], [1, 2], "overridden"),
480+
)
481+
482+
@pytest.mark.skipif(
483+
sys.version_info <= (3, 13), reason="requires python 3.13 or higher"
484+
)
485+
def test_generic_default_recursion(self):
486+
T = typing.TypeVar("T", default=str)
487+
U = typing.TypeVar("U", default=T)
488+
V = typing.TypeVar("V", default=U)
489+
490+
@dataclasses.dataclass
491+
class DefaultGenerics(typing.Generic[T, U, V]):
492+
a: T
493+
b: U
494+
c: V
495+
496+
test_schema = class_schema(DefaultGenerics)()
497+
assert list(test_schema.fields) == ["a", "b", "c"]
498+
assert isinstance(test_schema.fields["a"], marshmallow.fields.String)
499+
assert isinstance(test_schema.fields["b"], marshmallow.fields.String)
500+
assert isinstance(test_schema.fields["c"], marshmallow.fields.String)
501+
502+
test_schema2 = class_schema(DefaultGenerics[int])()
503+
assert list(test_schema2.fields) == ["a", "b", "c"]
504+
assert isinstance(test_schema2.fields["a"], marshmallow.fields.Integer)
505+
assert isinstance(test_schema2.fields["b"], marshmallow.fields.Integer)
506+
assert isinstance(test_schema2.fields["c"], marshmallow.fields.Integer)
507+
348508

349509
if __name__ == "__main__":
350510
unittest.main()

0 commit comments

Comments
 (0)