@@ -273,10 +273,36 @@ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
273
273
return decorator (_cls , stacklevel = stacklevel + 1 )
274
274
275
275
276
+ @overload
277
+ def class_schema (
278
+ clazz : type ,
279
+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
280
+ * ,
281
+ globalns : Optional [Dict [str , Any ]] = None ,
282
+ localns : Optional [Dict [str , Any ]] = None ,
283
+ ) -> Type [marshmallow .Schema ]:
284
+ ...
285
+
286
+
287
+ @overload
288
+ def class_schema (
289
+ clazz : type ,
290
+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
291
+ clazz_frame : Optional [types .FrameType ] = None ,
292
+ * ,
293
+ globalns : Optional [Dict [str , Any ]] = None ,
294
+ ) -> Type [marshmallow .Schema ]:
295
+ ...
296
+
297
+
276
298
def class_schema (
277
299
clazz : type ,
278
300
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
301
+ # FIXME: delete clazz_frame from API?
279
302
clazz_frame : Optional [types .FrameType ] = None ,
303
+ * ,
304
+ globalns : Optional [Dict [str , Any ]] = None ,
305
+ localns : Optional [Dict [str , Any ]] = None ,
280
306
) -> Type [marshmallow .Schema ]:
281
307
"""
282
308
Convert a class to a marshmallow schema
@@ -398,24 +424,26 @@ def class_schema(
398
424
"""
399
425
if not dataclasses .is_dataclass (clazz ):
400
426
clazz = dataclasses .dataclass (clazz )
401
- if not clazz_frame :
402
- clazz_frame = _maybe_get_callers_frame (clazz )
403
-
404
- with _SchemaContext (clazz_frame ):
427
+ if localns is None :
428
+ if clazz_frame is None :
429
+ clazz_frame = _maybe_get_callers_frame (clazz )
430
+ if clazz_frame is not None :
431
+ localns = clazz_frame .f_locals
432
+ with _SchemaContext (globalns , localns ):
405
433
return _internal_class_schema (clazz , base_schema )
406
434
407
435
408
436
class _SchemaContext :
409
437
"""Global context for an invocation of class_schema."""
410
438
411
- def __init__ (self , frame : Optional [types .FrameType ]):
439
+ def __init__ (
440
+ self ,
441
+ globalns : Optional [Dict [str , Any ]] = None ,
442
+ localns : Optional [Dict [str , Any ]] = None ,
443
+ ):
412
444
self .seen_classes : Dict [type , str ] = {}
413
- self .frame = frame
414
-
415
- def get_type_hints (self , cls : Type ) -> Dict [str , Any ]:
416
- frame = self .frame
417
- localns = frame .f_locals if frame is not None else None
418
- return get_type_hints (cls , localns = localns )
445
+ self .globalns = globalns
446
+ self .localns = localns
419
447
420
448
def __enter__ (self ) -> "_SchemaContext" :
421
449
_schema_ctx_stack .push (self )
@@ -486,7 +514,9 @@ def _internal_class_schema(
486
514
}
487
515
488
516
# Update the schema members to contain marshmallow fields instead of dataclass fields
489
- type_hints = schema_ctx .get_type_hints (clazz )
517
+ type_hints = get_type_hints (
518
+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
519
+ )
490
520
attributes .update (
491
521
(
492
522
field .name ,
@@ -670,6 +700,7 @@ def field_for_schema(
670
700
default : Any = marshmallow .missing ,
671
701
metadata : Optional [Mapping [str , Any ]] = None ,
672
702
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
703
+ # FIXME: delete typ_frame from API?
673
704
typ_frame : Optional [types .FrameType ] = None ,
674
705
) -> marshmallow .fields .Field :
675
706
"""
@@ -692,7 +723,7 @@ def field_for_schema(
692
723
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
693
724
<class 'marshmallow.fields.Url'>
694
725
"""
695
- with _SchemaContext (typ_frame ):
726
+ with _SchemaContext (localns = typ_frame . f_locals if typ_frame is not None else None ):
696
727
return _field_for_schema (typ , default , metadata , base_schema )
697
728
698
729
0 commit comments