Skip to content

Commit 5409d87

Browse files
committed
feat: Do not auto-dataclassify non-dataclasses
This implement the suggestions made in lovasoa#51. See lovasoa#51 (comment)
1 parent 8af4f23 commit 5409d87

File tree

2 files changed

+163
-117
lines changed

2 files changed

+163
-117
lines changed

marshmallow_dataclass/__init__.py

+148-108
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,9 @@ class _SchemaContext:
680680
default_factory=dict
681681
)
682682

683+
def replace(self, generic_args: Optional[_GenericArgs]) -> "_SchemaContext":
684+
return dataclasses.replace(self, generic_args=generic_args)
685+
683686
def get_type_mapping(
684687
self, include_marshmallow_default: bool = False
685688
) -> _TypeMapping:
@@ -738,63 +741,109 @@ def _internal_class_schema(
738741

739742
generic_args = schema_ctx.generic_args
740743

741-
if _is_generic_alias_of_dataclass(clazz):
742-
generic_args = _GenericArgs(clazz, generic_args)
743-
clazz = typing_inspect.get_origin(clazz)
744-
elif not dataclasses.is_dataclass(clazz):
745-
try:
746-
warnings.warn(
747-
"****** WARNING ****** "
748-
f"marshmallow_dataclass was called on the class {clazz}, which is not a dataclass. "
749-
"It is going to try and convert the class into a dataclass, which may have "
750-
"undesirable side effects. To avoid this message, make sure all your classes and "
751-
"all the classes of their fields are either explicitly supported by "
752-
"marshmallow_dataclass, or define the schema explicitly using "
753-
"field(metadata=dict(marshmallow_field=...)). For more information, see "
754-
"https://github.com/lovasoa/marshmallow_dataclass/issues/51 "
755-
"****** WARNING ******"
756-
)
757-
dataclasses.dataclass(clazz)
758-
except Exception as exc:
759-
raise TypeError(
760-
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
761-
) from exc
762-
763-
fields = dataclasses.fields(clazz)
764-
765-
# Copy all marshmallow hooks and whitelisted members of the dataclass to the schema.
766-
attributes = {
767-
k: v
768-
for k, v in inspect.getmembers(clazz)
769-
if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST
770-
}
744+
constructor: Callable[..., object]
745+
746+
if _is_simple_annotated_class(clazz):
747+
class_name = clazz.__name__
748+
constructor = _simple_class_constructor(clazz)
749+
attributes = _schema_attrs_for_simple_class(clazz)
750+
elif _is_generic_alias_of_dataclass(clazz):
751+
origin = get_origin(clazz)
752+
assert isinstance(origin, type)
753+
class_name = origin.__name__
754+
constructor = origin
755+
with schema_ctx.replace(generic_args=_GenericArgs(clazz, generic_args)):
756+
attributes = _schema_attrs_for_dataclass(origin)
757+
elif dataclasses.is_dataclass(clazz):
758+
class_name = clazz.__name__
759+
constructor = clazz
760+
attributes = _schema_attrs_for_dataclass(clazz)
761+
else:
762+
raise TypeError(f"{clazz} is not a dataclass or a simple annotated class")
763+
764+
base_schema = marshmallow.Schema
765+
if schema_ctx.base_schema is not None:
766+
base_schema = schema_ctx.base_schema
767+
768+
load_to_dict = base_schema.load
769+
770+
def load(
771+
self: marshmallow.Schema,
772+
data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]],
773+
*,
774+
many: Optional[bool] = None,
775+
unknown: Optional[str] = None,
776+
**kwargs: Any,
777+
) -> Any:
778+
many = self.many if many is None else bool(many)
779+
loaded = load_to_dict(self, data, many=many, unknown=unknown, **kwargs)
780+
if many:
781+
return [constructor(**item) for item in loaded]
782+
else:
783+
return constructor(**loaded)
771784

772-
# Update the schema members to contain marshmallow fields instead of dataclass fields
773-
type_hints = get_type_hints(
774-
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
775-
)
776-
with dataclasses.replace(schema_ctx, generic_args=generic_args):
777-
attributes.update(
778-
(
779-
field.name,
780-
_field_for_schema(
781-
type_hints[field.name],
782-
_get_field_default(field),
783-
field.metadata,
784-
),
785-
)
786-
for field in fields
787-
if field.init
788-
)
785+
attributes["load"] = load
789786

790787
schema_class: Type[marshmallow.Schema] = type(
791-
clazz.__name__, (_base_schema(clazz, schema_ctx.base_schema),), attributes
788+
f"{class_name}Schema", (base_schema,), attributes
792789
)
790+
793791
future.set_result(schema_class)
794792
_schema_cache[cache_key] = schema_class
795793
return schema_class
796794

797795

