Skip to content

Commit 769ebc6

Browse files
committed
feat: Do not auto-dataclassify non-dataclasses
This implement the suggestions made in lovasoa#51. See lovasoa#51 (comment)
1 parent 39cfa30 commit 769ebc6

File tree

3 files changed

+171
-113
lines changed

3 files changed

+171
-113
lines changed

marshmallow_dataclass/__init__.py

+143-104
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,9 @@ class _SchemaContext:
653653
default_factory=dict
654654
)
655655

656+
def replace(self, generic_args: Optional[_GenericArgs]) -> "_SchemaContext":
657+
return dataclasses.replace(self, generic_args=generic_args)
658+
656659
def get_type_mapping(
657660
self, use_mro: bool = False
658661
) -> Mapping[Any, Type[marshmallow.fields.Field]]:
@@ -717,63 +720,109 @@ def _internal_class_schema(
717720

718721
generic_args = schema_ctx.generic_args
719722

720-
if _is_generic_alias_of_dataclass(clazz):
721-
generic_args = _GenericArgs(clazz, generic_args)
722-
clazz = typing_inspect.get_origin(clazz)
723-
elif not dataclasses.is_dataclass(clazz):
724-
try:
725-
warnings.warn(
726-
"****** WARNING ****** "
727-
f"marshmallow_dataclass was called on the class {clazz}, which is not a dataclass. "
728-
"It is going to try and convert the class into a dataclass, which may have "
729-
"undesirable side effects. To avoid this message, make sure all your classes and "
730-
"all the classes of their fields are either explicitly supported by "
731-
"marshmallow_dataclass, or define the schema explicitly using "
732-
"field(metadata=dict(marshmallow_field=...)). For more information, see "
733-
"https://github.com/lovasoa/marshmallow_dataclass/issues/51 "
734-
"****** WARNING ******"
735-
)
736-
dataclasses.dataclass(clazz)
737-
except Exception as exc:
738-
raise TypeError(
739-
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
740-
) from exc
741-
742-
fields = dataclasses.fields(clazz)
743-
744-
# Copy all marshmallow hooks and whitelisted members of the dataclass to the schema.
745-
attributes = {
746-
k: v
747-
for k, v in inspect.getmembers(clazz)
748-
if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST
749-
}
723+
constructor: Callable[..., object]
724+
725+
if _is_simple_annotated_class(clazz):
726+
class_name = clazz.__name__
727+
constructor = _simple_class_constructor(clazz)
728+
attributes = _schema_attrs_for_simple_class(clazz)
729+
elif _is_generic_alias_of_dataclass(clazz):
730+
origin = get_origin(clazz)
731+
assert isinstance(origin, type)
732+
class_name = origin.__name__
733+
constructor = origin
734+
with schema_ctx.replace(generic_args=_GenericArgs(clazz, generic_args)):
735+
attributes = _schema_attrs_for_dataclass(origin)
736+
elif dataclasses.is_dataclass(clazz):
737+
class_name = clazz.__name__
738+
constructor = clazz
739+
attributes = _schema_attrs_for_dataclass(clazz)
740+
else:
741+
raise TypeError(f"{clazz} is not a dataclass or a simple annotated class")
742+
743+
base_schema = marshmallow.Schema
744+
if schema_ctx.base_schema is not None:
745+
base_schema = schema_ctx.base_schema
746+
747+
load_to_dict = base_schema.load
748+
749+
def load(
750+
self: marshmallow.Schema,
751+
data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]],
752+
*,
753+
many: Optional[bool] = None,
754+
unknown: Optional[str] = None,
755+
**kwargs: Any,
756+
) -> Any:
757+
many = self.many if many is None else bool(many)
758+
loaded = load_to_dict(self, data, many=many, unknown=unknown, **kwargs)
759+
if many:
760+
return [constructor(**item) for item in loaded]
761+
else:
762+
return constructor(**loaded)
750763

