Skip to content

Commit 387afb3

Browse files
committed
Possible fix for lovasoa#198: memory leak
1 parent 259b119 commit 387afb3

File tree

1 file changed

+91
-50
lines changed

1 file changed

+91
-50
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 91 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class User:
4646
Any,
4747
Callable,
4848
Dict,
49+
Generic,
4950
List,
5051
Mapping,
5152
NewType as typing_NewType,
@@ -79,9 +80,6 @@ class User:
7980
# Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates.
8081
MAX_CLASS_SCHEMA_CACHE_SIZE = 1024
8182

82-
# Recursion guard for class_schema()
83-
_RECURSION_GUARD = threading.local()
84-
8583

8684
@overload
8785
def dataclass(
@@ -352,20 +350,61 @@ def class_schema(
352350
clazz_frame = current_frame.f_back
353351
# Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
354352
del current_frame
355-
_RECURSION_GUARD.seen_classes = {}
356-
try:
357-
return _internal_class_schema(clazz, base_schema, clazz_frame)
358-
finally:
359-
_RECURSION_GUARD.seen_classes.clear()
353+
354+
with _SchemaContext(clazz_frame):
355+
return _internal_class_schema(clazz, base_schema)
356+
357+
358+
class _SchemaContext:
359+
"""Global context for an invocation of class_schema."""
360+
361+
def __init__(self, frame: Optional[types.FrameType]):
362+
self.seen_classes: Dict[type, str] = {}
363+
self.frame = frame
364+
365+
def get_type_hints(self, cls: Type) -> Dict[str, Any]:
366+
frame = self.frame
367+
localns = frame.f_locals if frame is not None else None
368+
return get_type_hints(cls, localns=localns)
369+
370+
def __enter__(self) -> "_SchemaContext":
371+
_schema_ctx_stack.push(self)
372+
return self
373+
374+
def __exit__(
375+
self,
376+
_typ: Optional[Type[BaseException]],
377+
_value: Optional[BaseException],
378+
_tb: Optional[types.TracebackType],
379+
) -> None:
380+
_schema_ctx_stack.pop()
381+
382+
383+
class _LocalStack(threading.local, Generic[_U]):
384+
def __init__(self) -> None:
385+
self.stack: List[_U] = []
386+
387+
def push(self, value: _U) -> None:
388+
self.stack.append(value)
389+
390+
def pop(self) -> None:
391+
self.stack.pop()
392+
393+
@property
394+
def top(self) -> _U:
395+
return self.stack[-1]
396+
397+
398+
_schema_ctx_stack = _LocalStack[_SchemaContext]()
360399

361400

362401
@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
363402
def _internal_class_schema(
364403
clazz: type,
365404
base_schema: Optional[Type[marshmallow.Schema]] = None,
366-
clazz_frame: Optional[types.FrameType] = None,
367405
) -> Type[marshmallow.Schema]:
368-
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
406+
schema_ctx = _schema_ctx_stack.top
407+
schema_ctx.seen_classes[clazz] = clazz.__name__
369408
try:
370409
# noinspection PyDataclass
371410
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
@@ -383,7 +422,7 @@ def _internal_class_schema(
383422
"****** WARNING ******"
384423
)
385424
created_dataclass: type = dataclasses.dataclass(clazz)
386-
return _internal_class_schema(created_dataclass, base_schema, clazz_frame)
425+
return _internal_class_schema(created_dataclass, base_schema)
387426
except Exception as exc:
388427
raise TypeError(
389428
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
@@ -397,18 +436,15 @@ def _internal_class_schema(
397436
}
398437

399438
# Update the schema members to contain marshmallow fields instead of dataclass fields
400-
type_hints = get_type_hints(
401-
clazz, localns=clazz_frame.f_locals if clazz_frame else None
402-
)
439+
type_hints = schema_ctx.get_type_hints(clazz)
403440
attributes.update(
404441
(
405442
field.name,
406-
field_for_schema(
443+
_field_for_schema(
407444
type_hints[field.name],
408445
_get_field_default(field),
409446
field.metadata,
410447
base_schema,
411-
clazz_frame,
412448
),
413449
)
414450
for field in fields
@@ -433,7 +469,6 @@ def _field_by_supertype(
433469
newtype_supertype: Type,
434470
metadata: dict,
435471
base_schema: Optional[Type[marshmallow.Schema]],
436-
typ_frame: Optional[types.FrameType],
437472
) -> marshmallow.fields.Field:
438473
"""
439474
Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -459,12 +494,11 @@ def _field_by_supertype(
459494
if field:
460495
return field(**metadata)
461496
else:
462-
return field_for_schema(
497+
return _field_for_schema(
463498
newtype_supertype,
464499
metadata=metadata,
465500
default=default,
466501
base_schema=base_schema,
467-
typ_frame=typ_frame,
468502
)
469503

470504

@@ -488,7 +522,6 @@ def _generic_type_add_any(typ: type) -> type:
488522
def _field_for_generic_type(
489523
typ: type,
490524
base_schema: Optional[Type[marshmallow.Schema]],
491-
typ_frame: Optional[types.FrameType],
492525
**metadata: Any,
493526
) -> Optional[marshmallow.fields.Field]:
494527
"""
@@ -501,9 +534,7 @@ def _field_for_generic_type(
501534
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
502535

503536
if origin in (list, List):
504-
child_type = field_for_schema(
505-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
506-
)
537+
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
507538
list_type = cast(
508539
Type[marshmallow.fields.List],
509540
type_mapping.get(List, marshmallow.fields.List),
@@ -516,32 +547,25 @@ def _field_for_generic_type(
516547
):
517548
from . import collection_field
518549

519-
child_type = field_for_schema(
520-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
521-
)
550+
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
522551
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
523552
if origin in (set, Set):
524553
from . import collection_field
525554

526-
child_type = field_for_schema(
527-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
528-
)
555+
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
529556
return collection_field.Set(
530557
cls_or_instance=child_type, frozen=False, **metadata
531558
)
532559
if origin in (frozenset, FrozenSet):
533560
from . import collection_field
534561

535-
child_type = field_for_schema(
536-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
537-
)
562+
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
538563
return collection_field.Set(
539564
cls_or_instance=child_type, frozen=True, **metadata
540565
)
541566
if origin in (tuple, Tuple):
542567
children = tuple(
543-
field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame)
544-
for arg in arguments
568+
_field_for_schema(arg, base_schema=base_schema) for arg in arguments
545569
)
546570
tuple_type = cast(
547571
Type[marshmallow.fields.Tuple],
@@ -553,14 +577,11 @@ def _field_for_generic_type(
553577
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
554578
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
555579
return dict_type(
556-
keys=field_for_schema(
557-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
558-
),
559-
values=field_for_schema(
560-
arguments[1], base_schema=base_schema, typ_frame=typ_frame
561-
),
580+
keys=_field_for_schema(arguments[0], base_schema=base_schema),
581+
values=_field_for_schema(arguments[1], base_schema=base_schema),
562582
**metadata,
563583
)
584+
564585
if typing_inspect.is_union_type(typ):
565586
if typing_inspect.is_optional_type(typ):
566587
metadata["allow_none"] = metadata.get("allow_none", True)
@@ -570,23 +591,21 @@ def _field_for_generic_type(
570591
metadata.setdefault("required", False)
571592
subtypes = [t for t in arguments if t is not NoneType] # type: ignore
572593
if len(subtypes) == 1:
573-
return field_for_schema(
594+
return _field_for_schema(
574595
subtypes[0],
575596
metadata=metadata,
576597
base_schema=base_schema,
577-
typ_frame=typ_frame,
578598
)
579599
from . import union_field
580600

581601
return union_field.Union(
582602
[
583603
(
584604
subtyp,
585-
field_for_schema(
605+
_field_for_schema(
586606
subtyp,
587607
metadata={"required": True},
588608
base_schema=base_schema,
589-
typ_frame=typ_frame,
590609
),
591610
)
592611
for subtyp in subtypes
@@ -598,7 +617,7 @@ def _field_for_generic_type(
598617

599618
def field_for_schema(
600619
typ: type,
601-
default=marshmallow.missing,
620+
default: Any = marshmallow.missing,
602621
metadata: Optional[Mapping[str, Any]] = None,
603622
base_schema: Optional[Type[marshmallow.Schema]] = None,
604623
typ_frame: Optional[types.FrameType] = None,
@@ -622,6 +641,29 @@ def field_for_schema(
622641
623642
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
624643
<class 'marshmallow.fields.Url'>
644+
"""
645+
with _SchemaContext(typ_frame):
646+
return _field_for_schema(typ, default, metadata, base_schema)
647+
648+
649+
def _field_for_schema(
650+
typ: type,
651+
default: Any = marshmallow.missing,
652+
metadata: Optional[Mapping[str, Any]] = None,
653+
base_schema: Optional[Type[marshmallow.Schema]] = None,
654+
) -> marshmallow.fields.Field:
655+
"""
656+
Get a marshmallow Field corresponding to the given python type.
657+
The metadata of the dataclass field is used as arguments to the marshmallow Field.
658+
659+
This is an internal version of field_for_schema. It assumes a _SchemaContext
660+
has been pushed onto the local stack.
661+
662+
:param typ: The type for which a field should be generated
663+
:param default: value to use for (de)serialization when the field is missing
664+
:param metadata: Additional parameters to pass to the marshmallow field constructor
665+
:param base_schema: marshmallow schema used as a base class when deriving dataclass schema
666+
625667
"""
626668

627669
metadata = {} if metadata is None else dict(metadata)
@@ -694,10 +736,10 @@ def field_for_schema(
694736
)
695737
else:
696738
subtyp = Any
697-
return field_for_schema(subtyp, default, metadata, base_schema, typ_frame)
739+
return _field_for_schema(subtyp, default, metadata, base_schema)
698740

699741
# Generic types
700-
generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata)
742+
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
701743
if generic_field:
702744
return generic_field
703745

@@ -711,7 +753,6 @@ def field_for_schema(
711753
newtype_supertype=newtype_supertype,
712754
metadata=metadata,
713755
base_schema=base_schema,
714-
typ_frame=typ_frame,
715756
)
716757

717758
# enumerations
@@ -734,8 +775,8 @@ def field_for_schema(
734775
nested = (
735776
nested_schema
736777
or forward_reference
737-
or _RECURSION_GUARD.seen_classes.get(typ)
738-
or _internal_class_schema(typ, base_schema, typ_frame) # type: ignore [arg-type]
778+
or _schema_ctx_stack.top.seen_classes.get(typ)
779+
or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME
739780
)
740781

741782
return marshmallow.fields.Nested(nested, **metadata)

0 commit comments

Comments
 (0)