Skip to content

Commit 264e355

Browse files
committed
refactor: disuse lru_cache in favor of an _LRUDict
1 parent ab52c66 commit 264e355

File tree

2 files changed

+102
-15
lines changed

2 files changed

+102
-15
lines changed

marshmallow_dataclass/__init__.py

+67-15
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class User:
4242
import types
4343
import warnings
4444
from enum import Enum
45-
from functools import lru_cache, partial
45+
from functools import partial
4646
from typing import (
4747
Any,
4848
Callable,
@@ -107,8 +107,12 @@ def get_origin(tp):
107107

108108

109109
if sys.version_info >= (3, 7):
110+
from typing import OrderedDict
111+
110112
TypeVar_ = TypeVar
111113
else:
114+
from typing_extensions import OrderedDict
115+
112116
TypeVar_ = type
113117

114118
if sys.version_info >= (3, 10):
@@ -119,9 +123,9 @@ def get_origin(tp):
119123

120124
NoneType = type(None)
121125
_U = TypeVar("_U")
126+
_V = TypeVar("_V")
122127
_Field = TypeVar("_Field", bound=marshmallow.fields.Field)
123128

124-
125129
# Whitelist of dataclass members that will be copied to generated schema.
126130
MEMBERS_WHITELIST: Set[str] = {"Meta"}
127131

@@ -490,13 +494,59 @@ def class_schema(
490494
clazz_frame = _maybe_get_callers_frame(clazz)
491495
if clazz_frame is not None:
492496
localns = clazz_frame.f_locals
493-
with _SchemaContext(globalns, localns):
494-
schema = _internal_class_schema(clazz, base_schema)
497+
498+
if base_schema is None:
499+
base_schema = marshmallow.Schema
500+
501+
with _SchemaContext(globalns, localns, base_schema):
502+
schema = _internal_class_schema(clazz)
495503

496504
assert not isinstance(schema, _Future)
497505
return schema
498506

499507

508+
class _LRUDict(OrderedDict[_U, _V]):
509+
"""Limited-length dict which discards LRU entries."""
510+
511+
def __init__(self, maxsize: int = 128):
512+
self.maxsize = maxsize
513+
super().__init__()
514+
515+
def __setitem__(self, key: _U, value: _V) -> None:
516+
super().__setitem__(key, value)
517+
super().move_to_end(key)
518+
519+
while len(self) > self.maxsize:
520+
oldkey = next(iter(self))
521+
super().__delitem__(oldkey)
522+
523+
def __getitem__(self, key: _U) -> _V:
524+
val = super().__getitem__(key)
525+
super().move_to_end(key)
526+
return val
527+
528+
_T = TypeVar("_T")
529+
530+
@overload
531+
def get(self, key: _U) -> Optional[_V]:
532+
...
533+
534+
@overload
535+
def get(self, key: _U, default: _T) -> Union[_V, _T]:
536+
...
537+
538+
def get(self, key: _U, default: Any = None) -> Any:
539+
try:
540+
return self.__getitem__(key)
541+
except KeyError:
542+
return default
543+
544+
545+
_schema_cache = _LRUDict[Hashable, Type[marshmallow.Schema]](
546+
MAX_CLASS_SCHEMA_CACHE_SIZE
547+
)
548+
549+
500550
class InvalidStateError(Exception):
501551
"""Raised when an operation is performed on a future that is not
502552
allowed in the current state.
@@ -597,7 +647,7 @@ class _SchemaContext:
597647

598648
globalns: Optional[Dict[str, Any]] = None
599649
localns: Optional[Dict[str, Any]] = None
600-
base_schema: Optional[Type[marshmallow.Schema]] = None
650+
base_schema: Type[marshmallow.Schema] = marshmallow.Schema
601651
generic_args: Optional[_GenericArgs] = None
602652
seen_classes: Dict[type, _Future[Type[marshmallow.Schema]]] = dataclasses.field(
603653
default_factory=dict
@@ -612,8 +662,6 @@ def get_type_mapping(
612662
all bases in base_schema's MRO.
613663
"""
614664
base_schema = self.base_schema
615-
if base_schema is None:
616-
base_schema = marshmallow.Schema
617665
if use_mro:
618666
return ChainMap(
619667
*(getattr(cls, "TYPE_MAPPING", {}) for cls in base_schema.__mro__)
@@ -651,15 +699,19 @@ def top(self) -> _U:
651699
_schema_ctx_stack = _LocalStack[_SchemaContext]()
652700

653701

654-
@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
655702
def _internal_class_schema(
656703
clazz: type,
657-
base_schema: Optional[Type[marshmallow.Schema]] = None,
658704
) -> Union[Type[marshmallow.Schema], _Future[Type[marshmallow.Schema]]]:
659705
schema_ctx = _schema_ctx_stack.top
660706
if clazz in schema_ctx.seen_classes:
661707
return schema_ctx.seen_classes[clazz]
662708

709+
cache_key = clazz, schema_ctx.base_schema
710+
try:
711+
return _schema_cache[cache_key]
712+
except KeyError:
713+
pass
714+
663715
future: _Future[Type[marshmallow.Schema]] = _Future()
664716
schema_ctx.seen_classes[clazz] = future
665717

@@ -700,9 +752,7 @@ def _internal_class_schema(
700752
type_hints = get_type_hints(
701753
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
702754
)
703-
with dataclasses.replace(
704-
schema_ctx, base_schema=base_schema, generic_args=generic_args
705-
):
755+
with dataclasses.replace(schema_ctx, generic_args=generic_args):
706756
attributes.update(
707757
(
708758
field.name,
@@ -717,9 +767,10 @@ def _internal_class_schema(
717767
)
718768

719769
schema_class: Type[marshmallow.Schema] = type(
720-
clazz.__name__, (_base_schema(clazz, base_schema),), attributes
770+
clazz.__name__, (_base_schema(clazz, schema_ctx.base_schema),), attributes
721771
)
722772
future.set_result(schema_class)
773+
_schema_cache[cache_key] = schema_class
723774
return schema_class
724775

725776

@@ -940,8 +991,7 @@ def _field_for_dataclass(
940991
nested = typ.Schema
941992
else:
942993
assert isinstance(typ, Hashable)
943-
schema_ctx = _schema_ctx_stack.top
944-
nested = _internal_class_schema(typ, schema_ctx.base_schema)
994+
nested = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME
945995
if isinstance(nested, _Future):
946996
nested = nested.result
947997

@@ -976,6 +1026,8 @@ def field_for_schema(
9761026
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
9771027
<class 'marshmallow.fields.Url'>
9781028
"""
1029+
if base_schema is None:
1030+
base_schema = marshmallow.Schema
9791031
localns = typ_frame.f_locals if typ_frame is not None else None
9801032
with _SchemaContext(localns=localns, base_schema=base_schema):
9811033
return _field_for_schema(typ, default, metadata)

tests/test_lrudict.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from marshmallow_dataclass import _LRUDict
2+
3+
4+
def test_LRUDict_getitem_moves_to_end() -> None:
5+
d = _LRUDict[str, str]()
6+
d["a"] = "aval"
7+
d["b"] = "bval"
8+
assert list(d.items()) == [("a", "aval"), ("b", "bval")]
9+
assert d["a"] == "aval"
10+
assert list(d.items()) == [("b", "bval"), ("a", "aval")]
11+
12+
13+
def test_LRUDict_get_moves_to_end() -> None:
14+
d = _LRUDict[str, str]()
15+
d["a"] = "aval"
16+
d["b"] = "bval"
17+
assert list(d.items()) == [("a", "aval"), ("b", "bval")]
18+
assert d.get("a") == "aval"
19+
assert list(d.items()) == [("b", "bval"), ("a", "aval")]
20+
21+
22+
def test_LRUDict_setitem_moves_to_end() -> None:
23+
d = _LRUDict[str, str]()
24+
d["a"] = "aval"
25+
d["b"] = "bval"
26+
assert list(d.items()) == [("a", "aval"), ("b", "bval")]
27+
d["a"] = "newval"
28+
assert list(d.items()) == [("b", "bval"), ("a", "newval")]
29+
30+
31+
def test_LRUDict_discards_oldest() -> None:
32+
d = _LRUDict[str, str](maxsize=1)
33+
d["a"] = "aval"
34+
d["b"] = "bval"
35+
assert list(d.items()) == [("b", "bval")]

0 commit comments

Comments
 (0)