Skip to content

Commit 5123577

Browse files
committed
Possible fix for #198: memory leak
1 parent 6391003 commit 5123577

File tree

1 file changed

+84
-49
lines changed

1 file changed

+84
-49
lines changed

marshmallow_dataclass/__init__.py

+84-49
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class User:
4646
Any,
4747
Callable,
4848
Dict,
49+
Generic,
4950
List,
5051
Mapping,
5152
NewType as typing_NewType,
@@ -79,9 +80,6 @@ class User:
7980
# Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates.
8081
MAX_CLASS_SCHEMA_CACHE_SIZE = 1024
8182

82-
# Recursion guard for class_schema()
83-
_RECURSION_GUARD = threading.local()
84-
8583

8684
@overload
8785
def dataclass(
@@ -352,20 +350,56 @@ def class_schema(
352350
clazz_frame = current_frame.f_back
353351
# Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
354352
del current_frame
355-
_RECURSION_GUARD.seen_classes = {}
356-
try:
357-
return _internal_class_schema(clazz, base_schema, clazz_frame)
358-
finally:
359-
_RECURSION_GUARD.seen_classes.clear()
353+
354+
with _SchemaContext(clazz_frame):
355+
return _internal_class_schema(clazz, base_schema)
356+
357+
358+
class _SchemaContext:
359+
"""Global context for an invocation of class_schema."""
360+
361+
def __init__(self, frame: Optional[types.FrameType]):
362+
self.seen_classes: Dict[Type, str] = {}
363+
self.frame = frame
364+
365+
def get_type_hints(self, cls: Type) -> Dict[str, Any]:
366+
frame = self.frame
367+
localns = frame.f_locals if frame is not None else None
368+
return get_type_hints(cls, localns=localns)
369+
370+
def __enter__(self) -> "_SchemaContext":
371+
_schema_ctx_stack.push(self)
372+
return self
373+
374+
def __exit__(self, _typ, _value, _tb) -> None:
375+
_schema_ctx_stack.pop()
376+
377+
378+
class _LocalStack(threading.local, Generic[_U]):
379+
def __init__(self):
380+
self.stack: List[_U] = []
381+
382+
def push(self, value: _U) -> None:
383+
self.stack.append(value)
384+
385+
def pop(self) -> None:
386+
self.stack.pop()
387+
388+
@property
389+
def top(self) -> _U:
390+
return self.stack[-1]
391+
392+
393+
_schema_ctx_stack = _LocalStack[_SchemaContext]()
360394

361395

362396
@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
363397
def _internal_class_schema(
364398
clazz: type,
365399
base_schema: Optional[Type[marshmallow.Schema]] = None,
366-
clazz_frame: types.FrameType = None,
367400
) -> Type[marshmallow.Schema]:
368-
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
401+
schema_ctx = _schema_ctx_stack.top
402+
schema_ctx.seen_classes[clazz] = clazz.__name__
369403
try:
370404
# noinspection PyDataclass
371405
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
@@ -383,7 +417,7 @@ def _internal_class_schema(
383417
"****** WARNING ******"
384418
)
385419
created_dataclass: type = dataclasses.dataclass(clazz)
386-
return _internal_class_schema(created_dataclass, base_schema, clazz_frame)
420+
return _internal_class_schema(created_dataclass, base_schema)
387421
except Exception:
388422
raise TypeError(
389423
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
@@ -397,18 +431,15 @@ def _internal_class_schema(
397431
}
398432

399433
# Update the schema members to contain marshmallow fields instead of dataclass fields
400-
type_hints = get_type_hints(
401-
clazz, localns=clazz_frame.f_locals if clazz_frame else None
402-
)
434+
type_hints = schema_ctx.get_type_hints(clazz)
403435
attributes.update(
404436
(
405437
field.name,
406-
field_for_schema(
438+
_field_for_schema(
407439
type_hints[field.name],
408440
_get_field_default(field),
409441
field.metadata,
410442
base_schema,
411-
clazz_frame,
412443
),
413444
)
414445
for field in fields
@@ -433,7 +464,6 @@ def _field_by_supertype(
433464
newtype_supertype: Type,
434465
metadata: dict,
435466
base_schema: Optional[Type[marshmallow.Schema]],
436-
typ_frame: Optional[types.FrameType],
437467
) -> marshmallow.fields.Field:
438468
"""
439469
Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -459,12 +489,11 @@ def _field_by_supertype(
459489
if field:
460490
return field(**metadata)
461491
else:
462-
return field_for_schema(
492+
return _field_for_schema(
463493
newtype_supertype,
464494
metadata=metadata,
465495
default=default,
466496
base_schema=base_schema,
467-
typ_frame=typ_frame,
468497
)
469498

470499

@@ -488,7 +517,6 @@ def _generic_type_add_any(typ: type) -> type:
488517
def _field_for_generic_type(
489518
typ: type,
490519
base_schema: Optional[Type[marshmallow.Schema]],
491-
typ_frame: Optional[types.FrameType],
492520
**metadata: Any,
493521
) -> Optional[marshmallow.fields.Field]:
494522
"""
@@ -501,9 +529,7 @@ def _field_for_generic_type(
501529
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
502530

503531
if origin in (list, List):
504-
child_type = field_for_schema(
505-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
506-
)
532+
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
507533
list_type = cast(
508534
Type[marshmallow.fields.List],
509535
type_mapping.get(List, marshmallow.fields.List),
@@ -512,32 +538,25 @@ def _field_for_generic_type(
512538
if origin in (collections.abc.Sequence, Sequence):
513539
from . import collection_field
514540

515-
child_type = field_for_schema(
516-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
517-
)
541+
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
518542
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
519543
if origin in (set, Set):
520544
from . import collection_field
521545

522-
child_type = field_for_schema(
523-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
524-
)
546+
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
525547
return collection_field.Set(
526548
cls_or_instance=child_type, frozen=False, **metadata
527549
)
528550
if origin in (frozenset, FrozenSet):
529551
from . import collection_field
530552

531-
child_type = field_for_schema(
532-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
533-
)
553+
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
534554
return collection_field.Set(
535555
cls_or_instance=child_type, frozen=True, **metadata
536556
)
537557
if origin in (tuple, Tuple):
538558
children = tuple(
539-
field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame)
540-
for arg in arguments
559+
_field_for_schema(arg, base_schema=base_schema) for arg in arguments
541560
)
542561
tuple_type = cast(
543562
Type[marshmallow.fields.Tuple],
@@ -549,12 +568,8 @@ def _field_for_generic_type(
549568
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
550569
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
551570
return dict_type(
552-
keys=field_for_schema(
553-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
554-
),
555-
values=field_for_schema(
556-
arguments[1], base_schema=base_schema, typ_frame=typ_frame
557-
),
571+
keys=_field_for_schema(arguments[0], base_schema=base_schema),
572+
values=_field_for_schema(arguments[1], base_schema=base_schema),
558573
**metadata,
559574
)
560575
elif typing_inspect.is_union_type(typ):
@@ -566,23 +581,21 @@ def _field_for_generic_type(
566581
metadata.setdefault("required", False)
567582
subtypes = [t for t in arguments if t is not NoneType] # type: ignore
568583
if len(subtypes) == 1:
569-
return field_for_schema(
584+
return _field_for_schema(
570585
subtypes[0],
571586
metadata=metadata,
572587
base_schema=base_schema,
573-
typ_frame=typ_frame,
574588
)
575589
from . import union_field
576590

577591
return union_field.Union(
578592
[
579593
(
580594
subtyp,
581-
field_for_schema(
595+
_field_for_schema(
582596
subtyp,
583597
metadata={"required": True},
584598
base_schema=base_schema,
585-
typ_frame=typ_frame,
586599
),
587600
)
588601
for subtyp in subtypes
@@ -618,6 +631,29 @@ def field_for_schema(
618631
619632
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
620633
<class 'marshmallow.fields.Url'>
634+
"""
635+
with _SchemaContext(typ_frame):
636+
return _field_for_schema(typ, default, metadata, base_schema)
637+
638+
639+
def _field_for_schema(
640+
typ: type,
641+
default=marshmallow.missing,
642+
metadata: Mapping[str, Any] = None,
643+
base_schema: Optional[Type[marshmallow.Schema]] = None,
644+
) -> marshmallow.fields.Field:
645+
"""
646+
Get a marshmallow Field corresponding to the given python type.
647+
The metadata of the dataclass field is used as arguments to the marshmallow Field.
648+
649+
This is an internal version of field_for_schema. It assumes a _SchemaContext
650+
has been pushed onto the local stack.
651+
652+
:param typ: The type for which a field should be generated
653+
:param default: value to use for (de)serialization when the field is missing
654+
:param metadata: Additional parameters to pass to the marshmallow field constructor
655+
:param base_schema: marshmallow schema used as a base class when deriving dataclass schema
656+
621657
"""
622658

623659
metadata = {} if metadata is None else dict(metadata)
@@ -690,10 +726,10 @@ def field_for_schema(
690726
)
691727
else:
692728
subtyp = Any
693-
return field_for_schema(subtyp, default, metadata, base_schema, typ_frame)
729+
return _field_for_schema(subtyp, default, metadata, base_schema)
694730

695731
# Generic types
696-
generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata)
732+
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
697733
if generic_field:
698734
return generic_field
699735

@@ -707,7 +743,6 @@ def field_for_schema(
707743
newtype_supertype=newtype_supertype,
708744
metadata=metadata,
709745
base_schema=base_schema,
710-
typ_frame=typ_frame,
711746
)
712747

713748
# enumerations
@@ -726,8 +761,8 @@ def field_for_schema(
726761
nested = (
727762
nested_schema
728763
or forward_reference
729-
or _RECURSION_GUARD.seen_classes.get(typ)
730-
or _internal_class_schema(typ, base_schema, typ_frame)
764+
or _schema_ctx_stack.top.seen_classes.get(typ)
765+
or _internal_class_schema(typ, base_schema)
731766
)
732767

733768
return marshmallow.fields.Nested(nested, **metadata)

0 commit comments

Comments
 (0)