@@ -42,7 +42,7 @@ class User:
42
42
import types
43
43
import warnings
44
44
from enum import Enum
45
- from functools import lru_cache , partial
45
+ from functools import partial
46
46
from typing import (
47
47
Any ,
48
48
Callable ,
@@ -107,8 +107,12 @@ def get_origin(tp):
107
107
108
108
109
109
if sys .version_info >= (3 , 7 ):
110
+ from typing import OrderedDict
111
+
110
112
TypeVar_ = TypeVar
111
113
else :
114
+ from typing_extensions import OrderedDict
115
+
112
116
TypeVar_ = type
113
117
114
118
if sys .version_info >= (3 , 10 ):
@@ -119,9 +123,9 @@ def get_origin(tp):
119
123
120
124
NoneType = type (None )
121
125
_U = TypeVar ("_U" )
126
+ _V = TypeVar ("_V" )
122
127
_Field = TypeVar ("_Field" , bound = marshmallow .fields .Field )
123
128
124
-
125
129
# Whitelist of dataclass members that will be copied to generated schema.
126
130
MEMBERS_WHITELIST : Set [str ] = {"Meta" }
127
131
@@ -490,13 +494,59 @@ def class_schema(
490
494
clazz_frame = _maybe_get_callers_frame (clazz )
491
495
if clazz_frame is not None :
492
496
localns = clazz_frame .f_locals
493
- with _SchemaContext (globalns , localns ):
494
- schema = _internal_class_schema (clazz , base_schema )
497
+
498
+ if base_schema is None :
499
+ base_schema = marshmallow .Schema
500
+
501
+ with _SchemaContext (globalns , localns , base_schema ):
502
+ schema = _internal_class_schema (clazz )
495
503
496
504
assert not isinstance (schema , _Future )
497
505
return schema
498
506
499
507
508
+ class _LRUDict (OrderedDict [_U , _V ]):
509
+ """Limited-length dict which discards LRU entries."""
510
+
511
+ def __init__ (self , maxsize : int = 128 ):
512
+ self .maxsize = maxsize
513
+ super ().__init__ ()
514
+
515
+ def __setitem__ (self , key : _U , value : _V ) -> None :
516
+ super ().__setitem__ (key , value )
517
+ super ().move_to_end (key )
518
+
519
+ while len (self ) > self .maxsize :
520
+ oldkey = next (iter (self ))
521
+ super ().__delitem__ (oldkey )
522
+
523
+ def __getitem__ (self , key : _U ) -> _V :
524
+ val = super ().__getitem__ (key )
525
+ super ().move_to_end (key )
526
+ return val
527
+
528
+ _T = TypeVar ("_T" )
529
+
530
+ @overload
531
+ def get (self , key : _U ) -> Optional [_V ]:
532
+ ...
533
+
534
+ @overload
535
+ def get (self , key : _U , default : _T ) -> Union [_V , _T ]:
536
+ ...
537
+
538
+ def get (self , key : _U , default : Any = None ) -> Any :
539
+ try :
540
+ return self .__getitem__ (key )
541
+ except KeyError :
542
+ return default
543
+
544
+
545
+ _schema_cache = _LRUDict [Hashable , Type [marshmallow .Schema ]](
546
+ MAX_CLASS_SCHEMA_CACHE_SIZE
547
+ )
548
+
549
+
500
550
class InvalidStateError (Exception ):
501
551
"""Raised when an operation is performed on a future that is not
502
552
allowed in the current state.
@@ -597,7 +647,7 @@ class _SchemaContext:
597
647
598
648
globalns : Optional [Dict [str , Any ]] = None
599
649
localns : Optional [Dict [str , Any ]] = None
600
- base_schema : Optional [ Type [marshmallow .Schema ]] = None
650
+ base_schema : Type [marshmallow .Schema ] = marshmallow . Schema
601
651
generic_args : Optional [_GenericArgs ] = None
602
652
seen_classes : Dict [type , _Future [Type [marshmallow .Schema ]]] = dataclasses .field (
603
653
default_factory = dict
@@ -612,8 +662,6 @@ def get_type_mapping(
612
662
all bases in base_schema's MRO.
613
663
"""
614
664
base_schema = self .base_schema
615
- if base_schema is None :
616
- base_schema = marshmallow .Schema
617
665
if use_mro :
618
666
return ChainMap (
619
667
* (getattr (cls , "TYPE_MAPPING" , {}) for cls in base_schema .__mro__ )
@@ -651,15 +699,19 @@ def top(self) -> _U:
651
699
_schema_ctx_stack = _LocalStack [_SchemaContext ]()
652
700
653
701
654
- @lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
655
702
def _internal_class_schema (
656
703
clazz : type ,
657
- base_schema : Optional [Type [marshmallow .Schema ]] = None ,
658
704
) -> Union [Type [marshmallow .Schema ], _Future [Type [marshmallow .Schema ]]]:
659
705
schema_ctx = _schema_ctx_stack .top
660
706
if clazz in schema_ctx .seen_classes :
661
707
return schema_ctx .seen_classes [clazz ]
662
708
709
+ cache_key = clazz , schema_ctx .base_schema
710
+ try :
711
+ return _schema_cache [cache_key ]
712
+ except KeyError :
713
+ pass
714
+
663
715
future : _Future [Type [marshmallow .Schema ]] = _Future ()
664
716
schema_ctx .seen_classes [clazz ] = future
665
717
@@ -700,9 +752,7 @@ def _internal_class_schema(
700
752
type_hints = get_type_hints (
701
753
clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
702
754
)
703
- with dataclasses .replace (
704
- schema_ctx , base_schema = base_schema , generic_args = generic_args
705
- ):
755
+ with dataclasses .replace (schema_ctx , generic_args = generic_args ):
706
756
attributes .update (
707
757
(
708
758
field .name ,
@@ -717,9 +767,10 @@ def _internal_class_schema(
717
767
)
718
768
719
769
schema_class : Type [marshmallow .Schema ] = type (
720
- clazz .__name__ , (_base_schema (clazz , base_schema ),), attributes
770
+ clazz .__name__ , (_base_schema (clazz , schema_ctx . base_schema ),), attributes
721
771
)
722
772
future .set_result (schema_class )
773
+ _schema_cache [cache_key ] = schema_class
723
774
return schema_class
724
775
725
776
@@ -940,8 +991,7 @@ def _field_for_dataclass(
940
991
nested = typ .Schema
941
992
else :
942
993
assert isinstance (typ , Hashable )
943
- schema_ctx = _schema_ctx_stack .top
944
- nested = _internal_class_schema (typ , schema_ctx .base_schema )
994
+ nested = _internal_class_schema (typ ) # type: ignore[arg-type] # FIXME
945
995
if isinstance (nested , _Future ):
946
996
nested = nested .result
947
997
@@ -976,6 +1026,8 @@ def field_for_schema(
976
1026
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
977
1027
<class 'marshmallow.fields.Url'>
978
1028
"""
1029
+ if base_schema is None :
1030
+ base_schema = marshmallow .Schema
979
1031
localns = typ_frame .f_locals if typ_frame is not None else None
980
1032
with _SchemaContext (localns = localns , base_schema = base_schema ):
981
1033
return _field_for_schema (typ , default , metadata )
0 commit comments