751-
# Update the schema members to contain marshmallow fields instead of dataclass fields
752-
type_hints = get_type_hints(
753-
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
754-
)
755-
with dataclasses.replace(schema_ctx, generic_args=generic_args):
756-
attributes.update(
757-
(
758-
field.name,
759-
_field_for_schema(
760-
type_hints[field.name],
761-
_get_field_default(field),
762-
field.metadata,
763-
),
764-
)
765-
for field in fields
766-
if field.init
767-
)
764+
attributes["load"] = load
768765

769766
schema_class: Type[marshmallow.Schema] = type(
770-
clazz.__name__, (_base_schema(clazz, schema_ctx.base_schema),), attributes
767+
f"{class_name}Schema", (base_schema,), attributes
771768
)
769+
772770
future.set_result(schema_class)
773771
_schema_cache[cache_key] = schema_class
774772
return schema_class
775773

776774

775+
def _marshmallow_hooks(clazz: type) -> Iterator[Tuple[str, Any]]:
776+
for name, attr in inspect.getmembers(clazz):
777+
if hasattr(attr, "__marshmallow_hook__") or name in MEMBERS_WHITELIST:
778+
yield name, attr
779+
780+
781+
def _schema_attrs_for_dataclass(clazz: type) -> Dict[str, Any]:
782+
schema_ctx = _schema_ctx_stack.top
783+
type_hints = get_type_hints(
784+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
785+
)
786+
787+
attrs = dict(_marshmallow_hooks(clazz))
788+
for field in dataclasses.fields(clazz):
789+
if field.init:
790+
typ = type_hints[field.name]
791+
default = (
792+
field.default_factory
793+
if field.default_factory is not dataclasses.MISSING
794+
else field.default
795+
if field.default is not dataclasses.MISSING
796+
else marshmallow.missing
797+
)
798+
attrs[field.name] = _field_for_schema(typ, default, field.metadata)
799+
return attrs
800+
801+
802+
def _schema_attrs_for_simple_class(clazz: type) -> Dict[str, Any]:
803+
schema_ctx = _schema_ctx_stack.top
804+
type_hints = get_type_hints(
805+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
806+
)
807+
808+
attrs = dict(_marshmallow_hooks(clazz))
809+
for field_name, typ in type_hints.items():
810+
if not typing_inspect.is_classvar(typ):
811+
default = getattr(clazz, field_name, marshmallow.missing)
812+
attrs[field_name] = _field_for_schema(typ, default)
813+
return attrs
814+
815+
816+
def _simple_class_constructor(clazz: Type[_U]) -> Callable[..., _U]:
817+
def constructor(**kwargs: Any) -> _U:
818+
obj = clazz.__new__(clazz)
819+
for k, v in kwargs.items():
820+
setattr(obj, k, v)
821+
return obj
822+
823+
return constructor
824+
825+
777826
def _is_builtin_collection_type(typ: object) -> bool:
778827
origin = get_origin(typ)
779828
if origin is None:
@@ -953,8 +1002,8 @@ def _field_for_new_type(
9531002
**metadata,
9541003
"validate": validators if validators else None,
9551004
}
956-
if hasattr(typ, "__name__"):
957-
metadata.setdefault("metadata", {}).setdefault("description", typ.__name__)
1005+
type_name = getattr(typ, "__name__", repr(typ))
1006+
metadata.setdefault("metadata", {}).setdefault("description", type_name)
9581007