796+
def _marshmallow_hooks(clazz: type) -> Iterator[Tuple[str, Any]]:
797+
for name, attr in inspect.getmembers(clazz):
798+
if hasattr(attr, "__marshmallow_hook__") or name in MEMBERS_WHITELIST:
799+
yield name, attr
800+
801+
802+
def _schema_attrs_for_dataclass(clazz: type) -> Dict[str, Any]:
803+
schema_ctx = _schema_ctx_stack.top
804+
type_hints = get_type_hints(
805+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
806+
)
807+
808+
attrs = dict(_marshmallow_hooks(clazz))
809+
for field in dataclasses.fields(clazz):
810+
if field.init:
811+
typ = type_hints[field.name]
812+
default = (
813+
field.default_factory
814+
if field.default_factory is not dataclasses.MISSING
815+
else field.default
816+
if field.default is not dataclasses.MISSING
817+
else marshmallow.missing
818+
)
819+
attrs[field.name] = _field_for_schema(typ, default, field.metadata)
820+
return attrs
821+
822+
823+
def _schema_attrs_for_simple_class(clazz: type) -> Dict[str, Any]:
824+
schema_ctx = _schema_ctx_stack.top
825+
type_hints = get_type_hints(
826+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
827+
)
828+
829+
attrs = dict(_marshmallow_hooks(clazz))
830+
for field_name, typ in type_hints.items():
831+
if not typing_inspect.is_classvar(typ):
832+
default = getattr(clazz, field_name, marshmallow.missing)
833+
attrs[field_name] = _field_for_schema(typ, default)
834+
return attrs
835+
836+
837+
def _simple_class_constructor(clazz: Type[_U]) -> Callable[..., _U]:
838+
def constructor(**kwargs: Any) -> _U:
839+
obj = clazz.__new__(clazz)
840+
for k, v in kwargs.items():
841+
setattr(obj, k, v)
842+
return obj
843+
844+
return constructor
845+
846+
798847
def _is_builtin_collection_type(typ: object) -> bool:
799848
origin = get_origin(typ)
800849
if origin is None:
@@ -863,7 +912,7 @@ def get_field(i: int) -> marshmallow.fields.Field:
863912

864913

