@@ -46,6 +46,7 @@ class User:
46
46
Any ,
47
47
Callable ,
48
48
Dict ,
49
+ Generic ,
49
50
List ,
50
51
Mapping ,
51
52
NewType as typing_NewType ,
@@ -79,9 +80,6 @@ class User:
79
80
# Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates.
80
81
MAX_CLASS_SCHEMA_CACHE_SIZE = 1024
81
82
82
- # Recursion guard for class_schema()
83
- _RECURSION_GUARD = threading .local ()
84
-
85
83
86
84
@overload
87
85
def dataclass (
@@ -352,20 +350,61 @@ def class_schema(
352
350
clazz_frame = current_frame .f_back
353
351
# Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
354
352
del current_frame
355
- _RECURSION_GUARD .seen_classes = {}
356
- try :
357
- return _internal_class_schema (clazz , base_schema , clazz_frame )
358
- finally :
359
- _RECURSION_GUARD .seen_classes .clear ()
353
+
354
+ with _SchemaContext (clazz_frame ):
355
+ return _internal_class_schema (clazz , base_schema )
356
+
357
+
358
+ class _SchemaContext :
359
+ """Global context for an invocation of class_schema."""
360
+
361
+ def __init__ (self , frame : Optional [types .FrameType ]):
362
+ self .seen_classes : Dict [type , str ] = {}
363
+ self .frame = frame
364
+
365
+ def get_type_hints (self , cls : Type ) -> Dict [str , Any ]:
366
+ frame = self .frame
367
+ localns = frame .f_locals if frame is not None else None
368
+ return get_type_hints (cls , localns = localns )
369
+
370
+ def __enter__ (self ) -> "_SchemaContext" :
371
+ _schema_ctx_stack .push (self )
372
+ return self
373
+
374
+ def __exit__ (
375
+ self ,
376
+ _typ : Optional [Type [BaseException ]],
377
+ _value : Optional [BaseException ],
378
+ _tb : Optional [types .TracebackType ],
379
+ ) -> None :
380
+ _schema_ctx_stack .pop ()
381
+
382
+
383
+ class _LocalStack (threading .local , Generic [_U ]):
384
+ def __init__ (self ) -> None :
385
+ self .stack : List [_U ] = []
386
+
387
+ def push (self , value : _U ) -> None :
388
+ self .stack .append (value )
389
+
390
+ def pop (self ) -> None :
391
+ self .stack .pop ()
392
+
393
+ @property
394
+ def top (self ) -> _U :
395
+ return self .stack [- 1 ]
396
+
397
+
398
+ _schema_ctx_stack = _LocalStack [_SchemaContext ]()
360
399
361
400
362
401
@lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
363
402
def _internal_class_schema (
364
403
clazz : type ,
365
404
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
366
- clazz_frame : Optional [types .FrameType ] = None ,
367
405
) -> Type [marshmallow .Schema ]:
368
- _RECURSION_GUARD .seen_classes [clazz ] = clazz .__name__
406
+ schema_ctx = _schema_ctx_stack .top
407
+ schema_ctx .seen_classes [clazz ] = clazz .__name__
369
408
try :
370
409
# noinspection PyDataclass
371
410
fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
@@ -383,7 +422,7 @@ def _internal_class_schema(
383
422
"****** WARNING ******"
384
423
)
385
424
created_dataclass : type = dataclasses .dataclass (clazz )
386
- return _internal_class_schema (created_dataclass , base_schema , clazz_frame )
425
+ return _internal_class_schema (created_dataclass , base_schema )
387
426
except Exception as exc :
388
427
raise TypeError (
389
428
f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
@@ -397,18 +436,15 @@ def _internal_class_schema(
397
436
}
398
437
399
438
# Update the schema members to contain marshmallow fields instead of dataclass fields
400
- type_hints = get_type_hints (
401
- clazz , localns = clazz_frame .f_locals if clazz_frame else None
402
- )
439
+ type_hints = schema_ctx .get_type_hints (clazz )
403
440
attributes .update (
404
441
(
405
442
field .name ,
406
- field_for_schema (
443
+ _field_for_schema (
407
444
type_hints [field .name ],
408
445
_get_field_default (field ),
409
446
field .metadata ,
410
447
base_schema ,
411
- clazz_frame ,
412
448
),
413
449
)
414
450
for field in fields
@@ -433,7 +469,6 @@ def _field_by_supertype(
433
469
newtype_supertype : Type ,
434
470
metadata : dict ,
435
471
base_schema : Optional [Type [marshmallow .Schema ]],
436
- typ_frame : Optional [types .FrameType ],
437
472
) -> marshmallow .fields .Field :
438
473
"""
439
474
Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -459,12 +494,11 @@ def _field_by_supertype(
459
494
if field :
460
495
return field (** metadata )
461
496
else :
462
- return field_for_schema (
497
+ return _field_for_schema (
463
498
newtype_supertype ,
464
499
metadata = metadata ,
465
500
default = default ,
466
501
base_schema = base_schema ,
467
- typ_frame = typ_frame ,
468
502
)
469
503
470
504
@@ -488,7 +522,6 @@ def _generic_type_add_any(typ: type) -> type:
488
522
def _field_for_generic_type (
489
523
typ : type ,
490
524
base_schema : Optional [Type [marshmallow .Schema ]],
491
- typ_frame : Optional [types .FrameType ],
492
525
** metadata : Any ,
493
526
) -> Optional [marshmallow .fields .Field ]:
494
527
"""
@@ -501,9 +534,7 @@ def _field_for_generic_type(
501
534
type_mapping = base_schema .TYPE_MAPPING if base_schema else {}
502
535
503
536
if origin in (list , List ):
504
- child_type = field_for_schema (
505
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
506
- )
537
+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
507
538
list_type = cast (
508
539
Type [marshmallow .fields .List ],
509
540
type_mapping .get (List , marshmallow .fields .List ),
@@ -516,32 +547,25 @@ def _field_for_generic_type(
516
547
):
517
548
from . import collection_field
518
549
519
- child_type = field_for_schema (
520
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
521
- )
550
+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
522
551
return collection_field .Sequence (cls_or_instance = child_type , ** metadata )
523
552
if origin in (set , Set ):
524
553
from . import collection_field
525
554
526
- child_type = field_for_schema (
527
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
528
- )
555
+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
529
556
return collection_field .Set (
530
557
cls_or_instance = child_type , frozen = False , ** metadata
531
558
)
532
559
if origin in (frozenset , FrozenSet ):
533
560
from . import collection_field
534
561
535
- child_type = field_for_schema (
536
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
537
- )
562
+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
538
563
return collection_field .Set (
539
564
cls_or_instance = child_type , frozen = True , ** metadata
540
565
)
541
566
if origin in (tuple , Tuple ):
542
567
children = tuple (
543
- field_for_schema (arg , base_schema = base_schema , typ_frame = typ_frame )
544
- for arg in arguments
568
+ _field_for_schema (arg , base_schema = base_schema ) for arg in arguments
545
569
)
546
570
tuple_type = cast (
547
571
Type [marshmallow .fields .Tuple ],
@@ -553,14 +577,11 @@ def _field_for_generic_type(
553
577
elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
554
578
dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
555
579
return dict_type (
556
- keys = field_for_schema (
557
- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
558
- ),
559
- values = field_for_schema (
560
- arguments [1 ], base_schema = base_schema , typ_frame = typ_frame
561
- ),
580
+ keys = _field_for_schema (arguments [0 ], base_schema = base_schema ),
581
+ values = _field_for_schema (arguments [1 ], base_schema = base_schema ),
562
582
** metadata ,
563
583
)
584
+
564
585
if typing_inspect .is_union_type (typ ):
565
586
if typing_inspect .is_optional_type (typ ):
566
587
metadata ["allow_none" ] = metadata .get ("allow_none" , True )
@@ -570,23 +591,21 @@ def _field_for_generic_type(
570
591
metadata .setdefault ("required" , False )
571
592
subtypes = [t for t in arguments if t is not NoneType ] # type: ignore
572
593
if len (subtypes ) == 1 :
573
- return field_for_schema (
594
+ return _field_for_schema (
574
595
subtypes [0 ],
575
596
metadata = metadata ,
576
597
base_schema = base_schema ,
577
- typ_frame = typ_frame ,
578
598
)
579
599
from . import union_field
580
600
581
601
return union_field .Union (
582
602
[
583
603
(
584
604
subtyp ,
585
- field_for_schema (
605
+ _field_for_schema (
586
606
subtyp ,
587
607
metadata = {"required" : True },
588
608
base_schema = base_schema ,
589
- typ_frame = typ_frame ,
590
609
),
591
610
)
592
611
for subtyp in subtypes
@@ -598,7 +617,7 @@ def _field_for_generic_type(
598
617
599
618
def field_for_schema (
600
619
typ : type ,
601
- default = marshmallow .missing ,
620
+ default : Any = marshmallow .missing ,
602
621
metadata : Optional [Mapping [str , Any ]] = None ,
603
622
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
604
623
typ_frame : Optional [types .FrameType ] = None ,
@@ -622,6 +641,29 @@ def field_for_schema(
622
641
623
642
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
624
643
<class 'marshmallow.fields.Url'>
644
+ """
645
+ with _SchemaContext (typ_frame ):
646
+ return _field_for_schema (typ , default , metadata , base_schema )
647
+
648
+
649
+ def _field_for_schema (
650
+ typ : type ,
651
+ default : Any = marshmallow .missing ,
652
+ metadata : Optional [Mapping [str , Any ]] = None ,
653
+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
654
+ ) -> marshmallow .fields .Field :
655
+ """
656
+ Get a marshmallow Field corresponding to the given python type.
657
+ The metadata of the dataclass field is used as arguments to the marshmallow Field.
658
+
659
+ This is an internal version of field_for_schema. It assumes a _SchemaContext
660
+ has been pushed onto the local stack.
661
+
662
+ :param typ: The type for which a field should be generated
663
+ :param default: value to use for (de)serialization when the field is missing
664
+ :param metadata: Additional parameters to pass to the marshmallow field constructor
665
+ :param base_schema: marshmallow schema used as a base class when deriving dataclass schema
666
+
625
667
"""
626
668
627
669
metadata = {} if metadata is None else dict (metadata )
@@ -694,10 +736,10 @@ def field_for_schema(
694
736
)
695
737
else :
696
738
subtyp = Any
697
- return field_for_schema (subtyp , default , metadata , base_schema , typ_frame )
739
+ return _field_for_schema (subtyp , default , metadata , base_schema )
698
740
699
741
# Generic types
700
- generic_field = _field_for_generic_type (typ , base_schema , typ_frame , ** metadata )
742
+ generic_field = _field_for_generic_type (typ , base_schema , ** metadata )
701
743
if generic_field :
702
744
return generic_field
703
745
@@ -711,7 +753,6 @@ def field_for_schema(
711
753
newtype_supertype = newtype_supertype ,
712
754
metadata = metadata ,
713
755
base_schema = base_schema ,
714
- typ_frame = typ_frame ,
715
756
)
716
757
717
758
# enumerations
@@ -734,8 +775,8 @@ def field_for_schema(
734
775
nested = (
735
776
nested_schema
736
777
or forward_reference
737
- or _RECURSION_GUARD .seen_classes .get (typ )
738
- or _internal_class_schema (typ , base_schema , typ_frame ) # type: ignore [arg-type]
778
+ or _schema_ctx_stack . top .seen_classes .get (typ )
779
+ or _internal_class_schema (typ , base_schema ) # type: ignore[arg-type] # FIXME
739
780
)
740
781
741
782
return marshmallow .fields .Nested (nested , ** metadata )
0 commit comments