diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 7312374..1fe9813 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -47,6 +47,7 @@ class User: Any, Callable, Dict, + Generic, List, Mapping, NewType as typing_NewType, @@ -90,8 +91,50 @@ def dataclass_transform(**kwargs): # Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates. MAX_CLASS_SCHEMA_CACHE_SIZE = 1024 -# Recursion guard for class_schema() -_RECURSION_GUARD = threading.local() + +def _maybe_get_callers_frame( + cls: type, stacklevel: int = 1 +) -> Optional[types.FrameType]: + """Return the caller's frame, but only if it will help resolve forward type references. + + We sometimes need the caller's frame to get access to the caller's + local namespace in order to be able to resolve forward type + references in dataclasses. + + Notes + ----- + + If the caller's locals are the same as the dataclass' module + globals — this is the case for the common case of dataclasses + defined at the module top-level — we don't need the locals. + (Typing.get_type_hints() knows how to check the class module + globals on its own.) + + In that case, we don't need the caller's frame. Not holding a + reference to the frame in our our lazy ``.Scheme`` class attribute + is a significant win, memory-wise. + + """ + try: + frame = inspect.currentframe() + for _ in range(stacklevel + 1): + if frame is None: + return None + frame = frame.f_back + + if frame is None: + return None + + globalns = getattr(sys.modules.get(cls.__module__), "__dict__", None) + if frame.f_locals is globalns: + # Locals are the globals + return None + + return frame + + finally: + # Paranoia, per https://docs.python.org/3/library/inspect.html#the-interpreter-stack + del frame @overload @@ -137,6 +180,7 @@ def dataclass( frozen: bool = False, base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, ) -> Union[Type[_U], Callable[[Type[_U]], Type[_U]]]: """ This decorator does the same as dataclasses.dataclass, but also applies :func:`add_schema`. @@ -163,19 +207,18 @@ def dataclass( >>> Point.Schema().load({'x':0, 'y':0}) # This line can be statically type checked Point(x=0.0, y=0.0) """ - # dataclass's typing doesn't expect it to be called as a function, so ignore type check - dc = dataclasses.dataclass( # type: ignore - _cls, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen + dc = dataclasses.dataclass( + repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen ) - if not cls_frame: - current_frame = inspect.currentframe() - if current_frame: - cls_frame = current_frame.f_back - # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack - del current_frame + + def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + return add_schema( + dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 + ) + if _cls is None: - return lambda cls: add_schema(dc(cls), base_schema, cls_frame=cls_frame) - return add_schema(dc, base_schema, cls_frame=cls_frame) + return decorator + return decorator(_cls, stacklevel=stacklevel + 1) @overload @@ -195,11 +238,12 @@ def add_schema( _cls: Type[_U], base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, ) -> Type[_U]: ... -def add_schema(_cls=None, base_schema=None, cls_frame=None): +def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): """ This decorator adds a marshmallow schema as the 'Schema' attribute in a dataclass. It uses :func:`class_schema` internally. @@ -221,22 +265,55 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None): Artist(names=('Martin', 'Ramirez')) """ - def decorator(clazz: Type[_U]) -> Type[_U]: + def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]: + if cls_frame is not None: + frame = cls_frame + else: + frame = _maybe_get_callers_frame(clazz, stacklevel=stacklevel) + # noinspection PyTypeHints clazz.Schema = lazy_class_attribute( # type: ignore - partial(class_schema, clazz, base_schema, cls_frame), + partial(class_schema, clazz, base_schema, frame), "Schema", clazz.__name__, ) return clazz - return decorator(_cls) if _cls else decorator + if _cls is None: + return decorator + return decorator(_cls, stacklevel=stacklevel + 1) +@overload +def class_schema( + clazz: type, + base_schema: Optional[Type[marshmallow.Schema]] = None, + *, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, +) -> Type[marshmallow.Schema]: + ... + + +@overload def class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, clazz_frame: Optional[types.FrameType] = None, + *, + globalns: Optional[Dict[str, Any]] = None, +) -> Type[marshmallow.Schema]: + ... + + +def class_schema( + clazz: type, + base_schema: Optional[Type[marshmallow.Schema]] = None, + # FIXME: delete clazz_frame from API? + clazz_frame: Optional[types.FrameType] = None, + *, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, ) -> Type[marshmallow.Schema]: """ Convert a class to a marshmallow schema @@ -376,26 +453,65 @@ def class_schema( """ if not dataclasses.is_dataclass(clazz): clazz = dataclasses.dataclass(clazz) - if not clazz_frame: - current_frame = inspect.currentframe() - if current_frame: - clazz_frame = current_frame.f_back - # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack - del current_frame - _RECURSION_GUARD.seen_classes = {} - try: - return _internal_class_schema(clazz, base_schema, clazz_frame) - finally: - _RECURSION_GUARD.seen_classes.clear() + if localns is None: + if clazz_frame is None: + clazz_frame = _maybe_get_callers_frame(clazz) + if clazz_frame is not None: + localns = clazz_frame.f_locals + with _SchemaContext(globalns, localns): + return _internal_class_schema(clazz, base_schema) + + +class _SchemaContext: + """Global context for an invocation of class_schema.""" + + def __init__( + self, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, + ): + self.seen_classes: Dict[type, str] = {} + self.globalns = globalns + self.localns = localns + + def __enter__(self) -> "_SchemaContext": + _schema_ctx_stack.push(self) + return self + + def __exit__( + self, + _typ: Optional[Type[BaseException]], + _value: Optional[BaseException], + _tb: Optional[types.TracebackType], + ) -> None: + _schema_ctx_stack.pop() + + +class _LocalStack(threading.local, Generic[_U]): + def __init__(self) -> None: + self.stack: List[_U] = [] + + def push(self, value: _U) -> None: + self.stack.append(value) + + def pop(self) -> None: + self.stack.pop() + + @property + def top(self) -> _U: + return self.stack[-1] + + +_schema_ctx_stack = _LocalStack[_SchemaContext]() @lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _internal_class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, - clazz_frame: Optional[types.FrameType] = None, ) -> Type[marshmallow.Schema]: - _RECURSION_GUARD.seen_classes[clazz] = clazz.__name__ + schema_ctx = _schema_ctx_stack.top + schema_ctx.seen_classes[clazz] = clazz.__name__ try: # noinspection PyDataclass fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) @@ -413,7 +529,7 @@ def _internal_class_schema( "****** WARNING ******" ) created_dataclass: type = dataclasses.dataclass(clazz) - return _internal_class_schema(created_dataclass, base_schema, clazz_frame) + return _internal_class_schema(created_dataclass, base_schema) except Exception as exc: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." @@ -431,17 +547,16 @@ def _internal_class_schema( # Update the schema members to contain marshmallow fields instead of dataclass fields type_hints = get_type_hints( - clazz, localns=clazz_frame.f_locals if clazz_frame else None + clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns ) attributes.update( ( field.name, - field_for_schema( + _field_for_schema( type_hints[field.name], _get_field_default(field), field.metadata, base_schema, - clazz_frame, ), ) for field in fields @@ -466,7 +581,6 @@ def _field_by_supertype( newtype_supertype: Type, metadata: dict, base_schema: Optional[Type[marshmallow.Schema]], - typ_frame: Optional[types.FrameType], ) -> marshmallow.fields.Field: """ Return a new field for fields based on a super field. (Usually spawned from NewType) @@ -492,12 +606,11 @@ def _field_by_supertype( if field: return field(**metadata) else: - return field_for_schema( + return _field_for_schema( newtype_supertype, metadata=metadata, default=default, base_schema=base_schema, - typ_frame=typ_frame, ) @@ -521,7 +634,6 @@ def _generic_type_add_any(typ: type) -> type: def _field_for_generic_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], - typ_frame: Optional[types.FrameType], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ @@ -534,9 +646,7 @@ def _field_for_generic_type( type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) list_type = cast( Type[marshmallow.fields.List], type_mapping.get(List, marshmallow.fields.List), @@ -549,32 +659,25 @@ def _field_for_generic_type( ): from . import collection_field - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Sequence(cls_or_instance=child_type, **metadata) if origin in (set, Set): from . import collection_field - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata ) if origin in (frozenset, FrozenSet): from . import collection_field - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata ) if origin in (tuple, Tuple): children = tuple( - field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame) - for arg in arguments + _field_for_schema(arg, base_schema=base_schema) for arg in arguments ) tuple_type = cast( Type[marshmallow.fields.Tuple], @@ -586,14 +689,11 @@ def _field_for_generic_type( elif origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( - keys=field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ), - values=field_for_schema( - arguments[1], base_schema=base_schema, typ_frame=typ_frame - ), + keys=_field_for_schema(arguments[0], base_schema=base_schema), + values=_field_for_schema(arguments[1], base_schema=base_schema), **metadata, ) + if typing_inspect.is_union_type(typ): if typing_inspect.is_optional_type(typ): metadata["allow_none"] = metadata.get("allow_none", True) @@ -603,11 +703,10 @@ def _field_for_generic_type( metadata.setdefault("required", False) subtypes = [t for t in arguments if t is not NoneType] # type: ignore if len(subtypes) == 1: - return field_for_schema( + return _field_for_schema( subtypes[0], metadata=metadata, base_schema=base_schema, - typ_frame=typ_frame, ) from . import union_field @@ -615,11 +714,10 @@ def _field_for_generic_type( [ ( subtyp, - field_for_schema( + _field_for_schema( subtyp, metadata={"required": True}, base_schema=base_schema, - typ_frame=typ_frame, ), ) for subtyp in subtypes @@ -631,9 +729,10 @@ def _field_for_generic_type( def field_for_schema( typ: type, - default=marshmallow.missing, + default: Any = marshmallow.missing, metadata: Optional[Mapping[str, Any]] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, + # FIXME: delete typ_frame from API? typ_frame: Optional[types.FrameType] = None, ) -> marshmallow.fields.Field: """ @@ -655,6 +754,29 @@ def field_for_schema( >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ + """ + with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None): + return _field_for_schema(typ, default, metadata, base_schema) + + +def _field_for_schema( + typ: type, + default: Any = marshmallow.missing, + metadata: Optional[Mapping[str, Any]] = None, + base_schema: Optional[Type[marshmallow.Schema]] = None, +) -> marshmallow.fields.Field: + """ + Get a marshmallow Field corresponding to the given python type. + The metadata of the dataclass field is used as arguments to the marshmallow Field. + + This is an internal version of field_for_schema. It assumes a _SchemaContext + has been pushed onto the local stack. + + :param typ: The type for which a field should be generated + :param default: value to use for (de)serialization when the field is missing + :param metadata: Additional parameters to pass to the marshmallow field constructor + :param base_schema: marshmallow schema used as a base class when deriving dataclass schema + """ metadata = {} if metadata is None else dict(metadata) @@ -727,10 +849,10 @@ def field_for_schema( ) else: subtyp = Any - return field_for_schema(subtyp, default, metadata, base_schema, typ_frame) + return _field_for_schema(subtyp, default, metadata, base_schema) # Generic types - generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata) + generic_field = _field_for_generic_type(typ, base_schema, **metadata) if generic_field: return generic_field @@ -744,7 +866,6 @@ def field_for_schema( newtype_supertype=newtype_supertype, metadata=metadata, base_schema=base_schema, - typ_frame=typ_frame, ) # enumerations @@ -767,8 +888,8 @@ def field_for_schema( nested = ( nested_schema or forward_reference - or _RECURSION_GUARD.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema, typ_frame) # type: ignore [arg-type] + or _schema_ctx_stack.top.seen_classes.get(typ) + or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME ) return marshmallow.fields.Nested(nested, **metadata) diff --git a/tests/test_forward_references.py b/tests/test_forward_references.py index fc05b12..2a2fa96 100644 --- a/tests/test_forward_references.py +++ b/tests/test_forward_references.py @@ -133,3 +133,19 @@ class B: B.Schema().load(dict(a=dict(c=1))) # marshmallow.exceptions.ValidationError: # {'a': {'d': ['Missing data for required field.'], 'c': ['Unknown field.']}} + + def test_locals_from_decoration_ns(self): + # Test that locals are picked-up at decoration-time rather + # than when the decorator is constructed. + @frozen_dataclass + class A: + b: "B" + + @frozen_dataclass + class B: + x: int + + assert A.Schema().load({"b": {"x": 42}}) == A(b=B(x=42)) + + +frozen_dataclass = dataclass(frozen=True) diff --git a/tests/test_memory_leak.py b/tests/test_memory_leak.py new file mode 100644 index 0000000..306430e --- /dev/null +++ b/tests/test_memory_leak.py @@ -0,0 +1,140 @@ +import gc +import inspect +import sys +import unittest +import weakref +from dataclasses import dataclass +from unittest import mock + +import marshmallow +import marshmallow_dataclass as md + + +class Referenceable: + pass + + +class TestMemoryLeak(unittest.TestCase): + """Test for memory leaks as decribed in `#198`_. + + .. _#198: https://github.com/lovasoa/marshmallow_dataclass/issues/198 + """ + + def setUp(self): + gc.collect() + gc.disable() + self.frame_collected = False + + def tearDown(self): + gc.enable() + + def trackFrame(self): + """Create a tracked local variable in the callers frame. + + We track these locals in the WeakSet self.livingLocals. + + When the callers frame is freed, the locals will be GCed as well. + In this way we can check that the callers frame has been collected. + """ + local = Referenceable() + weakref.finalize(local, self._set_frame_collected) + try: + frame = inspect.currentframe() + frame.f_back.f_locals["local_variable"] = local + finally: + del frame + + def _set_frame_collected(self): + self.frame_collected = True + + def assertFrameCollected(self): + """Check that all locals created by makeLocal have been GCed""" + if not hasattr(sys, "getrefcount"): + # pypy does not do reference counting + gc.collect(0) + self.assertTrue(self.frame_collected) + + def test_sanity(self): + """Test that our scheme for detecting leaked frames works.""" + frames = [] + + def f(): + frames.append(inspect.currentframe()) + self.trackFrame() + + f() + + gc.collect(0) + self.assertFalse( + self.frame_collected + ) # with frame leaked, f's locals are still alive + frames.clear() + self.assertFrameCollected() + + def test_class_schema(self): + def f(): + @dataclass + class Foo: + value: int + + md.class_schema(Foo) + + self.trackFrame() + + f() + self.assertFrameCollected() + + def test_md_dataclass_lazy_schema(self): + def f(): + @md.dataclass + class Foo: + value: int + + self.trackFrame() + + f() + # NB: The "lazy" Foo.Schema attribute descriptor holds a reference to f's frame, + # which, in turn, holds a reference to class Foo, thereby creating ref cycle. + # So, a gc pass is required to clean that up. + gc.collect(0) + self.assertFrameCollected() + + def test_md_dataclass(self): + def f(): + @md.dataclass + class Foo: + value: int + + self.assertIsInstance(Foo.Schema(), marshmallow.Schema) + self.trackFrame() + + f() + self.assertFrameCollected() + + def assertDecoratorDoesNotLeakFrame(self, decorator): + def f() -> None: + class Foo: + value: int + + self.trackFrame() + with self.assertRaisesRegex(Exception, "forced exception"): + decorator(Foo) + + with mock.patch( + "marshmallow_dataclass.lazy_class_attribute", + side_effect=Exception("forced exception"), + ) as m: + f() + + assert m.mock_calls == [mock.call(mock.ANY, "Schema", mock.ANY)] + # NB: The Mock holds a reference to its arguments, one of which is the + # lazy_class_attribute which holds a reference to the caller's frame + m.reset_mock() + + self.assertFrameCollected() + + def test_exception_in_dataclass(self): + self.assertDecoratorDoesNotLeakFrame(md.dataclass) + + def test_exception_in_add_schema(self): + self.assertDecoratorDoesNotLeakFrame(md.add_schema) diff --git a/tests/test_mypy.yml b/tests/test_mypy.yml index d4f9c86..479abd5 100644 --- a/tests/test_mypy.yml +++ b/tests/test_mypy.yml @@ -6,7 +6,7 @@ follow_imports = silent plugins = marshmallow_dataclass.mypy show_error_codes = true - python_version = 3.6 + python_version = 3.8 env: - PYTHONPATH=. main: |