Skip to content

Commit 73c5fcc

Browse files
committed
feat: Do not auto-dataclassify non-dataclasses
This implement the suggestions made in lovasoa#51. See lovasoa#51 (comment)
1 parent 264e355 commit 73c5fcc

File tree

3 files changed

+176
-117
lines changed

3 files changed

+176
-117
lines changed

marshmallow_dataclass/__init__.py

+148-108
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,9 @@ class _SchemaContext:
653653
default_factory=dict
654654
)
655655

656+
def replace(self, generic_args: Optional[_GenericArgs]) -> "_SchemaContext":
657+
return dataclasses.replace(self, generic_args=generic_args)
658+
656659
def get_type_mapping(
657660
self, use_mro: bool = False
658661
) -> Mapping[Any, Type[marshmallow.fields.Field]]:
@@ -717,63 +720,109 @@ def _internal_class_schema(
717720

718721
generic_args = schema_ctx.generic_args
719722

720-
if _is_generic_alias_of_dataclass(clazz):
721-
generic_args = _GenericArgs(clazz, generic_args)
722-
clazz = typing_inspect.get_origin(clazz)
723-
elif not dataclasses.is_dataclass(clazz):
724-
try:
725-
warnings.warn(
726-
"****** WARNING ****** "
727-
f"marshmallow_dataclass was called on the class {clazz}, which is not a dataclass. "
728-
"It is going to try and convert the class into a dataclass, which may have "
729-
"undesirable side effects. To avoid this message, make sure all your classes and "
730-
"all the classes of their fields are either explicitly supported by "
731-
"marshmallow_dataclass, or define the schema explicitly using "
732-
"field(metadata=dict(marshmallow_field=...)). For more information, see "
733-
"https://github.com/lovasoa/marshmallow_dataclass/issues/51 "
734-
"****** WARNING ******"
735-
)
736-
dataclasses.dataclass(clazz)
737-
except Exception as exc:
738-
raise TypeError(
739-
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
740-
) from exc
741-
742-
fields = dataclasses.fields(clazz)
743-
744-
# Copy all marshmallow hooks and whitelisted members of the dataclass to the schema.
745-
attributes = {
746-
k: v
747-
for k, v in inspect.getmembers(clazz)
748-
if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST
749-
}
723+
constructor: Callable[..., object]
724+
725+
if _is_simple_annotated_class(clazz):
726+
class_name = clazz.__name__
727+
constructor = _simple_class_constructor(clazz)
728+
attributes = _schema_attrs_for_simple_class(clazz)
729+
elif _is_generic_alias_of_dataclass(clazz):
730+
origin = get_origin(clazz)
731+
assert isinstance(origin, type)
732+
class_name = origin.__name__
733+
constructor = origin
734+
with schema_ctx.replace(generic_args=_GenericArgs(clazz, generic_args)):
735+
attributes = _schema_attrs_for_dataclass(origin)
736+
elif dataclasses.is_dataclass(clazz):
737+
class_name = clazz.__name__
738+
constructor = clazz
739+
attributes = _schema_attrs_for_dataclass(clazz)
740+
else:
741+
raise TypeError(f"{clazz} is not a dataclass or a simple annotated class")
742+
743+
base_schema = marshmallow.Schema
744+
if schema_ctx.base_schema is not None:
745+
base_schema = schema_ctx.base_schema
746+
747+
load_to_dict = base_schema.load
748+
749+
def load(
750+
self: marshmallow.Schema,
751+
data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]],
752+
*,
753+
many: Optional[bool] = None,
754+
unknown: Optional[str] = None,
755+
**kwargs: Any,
756+
) -> Any:
757+
many = self.many if many is None else bool(many)
758+
loaded = load_to_dict(self, data, many=many, unknown=unknown, **kwargs)
759+
if many:
760+
return [constructor(**item) for item in loaded]
761+
else:
762+
return constructor(**loaded)
750763

751-
# Update the schema members to contain marshmallow fields instead of dataclass fields
752-
type_hints = get_type_hints(
753-
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
754-
)
755-
with dataclasses.replace(schema_ctx, generic_args=generic_args):
756-
attributes.update(
757-
(
758-
field.name,
759-
_field_for_schema(
760-
type_hints[field.name],
761-
_get_field_default(field),
762-
field.metadata,
763-
),
764-
)
765-
for field in fields
766-
if field.init
767-
)
764+
attributes["load"] = load
768765

769766
schema_class: Type[marshmallow.Schema] = type(
770-
clazz.__name__, (_base_schema(clazz, schema_ctx.base_schema),), attributes
767+
f"{class_name}Schema", (base_schema,), attributes
771768
)
769+
772770
future.set_result(schema_class)
773771
_schema_cache[cache_key] = schema_class
774772
return schema_class
775773

776774

