@@ -653,6 +653,9 @@ class _SchemaContext:
653
653
default_factory = dict
654
654
)
655
655
656
+ def replace (self , generic_args : Optional [_GenericArgs ]) -> "_SchemaContext" :
657
+ return dataclasses .replace (self , generic_args = generic_args )
658
+
656
659
def get_type_mapping (
657
660
self , use_mro : bool = False
658
661
) -> Mapping [Any , Type [marshmallow .fields .Field ]]:
@@ -717,63 +720,109 @@ def _internal_class_schema(
717
720
718
721
generic_args = schema_ctx .generic_args
719
722
720
- if _is_generic_alias_of_dataclass (clazz ):
721
- generic_args = _GenericArgs (clazz , generic_args )
722
- clazz = typing_inspect .get_origin (clazz )
723
- elif not dataclasses .is_dataclass (clazz ):
724
- try :
725
- warnings .warn (
726
- "****** WARNING ****** "
727
- f"marshmallow_dataclass was called on the class { clazz } , which is not a dataclass. "
728
- "It is going to try and convert the class into a dataclass, which may have "
729
- "undesirable side effects. To avoid this message, make sure all your classes and "
730
- "all the classes of their fields are either explicitly supported by "
731
- "marshmallow_dataclass, or define the schema explicitly using "
732
- "field(metadata=dict(marshmallow_field=...)). For more information, see "
733
- "https://github.com/lovasoa/marshmallow_dataclass/issues/51 "
734
- "****** WARNING ******"
735
- )
736
- dataclasses .dataclass (clazz )
737
- except Exception as exc :
738
- raise TypeError (
739
- f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
740
- ) from exc
741
-
742
- fields = dataclasses .fields (clazz )
743
-
744
- # Copy all marshmallow hooks and whitelisted members of the dataclass to the schema.
745
- attributes = {
746
- k : v
747
- for k , v in inspect .getmembers (clazz )
748
- if hasattr (v , "__marshmallow_hook__" ) or k in MEMBERS_WHITELIST
749
- }
723
+ constructor : Callable [..., object ]
724
+
725
+ if _is_simple_annotated_class (clazz ):
726
+ class_name = clazz .__name__
727
+ constructor = _simple_class_constructor (clazz )
728
+ attributes = _schema_attrs_for_simple_class (clazz )
729
+ elif _is_generic_alias_of_dataclass (clazz ):
730
+ origin = get_origin (clazz )
731
+ assert isinstance (origin , type )
732
+ class_name = origin .__name__
733
+ constructor = origin
734
+ with schema_ctx .replace (generic_args = _GenericArgs (clazz , generic_args )):
735
+ attributes = _schema_attrs_for_dataclass (origin )
736
+ elif dataclasses .is_dataclass (clazz ):
737
+ class_name = clazz .__name__
738
+ constructor = clazz
739
+ attributes = _schema_attrs_for_dataclass (clazz )
740
+ else :
741
+ raise TypeError (f"{ clazz } is not a dataclass or a simple annotated class" )
742
+
743
+ base_schema = marshmallow .Schema
744
+ if schema_ctx .base_schema is not None :
745
+ base_schema = schema_ctx .base_schema
746
+
747
+ load_to_dict = base_schema .load
748
+
749
+ def load (
750
+ self : marshmallow .Schema ,
751
+ data : Union [Mapping [str , Any ], Iterable [Mapping [str , Any ]]],
752
+ * ,
753
+ many : Optional [bool ] = None ,
754
+ unknown : Optional [str ] = None ,
755
+ ** kwargs : Any ,
756
+ ) -> Any :
757
+ many = self .many if many is None else bool (many )
758
+ loaded = load_to_dict (self , data , many = many , unknown = unknown , ** kwargs )
759
+ if many :
760
+ return [constructor (** item ) for item in loaded ]
761
+ else :
762
+ return constructor (** loaded )
750
763
751
- # Update the schema members to contain marshmallow fields instead of dataclass fields
752
- type_hints = get_type_hints (
753
- clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
754
- )
755
- with dataclasses .replace (schema_ctx , generic_args = generic_args ):
756
- attributes .update (
757
- (
758
- field .name ,
759
- _field_for_schema (
760
- type_hints [field .name ],
761
- _get_field_default (field ),
762
- field .metadata ,
763
- ),
764
- )
765
- for field in fields
766
- if field .init
767
- )
764
+ attributes ["load" ] = load
768
765
769
766
schema_class : Type [marshmallow .Schema ] = type (
770
- clazz . __name__ , (_base_schema ( clazz , schema_ctx . base_schema ) ,), attributes
767
+ f" { class_name } Schema" , (base_schema ,), attributes
771
768
)
769
+
772
770
future .set_result (schema_class )
773
771
_schema_cache [cache_key ] = schema_class
774
772
return schema_class
775
773
776
774
775
+ def _marshmallow_hooks (clazz : type ) -> Iterator [Tuple [str , Any ]]:
776
+ for name , attr in inspect .getmembers (clazz ):
777
+ if hasattr (attr , "__marshmallow_hook__" ) or name in MEMBERS_WHITELIST :
778
+ yield name , attr
779
+
780
+
781
+ def _schema_attrs_for_dataclass (clazz : type ) -> Dict [str , Any ]:
782
+ schema_ctx = _schema_ctx_stack .top
783
+ type_hints = get_type_hints (
784
+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
785
+ )
786
+
787
+ attrs = dict (_marshmallow_hooks (clazz ))
788
+ for field in dataclasses .fields (clazz ):
789
+ if field .init :
790
+ typ = type_hints [field .name ]
791
+ default = (
792
+ field .default_factory
793
+ if field .default_factory is not dataclasses .MISSING
794
+ else field .default
795
+ if field .default is not dataclasses .MISSING
796
+ else marshmallow .missing
797
+ )
798
+ attrs [field .name ] = _field_for_schema (typ , default , field .metadata )
799
+ return attrs
800
+
801
+
802
+ def _schema_attrs_for_simple_class (clazz : type ) -> Dict [str , Any ]:
803
+ schema_ctx = _schema_ctx_stack .top
804
+ type_hints = get_type_hints (
805
+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
806
+ )
807
+
808
+ attrs = dict (_marshmallow_hooks (clazz ))
809
+ for field_name , typ in type_hints .items ():
810
+ if not typing_inspect .is_classvar (typ ):
811
+ default = getattr (clazz , field_name , marshmallow .missing )
812
+ attrs [field_name ] = _field_for_schema (typ , default )
813
+ return attrs
814
+
815
+
816
+ def _simple_class_constructor (clazz : Type [_U ]) -> Callable [..., _U ]:
817
+ def constructor (** kwargs : Any ) -> _U :
818
+ obj = clazz .__new__ (clazz )
819
+ for k , v in kwargs .items ():
820
+ setattr (obj , k , v )
821
+ return obj
822
+
823
+ return constructor
824
+
825
+
777
826
def _is_builtin_collection_type (typ : object ) -> bool :
778
827
origin = get_origin (typ )
779
828
if origin is None :
@@ -953,8 +1002,8 @@ def _field_for_new_type(
953
1002
** metadata ,
954
1003
"validate" : validators if validators else None ,
955
1004
}
956
- if hasattr (typ , "__name__" ):
957
- metadata .setdefault ("metadata" , {}).setdefault ("description" , typ . __name__ )
1005
+ type_name = getattr (typ , "__name__" , repr ( typ ))
1006
+ metadata .setdefault ("metadata" , {}).setdefault ("description" , type_name )
958
1007
959
1008
field : Optional [Type [marshmallow .fields .Field ]] = getattr (
960
1009
typ , "_marshmallow_field" , None
@@ -981,22 +1030,41 @@ def _field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.F
981
1030
return marshmallow_enum .EnumField (typ , ** metadata )
982
1031
983
1032
984
- def _field_for_dataclass (
985
- typ : Union [ Type , object ], metadata : Dict [ str , Any ]
986
- ) -> marshmallow .fields . Field :
1033
+ def _schema_for_nested (
1034
+ typ : object ,
1035
+ ) -> Union [ Type [ marshmallow .Schema ], Callable [[], Type [ marshmallow . Schema ]]] :
987
1036
"""
988
- Return a new field for a nested dataclass field.
1037
+ Return a marshmallow.Schema for a nested dataclass (or simple annotated class)
989
1038
"""
990
1039
if isinstance (typ , type ) and hasattr (typ , "Schema" ):
991
1040
# marshmallow_dataclass.dataclass
992
- nested = typ .Schema
993
- else :
994
- assert isinstance (typ , Hashable )
995
- nested = _internal_class_schema (typ ) # type: ignore[arg-type] # FIXME
996
- if isinstance (nested , _Future ):
997
- nested = nested .result
1041
+ # Defer evaluation of .Schema attribute, to avoid forward reference issues
1042
+ return partial (getattr , typ , "Schema" )
998
1043
999
- return marshmallow .fields .Nested (nested , ** metadata )
1044
+ class_schema = _internal_class_schema (typ ) # type: ignore[arg-type] # FIXME
1045
+ if isinstance (class_schema , _Future ):
1046
+ return class_schema .result
1047
+ return class_schema
1048
+
1049
+
1050
+ def _is_simple_annotated_class (obj : object ) -> bool :
1051
+ """Determine whether obj is a "simple annotated class".
1052
+
1053
+ The ```class_schema``` function can generate schemas for
1054
+ simple annotated classes (as well as for dataclasses).
1055
+ """
1056
+ if not isinstance (obj , type ):
1057
+ return False
1058
+ if getattr (obj , "__init__" , None ) is not object .__init__ :
1059
+ return False
1060
+ if getattr (obj , "__new__" , None ) is not object .__new__ :
1061
+ return False
1062
+
1063
+ schema_ctx = _schema_ctx_stack .top
1064
+ type_hints = get_type_hints (
1065
+ obj , globalns = schema_ctx .globalns , localns = schema_ctx .localns
1066
+ )
1067
+ return any (not typing_inspect .is_classvar (th ) for th in type_hints .values ())
1000
1068
1001
1069
1002
1070
def field_for_schema (
@@ -1105,54 +1173,25 @@ def _field_for_schema(
1105
1173
if isinstance (typ , type ) and issubclass (typ , Enum ):
1106
1174
return _field_for_enum (typ , metadata )
1107
1175
1108
- # Assume nested marshmallow dataclass (and hope for the best)
1109
- return _field_for_dataclass (typ , metadata )
1110
-
1111
-
1112
- def _base_schema (
1113
- clazz : type , base_schema : Optional [Type [marshmallow .Schema ]] = None
1114
- ) -> Type [marshmallow .Schema ]:
1115
- """
1116
- Base schema factory that creates a schema for `clazz` derived either from `base_schema`
1117
- or `BaseSchema`
1118
- """
1119
-
1120
- # Remove `type: ignore` when mypy handles dynamic base classes
1121
- # https://github.com/python/mypy/issues/2813
1122
- class BaseSchema (base_schema or marshmallow .Schema ): # type: ignore
1123
- def load (self , data : Mapping , * , many : Optional [bool ] = None , ** kwargs ):
1124
- all_loaded = super ().load (data , many = many , ** kwargs )
1125
- many = self .many if many is None else bool (many )
1126
- if many :
1127
- return [clazz (** loaded ) for loaded in all_loaded ]
1128
- else :
1129
- return clazz (** all_loaded )
1130
-
1131
- return BaseSchema
1132
-
1133
-
1134
- def _get_field_default (field : dataclasses .Field ):
1135
- """
1136
- Return a marshmallow default value given a dataclass default value
1176
+ # nested dataclasses
1177
+ if (
1178
+ dataclasses .is_dataclass (typ )
1179
+ or _is_generic_alias_of_dataclass (typ )
1180
+ or _is_simple_annotated_class (typ )
1181
+ ):
1182
+ nested = _schema_for_nested (typ )
1183
+ # type spec for Nested.__init__ is not correct
1184
+ return marshmallow .fields .Nested (nested , ** metadata ) # type: ignore[arg-type]
1137
1185
1138
- >>> _get_field_default(dataclasses.field())
1139
- <marshmallow.missing>
1140
- """
1141
- # Remove `type: ignore` when https://github.com/python/mypy/issues/6910 is fixed
1142
- default_factory = field .default_factory # type: ignore
1143
- if default_factory is not dataclasses .MISSING :
1144
- return default_factory
1145
- elif field .default is dataclasses .MISSING :
1146
- return marshmallow .missing
1147
- return field .default
1186
+ raise TypeError (f"can not deduce field type for { typ } " )
1148
1187
1149
1188
1150
1189
def NewType (
1151
1190
name : str ,
1152
1191
typ : Type [_U ],
1153
1192
field : Optional [Type [marshmallow .fields .Field ]] = None ,
1154
- ** kwargs ,
1155
- ) -> Callable [[ _U ], _U ] :
1193
+ ** kwargs : Any ,
1194
+ ) -> type :
1156
1195
"""NewType creates simple unique types
1157
1196
to which you can attach custom marshmallow attributes.
1158
1197
All the keyword arguments passed to this function will be transmitted
@@ -1185,9 +1224,9 @@ def NewType(
1185
1224
# noinspection PyTypeHints
1186
1225
new_type = typing_NewType (name , typ ) # type: ignore
1187
1226
# noinspection PyTypeHints
1188
- new_type ._marshmallow_field = field # type: ignore
1227
+ new_type ._marshmallow_field = field
1189
1228
# noinspection PyTypeHints
1190
- new_type ._marshmallow_args = kwargs # type: ignore
1229
+ new_type ._marshmallow_args = kwargs
1191
1230
return new_type
1192
1231
1193
1232
0 commit comments