Skip to content

Commit 9093446

Browse files
committed
Add ability to pass explicit localns (and globalns) to class_schema
When class_schema is called, it doesn't need the caller's whole stack frame. What it really wants is a `localns` to pass to `typing.get_type_hints` to be used to resolve type references. Here we add the ability to pass an explicit `localns` parameter to `class_schema`. We also add the ability to pass an explicit `globalns`, because ... might as well — it might come in useful. (Since we need these only to pass to `get_type_hints`, we might as well match `get_type_hints` API as closely as possible.)
1 parent fd04f8c commit 9093446

File tree

1 file changed

+44
-13
lines changed

1 file changed

+44
-13
lines changed

marshmallow_dataclass/__init__.py

+44-13
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,36 @@ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
273273
return decorator(_cls, stacklevel=stacklevel + 1)
274274

275275

276+
@overload
277+
def class_schema(
278+
clazz: type,
279+
base_schema: Optional[Type[marshmallow.Schema]] = None,
280+
*,
281+
globalns: Optional[Dict[str, Any]] = None,
282+
localns: Optional[Dict[str, Any]] = None,
283+
) -> Type[marshmallow.Schema]:
284+
...
285+
286+
287+
@overload
288+
def class_schema(
289+
clazz: type,
290+
base_schema: Optional[Type[marshmallow.Schema]] = None,
291+
clazz_frame: Optional[types.FrameType] = None,
292+
*,
293+
globalns: Optional[Dict[str, Any]] = None,
294+
) -> Type[marshmallow.Schema]:
295+
...
296+
297+
276298
def class_schema(
277299
clazz: type,
278300
base_schema: Optional[Type[marshmallow.Schema]] = None,
301+
# FIXME: delete clazz_frame from API?
279302
clazz_frame: Optional[types.FrameType] = None,
303+
*,
304+
globalns: Optional[Dict[str, Any]] = None,
305+
localns: Optional[Dict[str, Any]] = None,
280306
) -> Type[marshmallow.Schema]:
281307
"""
282308
Convert a class to a marshmallow schema
@@ -398,24 +424,26 @@ def class_schema(
398424
"""
399425
if not dataclasses.is_dataclass(clazz):
400426
clazz = dataclasses.dataclass(clazz)
401-
if not clazz_frame:
402-
clazz_frame = _maybe_get_callers_frame(clazz)
403-
404-
with _SchemaContext(clazz_frame):
427+
if localns is None:
428+
if clazz_frame is None:
429+
clazz_frame = _maybe_get_callers_frame(clazz)
430+
if clazz_frame is not None:
431+
localns = clazz_frame.f_locals
432+
with _SchemaContext(globalns, localns):
405433
return _internal_class_schema(clazz, base_schema)
406434

407435

408436
class _SchemaContext:
409437
"""Global context for an invocation of class_schema."""
410438

411-
def __init__(self, frame: Optional[types.FrameType]):
439+
def __init__(
440+
self,
441+
globalns: Optional[Dict[str, Any]] = None,
442+
localns: Optional[Dict[str, Any]] = None,
443+
):
412444
self.seen_classes: Dict[type, str] = {}
413-
self.frame = frame
414-
415-
def get_type_hints(self, cls: Type) -> Dict[str, Any]:
416-
frame = self.frame
417-
localns = frame.f_locals if frame is not None else None
418-
return get_type_hints(cls, localns=localns)
445+
self.globalns = globalns
446+
self.localns = localns
419447

420448
def __enter__(self) -> "_SchemaContext":
421449
_schema_ctx_stack.push(self)
@@ -486,7 +514,9 @@ def _internal_class_schema(
486514
}
487515

488516
# Update the schema members to contain marshmallow fields instead of dataclass fields
489-
type_hints = schema_ctx.get_type_hints(clazz)
517+
type_hints = get_type_hints(
518+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
519+
)
490520
attributes.update(
491521
(
492522
field.name,
@@ -670,6 +700,7 @@ def field_for_schema(
670700
default: Any = marshmallow.missing,
671701
metadata: Optional[Mapping[str, Any]] = None,
672702
base_schema: Optional[Type[marshmallow.Schema]] = None,
703+
# FIXME: delete typ_frame from API?
673704
typ_frame: Optional[types.FrameType] = None,
674705
) -> marshmallow.fields.Field:
675706
"""
@@ -692,7 +723,7 @@ def field_for_schema(
692723
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
693724
<class 'marshmallow.fields.Url'>
694725
"""
695-
with _SchemaContext(typ_frame):
726+
with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None):
696727
return _field_for_schema(typ, default, metadata, base_schema)
697728

698729

0 commit comments

Comments
 (0)