775+
def _marshmallow_hooks(clazz: type) -> Iterator[Tuple[str, Any]]:
776+
for name, attr in inspect.getmembers(clazz):
777+
if hasattr(attr, "__marshmallow_hook__") or name in MEMBERS_WHITELIST:
778+
yield name, attr
779+
780+
781+
def _schema_attrs_for_dataclass(clazz: type) -> Dict[str, Any]:
782+
schema_ctx = _schema_ctx_stack.top
783+
type_hints = get_type_hints(
784+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
785+
)
786+
787+
attrs = dict(_marshmallow_hooks(clazz))
788+
for field in dataclasses.fields(clazz):
789+
if field.init:
790+
typ = type_hints[field.name]
791+
default = (
792+
field.default_factory
793+
if field.default_factory is not dataclasses.MISSING
794+
else field.default
795+
if field.default is not dataclasses.MISSING
796+
else marshmallow.missing
797+
)
798+
attrs[field.name] = _field_for_schema(typ, default, field.metadata)
799+
return attrs
800+
801+
802+
def _schema_attrs_for_simple_class(clazz: type) -> Dict[str, Any]:
803+
schema_ctx = _schema_ctx_stack.top
804+
type_hints = get_type_hints(
805+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
806+
)
807+
808+
attrs = dict(_marshmallow_hooks(clazz))
809+
for field_name, typ in type_hints.items():
810+
if not typing_inspect.is_classvar(typ):
811+
default = getattr(clazz, field_name, marshmallow.missing)
812+
attrs[field_name] = _field_for_schema(typ, default)
813+
return attrs
814+
815+
816+
def _simple_class_constructor(clazz: Type[_U]) -> Callable[..., _U]:
817+
def constructor(**kwargs: Any) -> _U:
818+
obj = clazz.__new__(clazz)
819+
for k, v in kwargs.items():
820+
setattr(obj, k, v)
821+
return obj
822+
823+
return constructor
824+
825+
777826
def _is_builtin_collection_type(typ: object) -> bool:
778827
origin = get_origin(typ)
779828
if origin is None:
@@ -845,7 +894,7 @@ def get_field(i: int) -> marshmallow.fields.Field:
845894

846895

