Skip to content

Commit 374f524

Browse files
committed
🐛 Getting generic schema more than once
1 parent 27d985f commit 374f524

File tree

4 files changed

+110
-11
lines changed

4 files changed

+110
-11
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# marshmallow\_dataclass2 change log
22

3+
## v8.8.1 (2025-02-01)
4+
5+
- Update Readme
6+
- Fix getting generic schema more than once
7+
38
## v8.8.0 (2025-02-01)
49

510
- Drop support for python 3.8

marshmallow_dataclass2/__init__.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,29 @@ class User:
105105
MAX_CLASS_SCHEMA_CACHE_SIZE = 1024
106106

107107

108+
class LazyGenericSchema:
109+
"""Exists to cache generic instances"""
110+
111+
def __init__(self, base_schema, frame):
112+
self.base_schema = base_schema
113+
self.frame = frame
114+
115+
self.__resolved_generic_schemas = {}
116+
117+
def get_schema(self, instance):
118+
instance_args = get_args(instance)
119+
schema = self.__resolved_generic_schemas.get(instance_args)
120+
if schema is None:
121+
schema = class_schema(
122+
instance,
123+
self.base_schema,
124+
self.frame,
125+
)
126+
self.__resolved_generic_schemas[instance_args] = schema
127+
128+
return schema
129+
130+
108131
def _maybe_get_callers_frame(
109132
cls: type, stacklevel: int = 1
110133
) -> Optional[types.FrameType]:
@@ -296,12 +319,17 @@ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
296319
else:
297320
frame = _maybe_get_callers_frame(clazz, stacklevel=stacklevel)
298321

299-
# noinspection PyTypeHints
300-
clazz.Schema = lazy_class_attribute( # type: ignore
301-
partial(class_schema, clazz, base_schema, frame),
302-
"Schema",
303-
clazz.__name__,
304-
)
322+
if not typing_inspect.is_generic_type(clazz):
323+
# noinspection PyTypeHints
324+
clazz.Schema = lazy_class_attribute( # type: ignore
325+
partial(class_schema, clazz, base_schema, frame),
326+
"Schema",
327+
clazz.__name__,
328+
)
329+
else:
330+
# noinspection PyTypeHints
331+
clazz.Schema = LazyGenericSchema(base_schema, frame) # type: ignore
332+
305333
return clazz
306334

307335
if _cls is None:
@@ -999,10 +1027,15 @@ def _field_for_schema(
9991027
forward_reference = getattr(typ, "__forward_arg__", None)
10001028

10011029
nested = (
1002-
nested_schema
1003-
or forward_reference
1004-
or _schema_ctx_stack.top.seen_classes.get(typ)
1005-
or _internal_class_schema(typ, base_schema) # type: ignore [arg-type]
1030+
# Pass the type instance. This is required for generics
1031+
nested_schema.get_schema(typ)
1032+
if isinstance(nested_schema, LazyGenericSchema)
1033+
else (
1034+
nested_schema
1035+
or forward_reference
1036+
or _schema_ctx_stack.top.seen_classes.get(typ)
1037+
or _internal_class_schema(typ, base_schema) # type: ignore [arg-type]
1038+
)
10061039
)
10071040

10081041
return marshmallow.fields.Nested(nested, **metadata)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import find_packages, setup
22

3-
VERSION = "8.8.0"
3+
VERSION = "8.8.1"
44

55
CLASSIFIERS = [
66
"Development Status :: 4 - Beta",

tests/test_generics.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,67 @@ class NestedGeneric(typing.Generic[T]):
110110
with self.assertRaises(ValidationError):
111111
schema_nested_generic.load({"data": {"data": "str"}})
112112

113+
def test_generic_dataclass_cached(self):
114+
T = typing.TypeVar("T")
115+
116+
@dataclass
117+
class SimpleGeneric(typing.Generic[T]):
118+
data1: T
119+
120+
@dataclass
121+
class NestedFixed:
122+
data2: SimpleGeneric[int]
123+
124+
@dataclass
125+
class NestedGeneric(typing.Generic[T]):
126+
data3: SimpleGeneric[T]
127+
128+
self.assertTrue(is_generic_alias_of_dataclass(SimpleGeneric[int]))
129+
self.assertFalse(is_generic_alias_of_dataclass(SimpleGeneric))
130+
131+
schema_s = class_schema(SimpleGeneric[str])()
132+
self.assertEqual(SimpleGeneric(data1="a"), schema_s.load({"data1": "a"}))
133+
self.assertEqual(schema_s.dump(SimpleGeneric(data1="a")), {"data1": "a"})
134+
with self.assertRaises(ValidationError):
135+
schema_s.load({"data1": 2})
136+
137+
schema_nested = class_schema(NestedFixed)()
138+
self.assertEqual(
139+
NestedFixed(data2=SimpleGeneric(1)),
140+
schema_nested.load({"data2": {"data1": 1}}),
141+
)
142+
self.assertEqual(
143+
schema_nested.dump(NestedFixed(data2=SimpleGeneric(data1=1))),
144+
{"data2": {"data1": 1}},
145+
)
146+
with self.assertRaises(ValidationError):
147+
schema_nested.load({"data2": {"data1": "str"}})
148+
149+
schema_nested_generic = class_schema(NestedGeneric[int])()
150+
self.assertEqual(
151+
NestedGeneric(data3=SimpleGeneric(1)),
152+
schema_nested_generic.load({"data3": {"data1": 1}}),
153+
)
154+
self.assertEqual(
155+
schema_nested_generic.dump(NestedGeneric(data3=SimpleGeneric(data1=1))),
156+
{"data3": {"data1": 1}},
157+
)
158+
with self.assertRaises(ValidationError):
159+
schema_nested_generic.load({"data3": {"data1": "str"}})
160+
161+
# Copy test again so that we trigger a cache hit
162+
schema_nested_generic = class_schema(NestedGeneric[int])()
163+
self.assertEqual(
164+
NestedGeneric(data3=SimpleGeneric(1)),
165+
schema_nested_generic.load({"data3": {"data1": 1}}),
166+
)
167+
self.assertEqual(
168+
schema_nested_generic.dump(NestedGeneric(data3=SimpleGeneric(data1=1))),
169+
{"data3": {"data1": 1}},
170+
)
171+
with self.assertRaises(ValidationError):
172+
schema_nested_generic.load({"data3": {"data1": "str"}})
173+
113174
def test_generic_dataclass_repeated_fields(self):
114175
T = typing.TypeVar("T")
115176

0 commit comments

Comments
 (0)