865914
def _field_for_union_type(
866-
typ: type, metadata: Dict[str, Any]
915+
typ: object, metadata: Dict[str, Any]
867916
) -> marshmallow.fields.Field:
868917
"""
869918
Construct the appropriate Field for a union or optional type.
@@ -895,7 +944,7 @@ def _field_for_union_type(
895944

896945

897946
def _field_for_literal_type(
898-
typ: type, metadata: Dict[str, Any]
947+
typ: object, metadata: Dict[str, Any]
899948
) -> marshmallow.fields.Field:
900949
"""
901950
Construct the appropriate Field for a Literal type.
@@ -911,7 +960,7 @@ def _field_for_literal_type(
911960
return marshmallow.fields.Raw(validate=validate, **metadata)
912961

913962

914-
def _get_subtype_for_final_type(typ: type, default: Any) -> Any:
963+
def _get_subtype_for_final_type(typ: object, default: Any) -> Any:
915964
"""
916965
Construct the appropriate Field for a Final type.
917966
"""
@@ -948,7 +997,7 @@ def _get_subtype_for_final_type(typ: type, default: Any) -> Any:
948997

949998

950999
def _field_for_new_type(
951-
typ: Type, default: Any, metadata: Dict[str, Any]
1000+
typ: object, default: Any, metadata: Dict[str, Any]
9521001
) -> marshmallow.fields.Field:
9531002
"""
9541003
Return a new field for fields based on a NewType.
@@ -971,7 +1020,8 @@ def _field_for_new_type(
9711020
**metadata,
9721021
"validate": validators if validators else None,
9731022
}
974-
metadata.setdefault("metadata", {}).setdefault("description", typ.__name__)
1023+
type_name = getattr(typ, "__name__", repr(typ))
1024+
metadata.setdefault("metadata", {}).setdefault("description", type_name)
9751025

9761026
field: Optional[Type[marshmallow.fields.Field]] = getattr(
9771027
typ, "_marshmallow_field", None
@@ -998,22 +1048,41 @@ def _field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.F
9981048
return marshmallow_enum.EnumField(typ, **metadata)
9991049

10001050

1001-
def _field_for_dataclass(
1002-
typ: Union[Type, object], metadata: Dict[str, Any]
1003-
) -> marshmallow.fields.Field:
1051+
def _schema_for_nested(
1052+
typ: object,
1053+
) -> Union[Type[marshmallow.Schema], Callable[[], Type[marshmallow.Schema]]]:
10041054
"""
1005-
Return a new field for a nested dataclass field.
1055+
Return a marshmallow.Schema for a nested dataclass (or simple annotated class)
10061056
"""
10071057
if isinstance(typ, type) and hasattr(typ, "Schema"):
10081058
# marshmallow_dataclass.dataclass
1009-
nested = typ.Schema
1010-
else:
1011-
assert isinstance(typ, Hashable)
1012-
nested = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME
1013-
if isinstance(nested, _Future):
1014-
nested = nested.result
1059+
# Defer evaluation of .Schema attribute, to avoid forward reference issues
1060+
return partial(getattr, typ, "Schema")
10151061

1016-
return marshmallow.fields.Nested(nested, **metadata)
1062+
class_schema = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME
1063+
if isinstance(class_schema, _Future):
1064+
return class_schema.result
1065+
return class_schema
1066+
1067+
1068+
def _is_simple_annotated_class(obj: object) -> bool:
1069+
"""Determine whether obj is a "simple annotated class".
1070+
1071+
The ```class_schema``` function can generate schemas for
1072+
simple annotated classes (as well as for dataclasses).
1073+
"""
1074+
if not isinstance(obj, type):
1075+
return False
1076+
if getattr(obj, "__init__", None) is not object.__init__:
1077+
return False
1078+
if getattr(obj, "__new__", None) is not object.__new__:
1079+
return False
1080+
1081+
schema_ctx = _schema_ctx_stack.top
1082+
type_hints = get_type_hints(
1083+
obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns
1084+
)
1085+
return any(not typing_inspect.is_classvar(th) for th in type_hints.values())
10171086

10181087

10191088
def field_for_schema(
@@ -1052,7 +1121,7 @@ def field_for_schema(
10521121

10531122

10541123
def _field_for_schema(
1055-
typ: type,
1124+
typ: Union[type, object],
10561125
default: Any = marshmallow.missing,
10571126
metadata: Optional[Mapping[str, Any]] = None,
10581127
) -> marshmallow.fields.Field:
@@ -1122,54 +1191,25 @@ def _field_for_schema(
11221191
if isinstance(typ, type) and issubclass(typ, Enum):
11231192
return _field_for_enum(typ, metadata)
11241193

1125-
# Assume nested marshmallow dataclass (and hope for the best)
1126-
return _field_for_dataclass(typ, metadata)
1127-
1128-
1129-
def _base_schema(
1130-
clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None
1131-
) -> Type[marshmallow.Schema]:
1132-
"""
1133-
Base schema factory that creates a schema for `clazz` derived either from `base_schema`
1134-
or `BaseSchema`
1135-
"""
1136-
1137-
# Remove `type: ignore` when mypy handles dynamic base classes
1138-
# https://github.com/python/mypy/issues/2813
1139-
class BaseSchema(base_schema or marshmallow.Schema): # type: ignore
1140-
def load(self, data: Mapping, *, many: Optional[bool] = None, **kwargs):
1141-
all_loaded = super().load(data, many=many, **kwargs)
1142-
many = self.many if many is None else bool(many)
1143-
if many:
1144-
return [clazz(**loaded) for loaded in all_loaded]
1145-
else:
1146-
return clazz(**all_loaded)
1147-
1148-
return BaseSchema
1149-
1150-
1151-
def _get_field_default(field: dataclasses.Field):
1152-
"""
1153-
Return a marshmallow default value given a dataclass default value
1194+
# nested dataclasses
1195+
if (
1196+
dataclasses.is_dataclass(typ)
1197+
or _is_generic_alias_of_dataclass(typ)
1198+
or _is_simple_annotated_class(typ)
1199+
):
1200+
nested = _schema_for_nested(typ)
1201+
# type spec for Nested.__init__ is not correct
1202+
return marshmallow.fields.Nested(nested, **metadata) # type: ignore[arg-type]
11541203

1155-
>>> _get_field_default(dataclasses.field())
1156-
<marshmallow.missing>
1157-
"""
1158-
# Remove `type: ignore` when https://github.com/python/mypy/issues/6910 is fixed
1159-
default_factory = field.default_factory # type: ignore
1160-
if default_factory is not dataclasses.MISSING:
1161-
return default_factory
1162-
elif field.default is dataclasses.MISSING:
1163-
return marshmallow.missing
1164-
return field.default
1204+
raise TypeError(f"can not deduce field type for {typ}")
11651205

11661206

11671207
def NewType(
11681208
name: str,
11691209
typ: Type[_U],
11701210
field: Optional[Type[marshmallow.fields.Field]] = None,
1171-
**kwargs,
1172-
) -> Callable[[_U], _U]:
1211+
**kwargs: Any,
1212+
) -> type:
11731213
"""NewType creates simple unique types
11741214
to which you can attach custom marshmallow attributes.
11751215
All the keyword arguments passed to this function will be transmitted
@@ -1202,9 +1242,9 @@ def NewType(
12021242
# noinspection PyTypeHints
12031243
new_type = typing_NewType(name, typ) # type: ignore
12041244
# noinspection PyTypeHints
1205-
new_type._marshmallow_field = field # type: ignore
1245+
new_type._marshmallow_field = field
12061246
# noinspection PyTypeHints
1207-
new_type._marshmallow_args = kwargs # type: ignore
1247+
new_type._marshmallow_args = kwargs
12081248
return new_type
12091249

12101250

0 commit comments

Comments
 (0)