@@ -42,7 +42,7 @@ class User:
42
42
import types
43
43
import warnings
44
44
from enum import Enum
45
- from functools import lru_cache , partial
45
+ from functools import partial
46
46
from typing import (
47
47
Any ,
48
48
Callable ,
@@ -106,8 +106,12 @@ def get_origin(tp):
106
106
107
107
108
108
if sys .version_info >= (3 , 7 ):
109
+ from typing import OrderedDict
110
+
109
111
TypeVar_ = TypeVar
110
112
else :
113
+ from typing_extensions import OrderedDict
114
+
111
115
TypeVar_ = type
112
116
113
117
if sys .version_info >= (3 , 10 ):
@@ -118,6 +122,7 @@ def get_origin(tp):
118
122
119
123
NoneType = type (None )
120
124
_U = TypeVar ("_U" )
125
+ _V = TypeVar ("_V" )
121
126
122
127
# Whitelist of dataclass members that will be copied to generated schema.
123
128
MEMBERS_WHITELIST : Set [str ] = {"Meta" }
@@ -487,13 +492,59 @@ def class_schema(
487
492
clazz_frame = _maybe_get_callers_frame (clazz )
488
493
if clazz_frame is not None :
489
494
localns = clazz_frame .f_locals
490
- with _SchemaContext (globalns , localns ):
491
- schema = _internal_class_schema (clazz , base_schema )
495
+
496
+ if base_schema is None :
497
+ base_schema = marshmallow .Schema
498
+
499
+ with _SchemaContext (globalns , localns , base_schema ):
500
+ schema = _internal_class_schema (clazz )
492
501
493
502
assert not isinstance (schema , _Future )
494
503
return schema
495
504
496
505
506
+ class _LRUDict (OrderedDict [_U , _V ]):
507
+ """Limited-length dict which discards LRU entries."""
508
+
509
+ def __init__ (self , maxsize : int = 128 ):
510
+ self .maxsize = maxsize
511
+ super ().__init__ ()
512
+
513
+ def __setitem__ (self , key : _U , value : _V ) -> None :
514
+ super ().__setitem__ (key , value )
515
+ super ().move_to_end (key )
516
+
517
+ while len (self ) > self .maxsize :
518
+ oldkey = next (iter (self ))
519
+ super ().__delitem__ (oldkey )
520
+
521
+ def __getitem__ (self , key : _U ) -> _V :
522
+ val = super ().__getitem__ (key )
523
+ super ().move_to_end (key )
524
+ return val
525
+
526
+ _T = TypeVar ("_T" )
527
+
528
+ @overload
529
+ def get (self , key : _U ) -> Optional [_V ]:
530
+ ...
531
+
532
+ @overload
533
+ def get (self , key : _U , default : _T ) -> Union [_V , _T ]:
534
+ ...
535
+
536
+ def get (self , key : _U , default : Any = None ) -> Any :
537
+ try :
538
+ return self .__getitem__ (key )
539
+ except KeyError :
540
+ return default
541
+
542
+
543
+ _schema_cache = _LRUDict [Hashable , Type [marshmallow .Schema ]](
544
+ MAX_CLASS_SCHEMA_CACHE_SIZE
545
+ )
546
+
547
+
497
548
class InvalidStateError (Exception ):
498
549
"""Raised when an operation is performed on a future that is not
499
550
allowed in the current state.
@@ -623,7 +674,7 @@ class _SchemaContext:
623
674
624
675
globalns : Optional [Dict [str , Any ]] = None
625
676
localns : Optional [Dict [str , Any ]] = None
626
- base_schema : Optional [ Type [marshmallow .Schema ]] = None
677
+ base_schema : Type [marshmallow .Schema ] = marshmallow . Schema
627
678
generic_args : Optional [_GenericArgs ] = None
628
679
seen_classes : Dict [type , _Future [Type [marshmallow .Schema ]]] = dataclasses .field (
629
680
default_factory = dict
@@ -633,12 +684,9 @@ def get_type_mapping(
633
684
self , include_marshmallow_default : bool = False
634
685
) -> _TypeMapping :
635
686
default_mapping = marshmallow .Schema .TYPE_MAPPING
636
- if self .base_schema is not None :
637
- mappings = [self .base_schema .TYPE_MAPPING ]
638
- if include_marshmallow_default :
639
- mappings .append (default_mapping )
640
- else :
641
- mappings = [default_mapping ]
687
+ mappings = [self .base_schema .TYPE_MAPPING ]
688
+ if include_marshmallow_default and mappings [0 ] is not default_mapping :
689
+ mappings .append (default_mapping )
642
690
return _TypeMapping (* mappings )
643
691
644
692
def __enter__ (self ) -> "_SchemaContext" :
@@ -672,15 +720,19 @@ def top(self) -> _U:
672
720
_schema_ctx_stack = _LocalStack [_SchemaContext ]()
673
721
674
722
675
- @lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
676
723
def _internal_class_schema (
677
724
clazz : type ,
678
- base_schema : Optional [Type [marshmallow .Schema ]] = None ,
679
725
) -> Union [Type [marshmallow .Schema ], _Future [Type [marshmallow .Schema ]]]:
680
726
schema_ctx = _schema_ctx_stack .top
681
727
if clazz in schema_ctx .seen_classes :
682
728
return schema_ctx .seen_classes [clazz ]
683
729
730
+ cache_key = clazz , schema_ctx .base_schema
731
+ try :
732
+ return _schema_cache [cache_key ]
733
+ except KeyError :
734
+ pass
735
+
684
736
future : _Future [Type [marshmallow .Schema ]] = _Future ()
685
737
schema_ctx .seen_classes [clazz ] = future
686
738
@@ -721,9 +773,7 @@ def _internal_class_schema(
721
773
type_hints = get_type_hints (
722
774
clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
723
775
)
724
- with dataclasses .replace (
725
- schema_ctx , base_schema = base_schema , generic_args = generic_args
726
- ):
776
+ with dataclasses .replace (schema_ctx , generic_args = generic_args ):
727
777
attributes .update (
728
778
(
729
779
field .name ,
@@ -738,9 +788,10 @@ def _internal_class_schema(
738
788
)
739
789
740
790
schema_class : Type [marshmallow .Schema ] = type (
741
- clazz .__name__ , (_base_schema (clazz , base_schema ),), attributes
791
+ clazz .__name__ , (_base_schema (clazz , schema_ctx . base_schema ),), attributes
742
792
)
743
793
future .set_result (schema_class )
794
+ _schema_cache [cache_key ] = schema_class
744
795
return schema_class
745
796
746
797
@@ -958,8 +1009,7 @@ def _field_for_dataclass(
958
1009
nested = typ .Schema
959
1010
else :
960
1011
assert isinstance (typ , Hashable )
961
- schema_ctx = _schema_ctx_stack .top
962
- nested = _internal_class_schema (typ , schema_ctx .base_schema )
1012
+ nested = _internal_class_schema (typ ) # type: ignore[arg-type] # FIXME
963
1013
if isinstance (nested , _Future ):
964
1014
nested = nested .result
965
1015
@@ -994,6 +1044,8 @@ def field_for_schema(
994
1044
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
995
1045
<class 'marshmallow.fields.Url'>
996
1046
"""
1047
+ if base_schema is None :
1048
+ base_schema = marshmallow .Schema
997
1049
localns = typ_frame .f_locals if typ_frame is not None else None
998
1050
with _SchemaContext (localns = localns , base_schema = base_schema ):
999
1051
return _field_for_schema (typ , default , metadata )
0 commit comments