847896
def _field_for_union_type(
848-
typ: type, metadata: Dict[str, Any]
897+
typ: object, metadata: Dict[str, Any]
849898
) -> marshmallow.fields.Field:
850899
"""
851900
Construct the appropriate Field for a union or optional type.
@@ -877,7 +926,7 @@ def _field_for_union_type(
877926

878927

879928
def _field_for_literal_type(
880-
typ: type, metadata: Dict[str, Any]
929+
typ: object, metadata: Dict[str, Any]
881930
) -> marshmallow.fields.Field:
882931
"""
883932
Construct the appropriate Field for a Literal type.
@@ -893,7 +942,7 @@ def _field_for_literal_type(
893942
return marshmallow.fields.Raw(validate=validate, **metadata)
894943

895944

896-
def _get_subtype_for_final_type(typ: type, default: Any) -> Any:
945+
def _get_subtype_for_final_type(typ: object, default: Any) -> Any:
897946
"""
898947
Construct the appropriate Field for a Final type.
899948
"""
@@ -930,7 +979,7 @@ def _get_subtype_for_final_type(typ: type, default: Any) -> Any:
930979

931980

932981
def _field_for_new_type(
933-
typ: Type, default: Any, metadata: Dict[str, Any]
982+
typ: object, default: Any, metadata: Dict[str, Any]
934983
) -> marshmallow.fields.Field:
935984
"""
936985
Return a new field for fields based on a NewType.
@@ -953,7 +1002,8 @@ def _field_for_new_type(
9531002
**metadata,
9541003
"validate": validators if validators else None,
9551004
}
956-
metadata.setdefault("metadata", {}).setdefault("description", typ.__name__)
1005+
type_name = getattr(typ, "__name__", repr(typ))
1006+
metadata.setdefault("metadata", {}).setdefault("description", type_name)
9571007

9581008
field: Optional[Type[marshmallow.fields.Field]] = getattr(
9591009
typ, "_marshmallow_field", None
@@ -980,22 +1030,41 @@ def _field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.F
9801030
return marshmallow_enum.EnumField(typ, **metadata)
9811031

9821032

983-
def _field_for_dataclass(
984-
typ: Union[Type, object], metadata: Dict[str, Any]
985-
) -> marshmallow.fields.Field:
1033+
def _schema_for_nested(
1034+
typ: object,
1035+
) -> Union[Type[marshmallow.Schema], Callable[[], Type[marshmallow.Schema]]]:
9861036
"""
987-
Return a new field for a nested dataclass field.
1037+
Return a marshmallow.Schema for a nested dataclass (or simple annotated class)
9881038
"""
9891039
if isinstance(typ, type) and hasattr(typ, "Schema"):
9901040
# marshmallow_dataclass.dataclass
991-
nested = typ.Schema
992-
else:
993-
assert isinstance(typ, Hashable)
994-
nested = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME
995-
if isinstance(nested, _Future):
996-
nested = nested.result
1041+
# Defer evaluation of .Schema attribute, to avoid forward reference issues
1042+
return partial(getattr, typ, "Schema")
9971043

998-
return marshmallow.fields.Nested(nested, **metadata)
1044+
class_schema = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME
1045+
if isinstance(class_schema, _Future):
1046+
return class_schema.result
1047+
return class_schema
1048+
1049+
1050+
def _is_simple_annotated_class(obj: object) -> bool:
1051+
"""Determine whether obj is a "simple annotated class".
1052+
1053+
The ```class_schema``` function can generate schemas for
1054+
simple annotated classes (as well as for dataclasses).
1055+
"""
1056+
if not isinstance(obj, type):
1057+
return False
1058+
if getattr(obj, "__init__", None) is not object.__init__:
1059+
return False
1060+
if getattr(obj, "__new__", None) is not object.__new__:
1061+
return False
1062+
1063+
schema_ctx = _schema_ctx_stack.top
1064+
type_hints = get_type_hints(
1065+
obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns
1066+
)
1067+
return any(not typing_inspect.is_classvar(th) for th in type_hints.values())
9991068

10001069

10011070
def field_for_schema(
@@ -1034,7 +1103,7 @@ def field_for_schema(
10341103

10351104

10361105
def _field_for_schema(
1037-
typ: type,
1106+
typ: Union[type, object],
10381107
default: Any = marshmallow.missing,
10391108
metadata: Optional[Mapping[str, Any]] = None,
10401109
) -> marshmallow.fields.Field:
@@ -1104,54 +1173,25 @@ def _field_for_schema(
11041173
if isinstance(typ, type) and issubclass(typ, Enum):
11051174
return _field_for_enum(typ, metadata)
11061175

1107-
# Assume nested marshmallow dataclass (and hope for the best)
1108-
return _field_for_dataclass(typ, metadata)
1109-
1110-
1111-
def _base_schema(
1112-
clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None
1113-
) -> Type[marshmallow.Schema]:
1114-
"""
1115-
Base schema factory that creates a schema for `clazz` derived either from `base_schema`
1116-
or `BaseSchema`
1117-
"""
1118-
1119-
# Remove `type: ignore` when mypy handles dynamic base classes
1120-
# https://github.com/python/mypy/issues/2813
1121-
class BaseSchema(base_schema or marshmallow.Schema): # type: ignore
1122-
def load(self, data: Mapping, *, many: Optional[bool] = None, **kwargs):
1123-
all_loaded = super().load(data, many=many, **kwargs)
1124-
many = self.many if many is None else bool(many)
1125-
if many:
1126-
return [clazz(**loaded) for loaded in all_loaded]
1127-
else:
1128-
return clazz(**all_loaded)
1129-
1130-
return BaseSchema
1131-
1132-
1133-
def _get_field_default(field: dataclasses.Field):
1134-
"""
1135-
Return a marshmallow default value given a dataclass default value
1176+
# nested dataclasses
1177+
if (
1178+
dataclasses.is_dataclass(typ)
1179+
or _is_generic_alias_of_dataclass(typ)
1180+
or _is_simple_annotated_class(typ)
1181+
):
1182+
nested = _schema_for_nested(typ)
1183+
# type spec for Nested.__init__ is not correct
1184+
return marshmallow.fields.Nested(nested, **metadata) # type: ignore[arg-type]
11361185

1137-
>>> _get_field_default(dataclasses.field())
1138-
<marshmallow.missing>
1139-
"""
1140-
# Remove `type: ignore` when https://github.com/python/mypy/issues/6910 is fixed
1141-
default_factory = field.default_factory # type: ignore
1142-
if default_factory is not dataclasses.MISSING:
1143-
return default_factory
1144-
elif field.default is dataclasses.MISSING:
1145-
return marshmallow.missing
1146-
return field.default
1186+
raise TypeError(f"can not deduce field type for {typ}")
11471187

11481188

11491189
def NewType(
11501190
name: str,
11511191
typ: Type[_U],
11521192
field: Optional[Type[marshmallow.fields.Field]] = None,
1153-
**kwargs,
1154-
) -> Callable[[_U], _U]:
1193+
**kwargs: Any,
1194+
) -> type:
11551195
"""NewType creates simple unique types
11561196
to which you can attach custom marshmallow attributes.
11571197
All the keyword arguments passed to this function will be transmitted
@@ -1184,9 +1224,9 @@ def NewType(
11841224
# noinspection PyTypeHints
11851225
new_type = typing_NewType(name, typ) # type: ignore
11861226
# noinspection PyTypeHints
1187-
new_type._marshmallow_field = field # type: ignore
1227+
new_type._marshmallow_field = field
11881228
# noinspection PyTypeHints
1189-
new_type._marshmallow_args = kwargs # type: ignore
1229+
new_type._marshmallow_args = kwargs
11901230
return new_type
11911231

11921232

tests/test_class_schema.py

+13
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,19 @@ class J:
406406
[validator_a, validator_b, validator_c, validator_d],
407407
)
408408

409+
def test_simple_annotated_class(self):
410+
class Child:
411+
x: int
412+
413+
@dataclasses.dataclass
414+
class Container:
415+
child: Child
416+
417+
schema = class_schema(Container)()
418+
419+
loaded = schema.load({"child": {"x": "42"}})
420+
self.assertEqual(loaded.child.x, 42)
421+
409422
def test_generic_dataclass(self):
410423
T = typing.TypeVar("T")
411424

0 commit comments

Comments
 (0)