9591008
field: Optional[Type[marshmallow.fields.Field]] = getattr(
9601009
typ, "_marshmallow_field", None
@@ -981,22 +1030,41 @@ def _field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.F
9811030
return marshmallow_enum.EnumField(typ, **metadata)
9821031

9831032

984-
def _field_for_dataclass(
985-
typ: Union[Type, object], metadata: Dict[str, Any]
986-
) -> marshmallow.fields.Field:
1033+
def _schema_for_nested(
1034+
typ: object,
1035+
) -> Union[Type[marshmallow.Schema], Callable[[], Type[marshmallow.Schema]]]:
9871036
"""
988-
Return a new field for a nested dataclass field.
1037+
Return a marshmallow.Schema for a nested dataclass (or simple annotated class)
9891038
"""
9901039
if isinstance(typ, type) and hasattr(typ, "Schema"):
9911040
# marshmallow_dataclass.dataclass
992-
nested = typ.Schema
993-
else:
994-
assert isinstance(typ, Hashable)
995-
nested = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME
996-
if isinstance(nested, _Future):
997-
nested = nested.result
1041+
# Defer evaluation of .Schema attribute, to avoid forward reference issues
1042+
return partial(getattr, typ, "Schema")
9981043

999-
return marshmallow.fields.Nested(nested, **metadata)
1044+
class_schema = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME
1045+
if isinstance(class_schema, _Future):
1046+
return class_schema.result
1047+
return class_schema
1048+
1049+
1050+
def _is_simple_annotated_class(obj: object) -> bool:
1051+
"""Determine whether obj is a "simple annotated class".
1052+
1053+
The ```class_schema``` function can generate schemas for
1054+
simple annotated classes (as well as for dataclasses).
1055+
"""
1056+
if not isinstance(obj, type):
1057+
return False
1058+
if getattr(obj, "__init__", None) is not object.__init__:
1059+
return False
1060+
if getattr(obj, "__new__", None) is not object.__new__:
1061+
return False
1062+
1063+
schema_ctx = _schema_ctx_stack.top
1064+
type_hints = get_type_hints(
1065+
obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns
1066+
)
1067+
return any(not typing_inspect.is_classvar(th) for th in type_hints.values())
10001068

10011069

10021070
def field_for_schema(
@@ -1105,54 +1173,25 @@ def _field_for_schema(
11051173
if isinstance(typ, type) and issubclass(typ, Enum):
11061174
return _field_for_enum(typ, metadata)
11071175

1108-
# Assume nested marshmallow dataclass (and hope for the best)
1109-
return _field_for_dataclass(typ, metadata)
1110-
1111-
1112-
def _base_schema(
1113-
clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None
1114-
) -> Type[marshmallow.Schema]:
1115-
"""
1116-
Base schema factory that creates a schema for `clazz` derived either from `base_schema`
1117-
or `BaseSchema`
1118-
"""
1119-
1120-
# Remove `type: ignore` when mypy handles dynamic base classes
1121-
# https://github.com/python/mypy/issues/2813
1122-
class BaseSchema(base_schema or marshmallow.Schema): # type: ignore
1123-
def load(self, data: Mapping, *, many: Optional[bool] = None, **kwargs):
1124-
all_loaded = super().load(data, many=many, **kwargs)
1125-
many = self.many if many is None else bool(many)
1126-
if many:
1127-
return [clazz(**loaded) for loaded in all_loaded]
1128-
else:
1129-
return clazz(**all_loaded)
1130-
1131-
return BaseSchema
1132-
1133-
1134-
def _get_field_default(field: dataclasses.Field):
1135-
"""
1136-
Return a marshmallow default value given a dataclass default value
1176+
# nested dataclasses
1177+
if (
1178+
dataclasses.is_dataclass(typ)
1179+
or _is_generic_alias_of_dataclass(typ)
1180+
or _is_simple_annotated_class(typ)
1181+
):
1182+
nested = _schema_for_nested(typ)
1183+
# type spec for Nested.__init__ is not correct
1184+
return marshmallow.fields.Nested(nested, **metadata) # type: ignore[arg-type]
11371185

1138-
>>> _get_field_default(dataclasses.field())
1139-
<marshmallow.missing>
1140-
"""
1141-
# Remove `type: ignore` when https://github.com/python/mypy/issues/6910 is fixed
1142-
default_factory = field.default_factory # type: ignore
1143-
if default_factory is not dataclasses.MISSING:
1144-
return default_factory
1145-
elif field.default is dataclasses.MISSING:
1146-
return marshmallow.missing
1147-
return field.default
1186+
raise TypeError(f"can not deduce field type for {typ}")
11481187

11491188

11501189
def NewType(
11511190
name: str,
11521191
typ: Type[_U],
11531192
field: Optional[Type[marshmallow.fields.Field]] = None,
1154-
**kwargs,
1155-
) -> Callable[[_U], _U]:
1193+
**kwargs: Any,
1194+
) -> type:
11561195
"""NewType creates simple unique types
11571196
to which you can attach custom marshmallow attributes.
11581197
All the keyword arguments passed to this function will be transmitted
@@ -1185,9 +1224,9 @@ def NewType(
11851224
# noinspection PyTypeHints
11861225
new_type = typing_NewType(name, typ) # type: ignore
11871226
# noinspection PyTypeHints
1188-
new_type._marshmallow_field = field # type: ignore
1227+
new_type._marshmallow_field = field
11891228
# noinspection PyTypeHints
1190-
new_type._marshmallow_args = kwargs # type: ignore
1229+
new_type._marshmallow_args = kwargs
11911230
return new_type
11921231

11931232

tests/test_class_schema.py

+13
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,19 @@ class J:
406406
[validator_a, validator_b, validator_c, validator_d],
407407
)
408408

409+
def test_simple_annotated_class(self):
410+
class Child:
411+
x: int
412+
413+
@dataclasses.dataclass
414+
class Container:
415+
child: Child
416+
417+
schema = class_schema(Container)()
418+
419+
loaded = schema.load({"child": {"x": "42"}})
420+
self.assertEqual(loaded.child.x, 42)
421+
409422
def test_generic_dataclass(self):
410423
T = typing.TypeVar("T")
411424

tests/test_field_for_schema.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import typing
44
import unittest
55
from enum import Enum
6-
from typing import Dict, Optional, Union, Any, List, Tuple
6+
from typing import Dict, Optional, Union, Any, List, Tuple, Iterable
77

8-
try:
9-
from typing import Final, Literal # type: ignore[attr-defined]
10-
except ImportError:
11-
from typing_extensions import Final, Literal # type: ignore[assignment]
8+
if sys.version_info >= (3, 8):
9+
from typing import Final, Literal
10+
else:
11+
from typing_extensions import Final, Literal
1212

1313
from marshmallow import fields, Schema, validate
1414

@@ -21,14 +21,18 @@
2121

2222

2323
class TestFieldForSchema(unittest.TestCase):
24-
def assertFieldsEqual(self, a: fields.Field, b: fields.Field):
24+
def assertFieldsEqual(
25+
self, a: fields.Field, b: fields.Field, *, ignore_attrs: Iterable[str] = ()
26+
) -> None:
27+
ignored = set(ignore_attrs)
28+
2529
self.assertEqual(a.__class__, b.__class__, "field class")
2630

2731
def attrs(x):
2832
return {
2933
k: f"{v!r} ({v.__mro__!r})" if inspect.isclass(v) else repr(v)
3034
for k, v in x.__dict__.items()
31-
if not k.startswith("_")
35+
if not (k in ignored or k.startswith("_"))
3236
}
3337

3438
self.assertEqual(attrs(a), attrs(b))
@@ -213,10 +217,12 @@ class NewSchema(Schema):
213217
class NewDataclass:
214218
pass
215219

220+
field = field_for_schema(NewDataclass, metadata=dict(required=False))
221+
216222
self.assertFieldsEqual(
217-
field_for_schema(NewDataclass, metadata=dict(required=False)),
218-
fields.Nested(NewDataclass.Schema),
223+
field, fields.Nested(NewDataclass.Schema), ignore_attrs=["nested"]
219224
)
225+
self.assertIs(type(field.schema), NewDataclass.Schema)
220226

221227
def test_override_container_type_with_type_mapping(self):
222228
type_mapping = [

0 commit comments

Comments
 (0)