@@ -680,6 +680,9 @@ class _SchemaContext:
680
680
default_factory = dict
681
681
)
682
682
683
+ def replace (self , generic_args : Optional [_GenericArgs ]) -> "_SchemaContext" :
684
+ return dataclasses .replace (self , generic_args = generic_args )
685
+
683
686
def get_type_mapping (
684
687
self , include_marshmallow_default : bool = False
685
688
) -> _TypeMapping :
@@ -738,63 +741,109 @@ def _internal_class_schema(
738
741
739
742
generic_args = schema_ctx .generic_args
740
743
741
- if _is_generic_alias_of_dataclass (clazz ):
742
- generic_args = _GenericArgs (clazz , generic_args )
743
- clazz = typing_inspect .get_origin (clazz )
744
- elif not dataclasses .is_dataclass (clazz ):
745
- try :
746
- warnings .warn (
747
- "****** WARNING ****** "
748
- f"marshmallow_dataclass was called on the class { clazz } , which is not a dataclass. "
749
- "It is going to try and convert the class into a dataclass, which may have "
750
- "undesirable side effects. To avoid this message, make sure all your classes and "
751
- "all the classes of their fields are either explicitly supported by "
752
- "marshmallow_dataclass, or define the schema explicitly using "
753
- "field(metadata=dict(marshmallow_field=...)). For more information, see "
754
- "https://github.com/lovasoa/marshmallow_dataclass/issues/51 "
755
- "****** WARNING ******"
756
- )
757
- dataclasses .dataclass (clazz )
758
- except Exception as exc :
759
- raise TypeError (
760
- f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
761
- ) from exc
762
-
763
- fields = dataclasses .fields (clazz )
764
-
765
- # Copy all marshmallow hooks and whitelisted members of the dataclass to the schema.
766
- attributes = {
767
- k : v
768
- for k , v in inspect .getmembers (clazz )
769
- if hasattr (v , "__marshmallow_hook__" ) or k in MEMBERS_WHITELIST
770
- }
744
+ constructor : Callable [..., object ]
745
+
746
+ if _is_simple_annotated_class (clazz ):
747
+ class_name = clazz .__name__
748
+ constructor = _simple_class_constructor (clazz )
749
+ attributes = _schema_attrs_for_simple_class (clazz )
750
+ elif _is_generic_alias_of_dataclass (clazz ):
751
+ origin = get_origin (clazz )
752
+ assert isinstance (origin , type )
753
+ class_name = origin .__name__
754
+ constructor = origin
755
+ with schema_ctx .replace (generic_args = _GenericArgs (clazz , generic_args )):
756
+ attributes = _schema_attrs_for_dataclass (origin )
757
+ elif dataclasses .is_dataclass (clazz ):
758
+ class_name = clazz .__name__
759
+ constructor = clazz
760
+ attributes = _schema_attrs_for_dataclass (clazz )
761
+ else :
762
+ raise TypeError (f"{ clazz } is not a dataclass or a simple annotated class" )
763
+
764
+ base_schema = marshmallow .Schema
765
+ if schema_ctx .base_schema is not None :
766
+ base_schema = schema_ctx .base_schema
767
+
768
+ load_to_dict = base_schema .load
769
+
770
+ def load (
771
+ self : marshmallow .Schema ,
772
+ data : Union [Mapping [str , Any ], Iterable [Mapping [str , Any ]]],
773
+ * ,
774
+ many : Optional [bool ] = None ,
775
+ unknown : Optional [str ] = None ,
776
+ ** kwargs : Any ,
777
+ ) -> Any :
778
+ many = self .many if many is None else bool (many )
779
+ loaded = load_to_dict (self , data , many = many , unknown = unknown , ** kwargs )
780
+ if many :
781
+ return [constructor (** item ) for item in loaded ]
782
+ else :
783
+ return constructor (** loaded )
771
784
772
- # Update the schema members to contain marshmallow fields instead of dataclass fields
773
- type_hints = get_type_hints (
774
- clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
775
- )
776
- with dataclasses .replace (schema_ctx , generic_args = generic_args ):
777
- attributes .update (
778
- (
779
- field .name ,
780
- _field_for_schema (
781
- type_hints [field .name ],
782
- _get_field_default (field ),
783
- field .metadata ,
784
- ),
785
- )
786
- for field in fields
787
- if field .init
788
- )
785
+ attributes ["load" ] = load
789
786
790
787
schema_class : Type [marshmallow .Schema ] = type (
791
- clazz . __name__ , (_base_schema ( clazz , schema_ctx . base_schema ) ,), attributes
788
+ f" { class_name } Schema" , (base_schema ,), attributes
792
789
)
790
+
793
791
future .set_result (schema_class )
794
792
_schema_cache [cache_key ] = schema_class
795
793
return schema_class
796
794
797
795
796
+ def _marshmallow_hooks (clazz : type ) -> Iterator [Tuple [str , Any ]]:
797
+ for name , attr in inspect .getmembers (clazz ):
798
+ if hasattr (attr , "__marshmallow_hook__" ) or name in MEMBERS_WHITELIST :
799
+ yield name , attr
800
+
801
+
802
+ def _schema_attrs_for_dataclass (clazz : type ) -> Dict [str , Any ]:
803
+ schema_ctx = _schema_ctx_stack .top
804
+ type_hints = get_type_hints (
805
+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
806
+ )
807
+
808
+ attrs = dict (_marshmallow_hooks (clazz ))
809
+ for field in dataclasses .fields (clazz ):
810
+ if field .init :
811
+ typ = type_hints [field .name ]
812
+ default = (
813
+ field .default_factory
814
+ if field .default_factory is not dataclasses .MISSING
815
+ else field .default
816
+ if field .default is not dataclasses .MISSING
817
+ else marshmallow .missing
818
+ )
819
+ attrs [field .name ] = _field_for_schema (typ , default , field .metadata )
820
+ return attrs
821
+
822
+
823
+ def _schema_attrs_for_simple_class (clazz : type ) -> Dict [str , Any ]:
824
+ schema_ctx = _schema_ctx_stack .top
825
+ type_hints = get_type_hints (
826
+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
827
+ )
828
+
829
+ attrs = dict (_marshmallow_hooks (clazz ))
830
+ for field_name , typ in type_hints .items ():
831
+ if not typing_inspect .is_classvar (typ ):
832
+ default = getattr (clazz , field_name , marshmallow .missing )
833
+ attrs [field_name ] = _field_for_schema (typ , default )
834
+ return attrs
835
+
836
+
837
+ def _simple_class_constructor (clazz : Type [_U ]) -> Callable [..., _U ]:
838
+ def constructor (** kwargs : Any ) -> _U :
839
+ obj = clazz .__new__ (clazz )
840
+ for k , v in kwargs .items ():
841
+ setattr (obj , k , v )
842
+ return obj
843
+
844
+ return constructor
845
+
846
+
798
847
def _is_builtin_collection_type (typ : object ) -> bool :
799
848
origin = get_origin (typ )
800
849
if origin is None :
@@ -863,7 +912,7 @@ def get_field(i: int) -> marshmallow.fields.Field:
863
912
864
913
865
914
def _field_for_union_type (
866
- typ : type , metadata : Dict [str , Any ]
915
+ typ : object , metadata : Dict [str , Any ]
867
916
) -> marshmallow .fields .Field :
868
917
"""
869
918
Construct the appropriate Field for a union or optional type.
@@ -895,7 +944,7 @@ def _field_for_union_type(
895
944
896
945
897
946
def _field_for_literal_type (
898
- typ : type , metadata : Dict [str , Any ]
947
+ typ : object , metadata : Dict [str , Any ]
899
948
) -> marshmallow .fields .Field :
900
949
"""
901
950
Construct the appropriate Field for a Literal type.
@@ -911,7 +960,7 @@ def _field_for_literal_type(
911
960
return marshmallow .fields .Raw (validate = validate , ** metadata )
912
961
913
962
914
- def _get_subtype_for_final_type (typ : type , default : Any ) -> Any :
963
+ def _get_subtype_for_final_type (typ : object , default : Any ) -> Any :
915
964
"""
916
965
Construct the appropriate Field for a Final type.
917
966
"""
@@ -948,7 +997,7 @@ def _get_subtype_for_final_type(typ: type, default: Any) -> Any:
948
997
949
998
950
999
def _field_for_new_type (
951
- typ : Type , default : Any , metadata : Dict [str , Any ]
1000
+ typ : object , default : Any , metadata : Dict [str , Any ]
952
1001
) -> marshmallow .fields .Field :
953
1002
"""
954
1003
Return a new field for fields based on a NewType.
@@ -971,7 +1020,8 @@ def _field_for_new_type(
971
1020
** metadata ,
972
1021
"validate" : validators if validators else None ,
973
1022
}
974
- metadata .setdefault ("metadata" , {}).setdefault ("description" , typ .__name__ )
1023
+ type_name = getattr (typ , "__name__" , repr (typ ))
1024
+ metadata .setdefault ("metadata" , {}).setdefault ("description" , type_name )
975
1025
976
1026
field : Optional [Type [marshmallow .fields .Field ]] = getattr (
977
1027
typ , "_marshmallow_field" , None
@@ -998,22 +1048,41 @@ def _field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.F
998
1048
return marshmallow_enum .EnumField (typ , ** metadata )
999
1049
1000
1050
1001
- def _field_for_dataclass (
1002
- typ : Union [ Type , object ], metadata : Dict [ str , Any ]
1003
- ) -> marshmallow .fields . Field :
1051
+ def _schema_for_nested (
1052
+ typ : object ,
1053
+ ) -> Union [ Type [ marshmallow .Schema ], Callable [[], Type [ marshmallow . Schema ]]] :
1004
1054
"""
1005
- Return a new field for a nested dataclass field.
1055
+ Return a marshmallow.Schema for a nested dataclass (or simple annotated class)
1006
1056
"""
1007
1057
if isinstance (typ , type ) and hasattr (typ , "Schema" ):
1008
1058
# marshmallow_dataclass.dataclass
1009
- nested = typ .Schema
1010
- else :
1011
- assert isinstance (typ , Hashable )
1012
- nested = _internal_class_schema (typ ) # type: ignore[arg-type] # FIXME
1013
- if isinstance (nested , _Future ):
1014
- nested = nested .result
1059
+ # Defer evaluation of .Schema attribute, to avoid forward reference issues
1060
+ return partial (getattr , typ , "Schema" )
1015
1061
1016
- return marshmallow .fields .Nested (nested , ** metadata )
1062
+ class_schema = _internal_class_schema (typ ) # type: ignore[arg-type] # FIXME
1063
+ if isinstance (class_schema , _Future ):
1064
+ return class_schema .result
1065
+ return class_schema
1066
+
1067
+
1068
+ def _is_simple_annotated_class (obj : object ) -> bool :
1069
+ """Determine whether obj is a "simple annotated class".
1070
+
1071
+ The ```class_schema``` function can generate schemas for
1072
+ simple annotated classes (as well as for dataclasses).
1073
+ """
1074
+ if not isinstance (obj , type ):
1075
+ return False
1076
+ if getattr (obj , "__init__" , None ) is not object .__init__ :
1077
+ return False
1078
+ if getattr (obj , "__new__" , None ) is not object .__new__ :
1079
+ return False
1080
+
1081
+ schema_ctx = _schema_ctx_stack .top
1082
+ type_hints = get_type_hints (
1083
+ obj , globalns = schema_ctx .globalns , localns = schema_ctx .localns
1084
+ )
1085
+ return any (not typing_inspect .is_classvar (th ) for th in type_hints .values ())
1017
1086
1018
1087
1019
1088
def field_for_schema (
@@ -1052,7 +1121,7 @@ def field_for_schema(
1052
1121
1053
1122
1054
1123
def _field_for_schema (
1055
- typ : type ,
1124
+ typ : Union [ type , object ] ,
1056
1125
default : Any = marshmallow .missing ,
1057
1126
metadata : Optional [Mapping [str , Any ]] = None ,
1058
1127
) -> marshmallow .fields .Field :
@@ -1122,54 +1191,25 @@ def _field_for_schema(
1122
1191
if isinstance (typ , type ) and issubclass (typ , Enum ):
1123
1192
return _field_for_enum (typ , metadata )
1124
1193
1125
- # Assume nested marshmallow dataclass (and hope for the best)
1126
- return _field_for_dataclass (typ , metadata )
1127
-
1128
-
1129
- def _base_schema (
1130
- clazz : type , base_schema : Optional [Type [marshmallow .Schema ]] = None
1131
- ) -> Type [marshmallow .Schema ]:
1132
- """
1133
- Base schema factory that creates a schema for `clazz` derived either from `base_schema`
1134
- or `BaseSchema`
1135
- """
1136
-
1137
- # Remove `type: ignore` when mypy handles dynamic base classes
1138
- # https://github.com/python/mypy/issues/2813
1139
- class BaseSchema (base_schema or marshmallow .Schema ): # type: ignore
1140
- def load (self , data : Mapping , * , many : Optional [bool ] = None , ** kwargs ):
1141
- all_loaded = super ().load (data , many = many , ** kwargs )
1142
- many = self .many if many is None else bool (many )
1143
- if many :
1144
- return [clazz (** loaded ) for loaded in all_loaded ]
1145
- else :
1146
- return clazz (** all_loaded )
1147
-
1148
- return BaseSchema
1149
-
1150
-
1151
- def _get_field_default (field : dataclasses .Field ):
1152
- """
1153
- Return a marshmallow default value given a dataclass default value
1194
+ # nested dataclasses
1195
+ if (
1196
+ dataclasses .is_dataclass (typ )
1197
+ or _is_generic_alias_of_dataclass (typ )
1198
+ or _is_simple_annotated_class (typ )
1199
+ ):
1200
+ nested = _schema_for_nested (typ )
1201
+ # type spec for Nested.__init__ is not correct
1202
+ return marshmallow .fields .Nested (nested , ** metadata ) # type: ignore[arg-type]
1154
1203
1155
- >>> _get_field_default(dataclasses.field())
1156
- <marshmallow.missing>
1157
- """
1158
- # Remove `type: ignore` when https://github.com/python/mypy/issues/6910 is fixed
1159
- default_factory = field .default_factory # type: ignore
1160
- if default_factory is not dataclasses .MISSING :
1161
- return default_factory
1162
- elif field .default is dataclasses .MISSING :
1163
- return marshmallow .missing
1164
- return field .default
1204
+ raise TypeError (f"can not deduce field type for { typ } " )
1165
1205
1166
1206
1167
1207
def NewType (
1168
1208
name : str ,
1169
1209
typ : Type [_U ],
1170
1210
field : Optional [Type [marshmallow .fields .Field ]] = None ,
1171
- ** kwargs ,
1172
- ) -> Callable [[ _U ], _U ] :
1211
+ ** kwargs : Any ,
1212
+ ) -> type :
1173
1213
"""NewType creates simple unique types
1174
1214
to which you can attach custom marshmallow attributes.
1175
1215
All the keyword arguments passed to this function will be transmitted
@@ -1202,9 +1242,9 @@ def NewType(
1202
1242
# noinspection PyTypeHints
1203
1243
new_type = typing_NewType (name , typ ) # type: ignore
1204
1244
# noinspection PyTypeHints
1205
- new_type ._marshmallow_field = field # type: ignore
1245
+ new_type ._marshmallow_field = field
1206
1246
# noinspection PyTypeHints
1207
- new_type ._marshmallow_args = kwargs # type: ignore
1247
+ new_type ._marshmallow_args = kwargs
1208
1248
return new_type
1209
1249
1210
1250
0 commit comments