@@ -46,6 +46,7 @@ class User:
46
46
Any ,
47
47
Callable ,
48
48
Dict ,
49
+ Generic ,
49
50
List ,
50
51
Mapping ,
51
52
NewType as typing_NewType ,
@@ -79,9 +80,6 @@ class User:
79
80
# Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates.
80
81
MAX_CLASS_SCHEMA_CACHE_SIZE = 1024
81
82
82
- # Recursion guard for class_schema()
83
- _RECURSION_GUARD = threading .local ()
84
-
85
83
86
84
@overload
87
85
def dataclass (
@@ -352,20 +350,56 @@ def class_schema(
352
350
clazz_frame = current_frame .f_back
353
351
# Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
354
352
del current_frame
355
- _RECURSION_GUARD .seen_classes = {}
356
- try :
357
- return _internal_class_schema (clazz , base_schema , clazz_frame )
358
- finally :
359
- _RECURSION_GUARD .seen_classes .clear ()
353
+
354
+ with _SchemaContext (clazz_frame ):
355
+ return _internal_class_schema (clazz , base_schema )
356
+
357
+
358
+ class _SchemaContext :
359
+ """Global context for an invocation of class_schema."""
360
+
361
+ def __init__ (self , frame : Optional [types .FrameType ]):
362
+ self .seen_classes : Dict [Type , str ] = {}
363
+ self .frame = frame
364
+
365
+ def get_type_hints (self , cls : Type ) -> Dict [str , Any ]:
366
+ frame = self .frame
367
+ localns = frame .f_locals if frame is not None else None
368
+ return get_type_hints (cls , localns = localns )
369
+
370
+ def __enter__ (self ) -> "_SchemaContext" :
371
+ _schema_ctx_stack .push (self )
372
+ return self
373
+
374
+ def __exit__ (self , _typ , _value , _tb ) -> None :
375
+ _schema_ctx_stack .pop ()
376
+
377
+
378
+ class _LocalStack (threading .local , Generic [_U ]):
379
+ def __init__ (self ):
380
+ self .stack : List [_U ] = []
381
+
382
+ def push (self , value : _U ) -> None :
383
+ self .stack .append (value )
384
+
385
+ def pop (self ) -> None :
386
+ self .stack .pop ()
387
+
388
+ @property
389
+ def top (self ) -> _U :
390
+ return self .stack [- 1 ]
391
+
392
+
393
+ _schema_ctx_stack = _LocalStack [_SchemaContext ]()
360
394
361
395
362
396
@lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
363
397
def _internal_class_schema (
364
398
clazz : type ,
365
399
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
366
- clazz_frame : types .FrameType = None ,
367
400
) -> Type [marshmallow .Schema ]:
368
- _RECURSION_GUARD .seen_classes [clazz ] = clazz .__name__
401
+ schema_ctx = _schema_ctx_stack .top
402
+ schema_ctx .seen_classes [clazz ] = clazz .__name__
369
403
try :
370
404
# noinspection PyDataclass
371
405
fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
@@ -383,7 +417,7 @@ def _internal_class_schema(
383
417
"****** WARNING ******"
384
418
)
385
419
created_dataclass : type = dataclasses .dataclass (clazz )
386
- return _internal_class_schema (created_dataclass , base_schema , clazz_frame )
420
+ return _internal_class_schema (created_dataclass , base_schema )
387
421
except Exception :
388
422
raise TypeError (
389
423
f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
@@ -397,18 +431,15 @@ def _internal_class_schema(
397
431
}
398
432
399
433
# Update the schema members to contain marshmallow fields instead of dataclass fields
400
- type_hints = get_type_hints (
401
- clazz , localns = clazz_frame .f_locals if clazz_frame else None
402
- )
434
+ type_hints = schema_ctx .get_type_hints (clazz )
403
435
attributes .update (
404
436
(
405
437
field .name ,
406
- field_for_schema (
438
+ _field_for_schema (
407
439
type_hints [field .name ],
408
440
_get_field_default (field ),
409
441
field .metadata ,
410
442
base_schema ,
411
- clazz_frame ,
412
443
),
413
444
)
414
445
for field in fields
@@ -433,7 +464,6 @@ def _field_by_supertype(
433
464
newtype_supertype : Type ,
434
465
metadata : dict ,
435
466
base_schema : Optional [Type [marshmallow .Schema ]],
436
- typ_frame : Optional [types .FrameType ],
437
467
) -> marshmallow .fields .Field :
438
468
"""
439
469
Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -459,12 +489,11 @@ def _field_by_supertype(
459
489
if field :
460
490
return field (** metadata )
461
491
else :
462
- return field_for_schema (
492
+ return _field_for_schema (
463
493
newtype_supertype ,
464
494
metadata = metadata ,
465
495
default = default ,
466
496
base_schema = base_schema ,
467
- typ_frame = typ_frame ,
468
497
)
469
498
470
499
@@ -488,7 +517,6 @@ def _generic_type_add_any(typ: type) -> type:
488
517
def _field_for_generic_type (
489
518
typ : type ,
490
519
base_schema : Optional [Type [marshmallow .Schema ]],
491
- typ_frame : Optional [types .FrameType ],
492
520
** metadata : Any ,
493
521
) -> Optional [marshmallow .fields .Field ]:
494
522
"""
@@ -501,9 +529,7 @@ def _field_for_generic_type(
501
529
type_mapping = base_schema .TYPE_MAPPING if base_schema else {}
502
530
503
531
if origin in (list , List ):
504
- child_type = field_for_schema (
505
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
506
- )
532
+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
507
533
list_type = cast (
508
534
Type [marshmallow .fields .List ],
509
535
type_mapping .get (List , marshmallow .fields .List ),
@@ -512,32 +538,25 @@ def _field_for_generic_type(
512
538
if origin in (collections .abc .Sequence , Sequence ):
513
539
from . import collection_field
514
540
515
- child_type = field_for_schema (
516
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
517
- )
541
+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
518
542
return collection_field .Sequence (cls_or_instance = child_type , ** metadata )
519
543
if origin in (set , Set ):
520
544
from . import collection_field
521
545
522
- child_type = field_for_schema (
523
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
524
- )
546
+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
525
547
return collection_field .Set (
526
548
cls_or_instance = child_type , frozen = False , ** metadata
527
549
)
528
550
if origin in (frozenset , FrozenSet ):
529
551
from . import collection_field
530
552
531
- child_type = field_for_schema (
532
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
533
- )
553
+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
534
554
return collection_field .Set (
535
555
cls_or_instance = child_type , frozen = True , ** metadata
536
556
)
537
557
if origin in (tuple , Tuple ):
538
558
children = tuple (
539
- field_for_schema (arg , base_schema = base_schema , typ_frame = typ_frame )
540
- for arg in arguments
559
+ _field_for_schema (arg , base_schema = base_schema ) for arg in arguments
541
560
)
542
561
tuple_type = cast (
543
562
Type [marshmallow .fields .Tuple ],
@@ -549,12 +568,8 @@ def _field_for_generic_type(
549
568
elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
550
569
dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
551
570
return dict_type (
552
- keys = field_for_schema (
553
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
554
- ),
555
- values = field_for_schema (
556
- arguments [1 ], base_schema = base_schema , typ_frame = typ_frame
557
- ),
571
+ keys = _field_for_schema (arguments [0 ], base_schema = base_schema ),
572
+ values = _field_for_schema (arguments [1 ], base_schema = base_schema ),
558
573
** metadata ,
559
574
)
560
575
elif typing_inspect .is_union_type (typ ):
@@ -566,23 +581,21 @@ def _field_for_generic_type(
566
581
metadata .setdefault ("required" , False )
567
582
subtypes = [t for t in arguments if t is not NoneType ] # type: ignore
568
583
if len (subtypes ) == 1 :
569
- return field_for_schema (
584
+ return _field_for_schema (
570
585
subtypes [0 ],
571
586
metadata = metadata ,
572
587
base_schema = base_schema ,
573
- typ_frame = typ_frame ,
574
588
)
575
589
from . import union_field
576
590
577
591
return union_field .Union (
578
592
[
579
593
(
580
594
subtyp ,
581
- field_for_schema (
595
+ _field_for_schema (
582
596
subtyp ,
583
597
metadata = {"required" : True },
584
598
base_schema = base_schema ,
585
- typ_frame = typ_frame ,
586
599
),
587
600
)
588
601
for subtyp in subtypes
@@ -618,6 +631,29 @@ def field_for_schema(
618
631
619
632
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
620
633
<class 'marshmallow.fields.Url'>
634
+ """
635
+ with _SchemaContext (typ_frame ):
636
+ return _field_for_schema (typ , default , metadata , base_schema )
637
+
638
+
639
+ def _field_for_schema (
640
+ typ : type ,
641
+ default = marshmallow .missing ,
642
+ metadata : Mapping [str , Any ] = None ,
643
+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
644
+ ) -> marshmallow .fields .Field :
645
+ """
646
+ Get a marshmallow Field corresponding to the given python type.
647
+ The metadata of the dataclass field is used as arguments to the marshmallow Field.
648
+
649
+ This is an internal version of field_for_schema. It assumes a _SchemaContext
650
+ has been pushed onto the local stack.
651
+
652
+ :param typ: The type for which a field should be generated
653
+ :param default: value to use for (de)serialization when the field is missing
654
+ :param metadata: Additional parameters to pass to the marshmallow field constructor
655
+ :param base_schema: marshmallow schema used as a base class when deriving dataclass schema
656
+
621
657
"""
622
658
623
659
metadata = {} if metadata is None else dict (metadata )
@@ -690,10 +726,10 @@ def field_for_schema(
690
726
)
691
727
else :
692
728
subtyp = Any
693
- return field_for_schema (subtyp , default , metadata , base_schema , typ_frame )
729
+ return _field_for_schema (subtyp , default , metadata , base_schema )
694
730
695
731
# Generic types
696
- generic_field = _field_for_generic_type (typ , base_schema , typ_frame , ** metadata )
732
+ generic_field = _field_for_generic_type (typ , base_schema , ** metadata )
697
733
if generic_field :
698
734
return generic_field
699
735
@@ -707,7 +743,6 @@ def field_for_schema(
707
743
newtype_supertype = newtype_supertype ,
708
744
metadata = metadata ,
709
745
base_schema = base_schema ,
710
- typ_frame = typ_frame ,
711
746
)
712
747
713
748
# enumerations
@@ -726,8 +761,8 @@ def field_for_schema(
726
761
nested = (
727
762
nested_schema
728
763
or forward_reference
729
- or _RECURSION_GUARD .seen_classes .get (typ )
730
- or _internal_class_schema (typ , base_schema , typ_frame )
764
+ or _schema_ctx_stack . top .seen_classes .get (typ )
765
+ or _internal_class_schema (typ , base_schema )
731
766
)
732
767
733
768
return marshmallow .fields .Nested (nested , ** metadata )
0 commit comments