Skip to content

Commit 8af4f23

Browse files
committed
refactor: disuse lru_cache in favor of an _LRUDict
1 parent de29bbc commit 8af4f23

File tree

2 files changed

+105
-18
lines changed

2 files changed

+105
-18
lines changed

marshmallow_dataclass/__init__.py

+70-18
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class User:
4242
import types
4343
import warnings
4444
from enum import Enum
45-
from functools import lru_cache, partial
45+
from functools import partial
4646
from typing import (
4747
Any,
4848
Callable,
@@ -106,8 +106,12 @@ def get_origin(tp):
106106

107107

108108
if sys.version_info >= (3, 7):
109+
from typing import OrderedDict
110+
109111
TypeVar_ = TypeVar
110112
else:
113+
from typing_extensions import OrderedDict
114+
111115
TypeVar_ = type
112116

113117
if sys.version_info >= (3, 10):
@@ -118,6 +122,7 @@ def get_origin(tp):
118122

119123
NoneType = type(None)
120124
_U = TypeVar("_U")
125+
_V = TypeVar("_V")
121126

122127
# Whitelist of dataclass members that will be copied to generated schema.
123128
MEMBERS_WHITELIST: Set[str] = {"Meta"}
@@ -487,13 +492,59 @@ def class_schema(
487492
clazz_frame = _maybe_get_callers_frame(clazz)
488493
if clazz_frame is not None:
489494
localns = clazz_frame.f_locals
490-
with _SchemaContext(globalns, localns):
491-
schema = _internal_class_schema(clazz, base_schema)
495+
496+
if base_schema is None:
497+
base_schema = marshmallow.Schema
498+
499+
with _SchemaContext(globalns, localns, base_schema):
500+
schema = _internal_class_schema(clazz)
492501

493502
assert not isinstance(schema, _Future)
494503
return schema
495504

496505

506+
class _LRUDict(OrderedDict[_U, _V]):
507+
"""Limited-length dict which discards LRU entries."""
508+
509+
def __init__(self, maxsize: int = 128):
510+
self.maxsize = maxsize
511+
super().__init__()
512+
513+
def __setitem__(self, key: _U, value: _V) -> None:
514+
super().__setitem__(key, value)
515+
super().move_to_end(key)
516+
517+
while len(self) > self.maxsize:
518+
oldkey = next(iter(self))
519+
super().__delitem__(oldkey)
520+
521+
def __getitem__(self, key: _U) -> _V:
522+
val = super().__getitem__(key)
523+
super().move_to_end(key)
524+
return val
525+
526+
_T = TypeVar("_T")
527+
528+
@overload
529+
def get(self, key: _U) -> Optional[_V]:
530+
...
531+
532+
@overload
533+
def get(self, key: _U, default: _T) -> Union[_V, _T]:
534+
...
535+
536+
def get(self, key: _U, default: Any = None) -> Any:
537+
try:
538+
return self.__getitem__(key)
539+
except KeyError:
540+
return default
541+
542+
543+
_schema_cache = _LRUDict[Hashable, Type[marshmallow.Schema]](
544+
MAX_CLASS_SCHEMA_CACHE_SIZE
545+
)
546+
547+
497548
class InvalidStateError(Exception):
498549
"""Raised when an operation is performed on a future that is not
499550
allowed in the current state.
@@ -623,7 +674,7 @@ class _SchemaContext:
623674

624675
globalns: Optional[Dict[str, Any]] = None
625676
localns: Optional[Dict[str, Any]] = None
626-
base_schema: Optional[Type[marshmallow.Schema]] = None
677+
base_schema: Type[marshmallow.Schema] = marshmallow.Schema
627678
generic_args: Optional[_GenericArgs] = None
628679
seen_classes: Dict[type, _Future[Type[marshmallow.Schema]]] = dataclasses.field(
629680
default_factory=dict
@@ -633,12 +684,9 @@ def get_type_mapping(
633684
self, include_marshmallow_default: bool = False
634685
) -> _TypeMapping:
635686
default_mapping = marshmallow.Schema.TYPE_MAPPING
636-
if self.base_schema is not None:
637-
mappings = [self.base_schema.TYPE_MAPPING]
638-
if include_marshmallow_default:
639-
mappings.append(default_mapping)
640-
else:
641-
mappings = [default_mapping]
687+
mappings = [self.base_schema.TYPE_MAPPING]
688+
if include_marshmallow_default and mappings[0] is not default_mapping:
689+
mappings.append(default_mapping)
642690
return _TypeMapping(*mappings)
643691

644692
def __enter__(self) -> "_SchemaContext":
@@ -672,15 +720,19 @@ def top(self) -> _U:
672720
_schema_ctx_stack = _LocalStack[_SchemaContext]()
673721

674722

675-
@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
676723
def _internal_class_schema(
677724
clazz: type,
678-
base_schema: Optional[Type[marshmallow.Schema]] = None,
679725
) -> Union[Type[marshmallow.Schema], _Future[Type[marshmallow.Schema]]]:
680726
schema_ctx = _schema_ctx_stack.top
681727
if clazz in schema_ctx.seen_classes:
682728
return schema_ctx.seen_classes[clazz]
683729

730+
cache_key = clazz, schema_ctx.base_schema
731+
try:
732+
return _schema_cache[cache_key]
733+
except KeyError:
734+
pass
735+
684736
future: _Future[Type[marshmallow.Schema]] = _Future()
685737
schema_ctx.seen_classes[clazz] = future
686738

@@ -721,9 +773,7 @@ def _internal_class_schema(
721773
type_hints = get_type_hints(
722774
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
723775
)
724-
with dataclasses.replace(
725-
schema_ctx, base_schema=base_schema, generic_args=generic_args
726-
):
776+
with dataclasses.replace(schema_ctx, generic_args=generic_args):
727777
attributes.update(
728778
(
729779
field.name,
@@ -738,9 +788,10 @@ def _internal_class_schema(
738788
)
739789

740790
schema_class: Type[marshmallow.Schema] = type(
741-
clazz.__name__, (_base_schema(clazz, base_schema),), attributes
791+
clazz.__name__, (_base_schema(clazz, schema_ctx.base_schema),), attributes
742792
)
743793
future.set_result(schema_class)
794+
_schema_cache[cache_key] = schema_class
744795
return schema_class
745796

746797

@@ -958,8 +1009,7 @@ def _field_for_dataclass(
9581009
nested = typ.Schema
9591010
else:
9601011
assert isinstance(typ, Hashable)
961-
schema_ctx = _schema_ctx_stack.top
962-
nested = _internal_class_schema(typ, schema_ctx.base_schema)
1012+
nested = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME
9631013
if isinstance(nested, _Future):
9641014
nested = nested.result
9651015

@@ -994,6 +1044,8 @@ def field_for_schema(
9941044
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
9951045
<class 'marshmallow.fields.Url'>
9961046
"""
1047+
if base_schema is None:
1048+
base_schema = marshmallow.Schema
9971049
localns = typ_frame.f_locals if typ_frame is not None else None
9981050
with _SchemaContext(localns=localns, base_schema=base_schema):
9991051
return _field_for_schema(typ, default, metadata)

tests/test_lrudict.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from marshmallow_dataclass import _LRUDict
2+
3+
4+
def test_LRUDict_getitem_moves_to_end() -> None:
5+
d = _LRUDict[str, str]()
6+
d["a"] = "aval"
7+
d["b"] = "bval"
8+
assert list(d.items()) == [("a", "aval"), ("b", "bval")]
9+
assert d["a"] == "aval"
10+
assert list(d.items()) == [("b", "bval"), ("a", "aval")]
11+
12+
13+
def test_LRUDict_get_moves_to_end() -> None:
14+
d = _LRUDict[str, str]()
15+
d["a"] = "aval"
16+
d["b"] = "bval"
17+
assert list(d.items()) == [("a", "aval"), ("b", "bval")]
18+
assert d.get("a") == "aval"
19+
assert list(d.items()) == [("b", "bval"), ("a", "aval")]
20+
21+
22+
def test_LRUDict_setitem_moves_to_end() -> None:
23+
d = _LRUDict[str, str]()
24+
d["a"] = "aval"
25+
d["b"] = "bval"
26+
assert list(d.items()) == [("a", "aval"), ("b", "bval")]
27+
d["a"] = "newval"
28+
assert list(d.items()) == [("b", "bval"), ("a", "newval")]
29+
30+
31+
def test_LRUDict_discards_oldest() -> None:
32+
d = _LRUDict[str, str](maxsize=1)
33+
d["a"] = "aval"
34+
d["b"] = "bval"
35+
assert list(d.items()) == [("b", "bval")]

0 commit comments

Comments
 (0)