@@ -653,6 +653,9 @@ class _SchemaContext:
653
653
default_factory = dict
654
654
)
655
655
656
+ def replace (self , generic_args : Optional [_GenericArgs ]) -> "_SchemaContext" :
657
+ return dataclasses .replace (self , generic_args = generic_args )
658
+
656
659
def get_type_mapping (
657
660
self , use_mro : bool = False
658
661
) -> Mapping [Any , Type [marshmallow .fields .Field ]]:
@@ -717,63 +720,109 @@ def _internal_class_schema(
717
720
718
721
generic_args = schema_ctx .generic_args
719
722
720
- if _is_generic_alias_of_dataclass (clazz ):
721
- generic_args = _GenericArgs (clazz , generic_args )
722
- clazz = typing_inspect .get_origin (clazz )
723
- elif not dataclasses .is_dataclass (clazz ):
724
- try :
725
- warnings .warn (
726
- "****** WARNING ****** "
727
- f"marshmallow_dataclass was called on the class { clazz } , which is not a dataclass. "
728
- "It is going to try and convert the class into a dataclass, which may have "
729
- "undesirable side effects. To avoid this message, make sure all your classes and "
730
- "all the classes of their fields are either explicitly supported by "
731
- "marshmallow_dataclass, or define the schema explicitly using "
732
- "field(metadata=dict(marshmallow_field=...)). For more information, see "
733
- "https://github.com/lovasoa/marshmallow_dataclass/issues/51 "
734
- "****** WARNING ******"
735
- )
736
- dataclasses .dataclass (clazz )
737
- except Exception as exc :
738
- raise TypeError (
739
- f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
740
- ) from exc
741
-
742
- fields = dataclasses .fields (clazz )
743
-
744
- # Copy all marshmallow hooks and whitelisted members of the dataclass to the schema.
745
- attributes = {
746
- k : v
747
- for k , v in inspect .getmembers (clazz )
748
- if hasattr (v , "__marshmallow_hook__" ) or k in MEMBERS_WHITELIST
749
- }
723
+ constructor : Callable [..., object ]
724
+
725
+ if _is_simple_annotated_class (clazz ):
726
+ class_name = clazz .__name__
727
+ constructor = _simple_class_constructor (clazz )
728
+ attributes = _schema_attrs_for_simple_class (clazz )
729
+ elif _is_generic_alias_of_dataclass (clazz ):
730
+ origin = get_origin (clazz )
731
+ assert isinstance (origin , type )
732
+ class_name = origin .__name__
733
+ constructor = origin
734
+ with schema_ctx .replace (generic_args = _GenericArgs (clazz , generic_args )):
735
+ attributes = _schema_attrs_for_dataclass (origin )
736
+ elif dataclasses .is_dataclass (clazz ):
737
+ class_name = clazz .__name__
738
+ constructor = clazz
739
+ attributes = _schema_attrs_for_dataclass (clazz )
740
+ else :
741
+ raise TypeError (f"{ clazz } is not a dataclass or a simple annotated class" )
742
+
743
+ base_schema = marshmallow .Schema
744
+ if schema_ctx .base_schema is not None :
745
+ base_schema = schema_ctx .base_schema
746
+
747
+ load_to_dict = base_schema .load
748
+
749
+ def load (
750
+ self : marshmallow .Schema ,
751
+ data : Union [Mapping [str , Any ], Iterable [Mapping [str , Any ]]],
752
+ * ,
753
+ many : Optional [bool ] = None ,
754
+ unknown : Optional [str ] = None ,
755
+ ** kwargs : Any ,
756
+ ) -> Any :
757
+ many = self .many if many is None else bool (many )
758
+ loaded = load_to_dict (self , data , many = many , unknown = unknown , ** kwargs )
759
+ if many :
760
+ return [constructor (** item ) for item in loaded ]
761
+ else :
762
+ return constructor (** loaded )
750
763
751
- # Update the schema members to contain marshmallow fields instead of dataclass fields
752
- type_hints = get_type_hints (
753
- clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
754
- )
755
- with dataclasses .replace (schema_ctx , generic_args = generic_args ):
756
- attributes .update (
757
- (
758
- field .name ,
759
- _field_for_schema (
760
- type_hints [field .name ],
761
- _get_field_default (field ),
762
- field .metadata ,
763
- ),
764
- )
765
- for field in fields
766
- if field .init
767
- )
764
+ attributes ["load" ] = load
768
765
769
766
schema_class : Type [marshmallow .Schema ] = type (
770
- clazz . __name__ , (_base_schema ( clazz , schema_ctx . base_schema ) ,), attributes
767
+ f" { class_name } Schema" , (base_schema ,), attributes
771
768
)
769
+
772
770
future .set_result (schema_class )
773
771
_schema_cache [cache_key ] = schema_class
774
772
return schema_class
775
773
776
774
775
+ def _marshmallow_hooks (clazz : type ) -> Iterator [Tuple [str , Any ]]:
776
+ for name , attr in inspect .getmembers (clazz ):
777
+ if hasattr (attr , "__marshmallow_hook__" ) or name in MEMBERS_WHITELIST :
778
+ yield name , attr
779
+
780
+
781
+ def _schema_attrs_for_dataclass (clazz : type ) -> Dict [str , Any ]:
782
+ schema_ctx = _schema_ctx_stack .top
783
+ type_hints = get_type_hints (
784
+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
785
+ )
786
+
787
+ attrs = dict (_marshmallow_hooks (clazz ))
788
+ for field in dataclasses .fields (clazz ):
789
+ if field .init :
790
+ typ = type_hints [field .name ]
791
+ default = (
792
+ field .default_factory
793
+ if field .default_factory is not dataclasses .MISSING
794
+ else field .default
795
+ if field .default is not dataclasses .MISSING
796
+ else marshmallow .missing
797
+ )
798
+ attrs [field .name ] = _field_for_schema (typ , default , field .metadata )
799
+ return attrs
800
+
801
+
802
+ def _schema_attrs_for_simple_class (clazz : type ) -> Dict [str , Any ]:
803
+ schema_ctx = _schema_ctx_stack .top
804
+ type_hints = get_type_hints (
805
+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
806
+ )
807
+
808
+ attrs = dict (_marshmallow_hooks (clazz ))
809
+ for field_name , typ in type_hints .items ():
810
+ if not typing_inspect .is_classvar (typ ):
811
+ default = getattr (clazz , field_name , marshmallow .missing )
812
+ attrs [field_name ] = _field_for_schema (typ , default )
813
+ return attrs
814
+
815
+
816
+ def _simple_class_constructor (clazz : Type [_U ]) -> Callable [..., _U ]:
817
+ def constructor (** kwargs : Any ) -> _U :
818
+ obj = clazz .__new__ (clazz )
819
+ for k , v in kwargs .items ():
820
+ setattr (obj , k , v )
821
+ return obj
822
+
823
+ return constructor
824
+
825
+
777
826
def _is_builtin_collection_type (typ : object ) -> bool :
778
827
origin = get_origin (typ )
779
828
if origin is None :
@@ -845,7 +894,7 @@ def get_field(i: int) -> marshmallow.fields.Field:
845
894
846
895
847
896
def _field_for_union_type (
848
- typ : type , metadata : Dict [str , Any ]
897
+ typ : object , metadata : Dict [str , Any ]
849
898
) -> marshmallow .fields .Field :
850
899
"""
851
900
Construct the appropriate Field for a union or optional type.
@@ -877,7 +926,7 @@ def _field_for_union_type(
877
926
878
927
879
928
def _field_for_literal_type (
880
- typ : type , metadata : Dict [str , Any ]
929
+ typ : object , metadata : Dict [str , Any ]
881
930
) -> marshmallow .fields .Field :
882
931
"""
883
932
Construct the appropriate Field for a Literal type.
@@ -893,7 +942,7 @@ def _field_for_literal_type(
893
942
return marshmallow .fields .Raw (validate = validate , ** metadata )
894
943
895
944
896
- def _get_subtype_for_final_type (typ : type , default : Any ) -> Any :
945
+ def _get_subtype_for_final_type (typ : object , default : Any ) -> Any :
897
946
"""
898
947
Construct the appropriate Field for a Final type.
899
948
"""
@@ -930,7 +979,7 @@ def _get_subtype_for_final_type(typ: type, default: Any) -> Any:
930
979
931
980
932
981
def _field_for_new_type (
933
- typ : Type , default : Any , metadata : Dict [str , Any ]
982
+ typ : object , default : Any , metadata : Dict [str , Any ]
934
983
) -> marshmallow .fields .Field :
935
984
"""
936
985
Return a new field for fields based on a NewType.
@@ -953,7 +1002,8 @@ def _field_for_new_type(
953
1002
** metadata ,
954
1003
"validate" : validators if validators else None ,
955
1004
}
956
- metadata .setdefault ("metadata" , {}).setdefault ("description" , typ .__name__ )
1005
+ type_name = getattr (typ , "__name__" , repr (typ ))
1006
+ metadata .setdefault ("metadata" , {}).setdefault ("description" , type_name )
957
1007
958
1008
field : Optional [Type [marshmallow .fields .Field ]] = getattr (
959
1009
typ , "_marshmallow_field" , None
@@ -980,22 +1030,41 @@ def _field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.F
980
1030
return marshmallow_enum .EnumField (typ , ** metadata )
981
1031
982
1032
983
- def _field_for_dataclass (
984
- typ : Union [ Type , object ], metadata : Dict [ str , Any ]
985
- ) -> marshmallow .fields . Field :
1033
+ def _schema_for_nested (
1034
+ typ : object ,
1035
+ ) -> Union [ Type [ marshmallow .Schema ], Callable [[], Type [ marshmallow . Schema ]]] :
986
1036
"""
987
- Return a new field for a nested dataclass field.
1037
+ Return a marshmallow.Schema for a nested dataclass (or simple annotated class)
988
1038
"""
989
1039
if isinstance (typ , type ) and hasattr (typ , "Schema" ):
990
1040
# marshmallow_dataclass.dataclass
991
- nested = typ .Schema
992
- else :
993
- assert isinstance (typ , Hashable )
994
- nested = _internal_class_schema (typ ) # type: ignore[arg-type] # FIXME
995
- if isinstance (nested , _Future ):
996
- nested = nested .result
1041
+ # Defer evaluation of .Schema attribute, to avoid forward reference issues
1042
+ return partial (getattr , typ , "Schema" )
997
1043
998
- return marshmallow .fields .Nested (nested , ** metadata )
1044
+ class_schema = _internal_class_schema (typ ) # type: ignore[arg-type] # FIXME
1045
+ if isinstance (class_schema , _Future ):
1046
+ return class_schema .result
1047
+ return class_schema
1048
+
1049
+
1050
+ def _is_simple_annotated_class (obj : object ) -> bool :
1051
+ """Determine whether obj is a "simple annotated class".
1052
+
1053
+ The ```class_schema``` function can generate schemas for
1054
+ simple annotated classes (as well as for dataclasses).
1055
+ """
1056
+ if not isinstance (obj , type ):
1057
+ return False
1058
+ if getattr (obj , "__init__" , None ) is not object .__init__ :
1059
+ return False
1060
+ if getattr (obj , "__new__" , None ) is not object .__new__ :
1061
+ return False
1062
+
1063
+ schema_ctx = _schema_ctx_stack .top
1064
+ type_hints = get_type_hints (
1065
+ obj , globalns = schema_ctx .globalns , localns = schema_ctx .localns
1066
+ )
1067
+ return any (not typing_inspect .is_classvar (th ) for th in type_hints .values ())
999
1068
1000
1069
1001
1070
def field_for_schema (
@@ -1034,7 +1103,7 @@ def field_for_schema(
1034
1103
1035
1104
1036
1105
def _field_for_schema (
1037
- typ : type ,
1106
+ typ : Union [ type , object ] ,
1038
1107
default : Any = marshmallow .missing ,
1039
1108
metadata : Optional [Mapping [str , Any ]] = None ,
1040
1109
) -> marshmallow .fields .Field :
@@ -1104,54 +1173,25 @@ def _field_for_schema(
1104
1173
if isinstance (typ , type ) and issubclass (typ , Enum ):
1105
1174
return _field_for_enum (typ , metadata )
1106
1175
1107
- # Assume nested marshmallow dataclass (and hope for the best)
1108
- return _field_for_dataclass (typ , metadata )
1109
-
1110
-
1111
- def _base_schema (
1112
- clazz : type , base_schema : Optional [Type [marshmallow .Schema ]] = None
1113
- ) -> Type [marshmallow .Schema ]:
1114
- """
1115
- Base schema factory that creates a schema for `clazz` derived either from `base_schema`
1116
- or `BaseSchema`
1117
- """
1118
-
1119
- # Remove `type: ignore` when mypy handles dynamic base classes
1120
- # https://github.com/python/mypy/issues/2813
1121
- class BaseSchema (base_schema or marshmallow .Schema ): # type: ignore
1122
- def load (self , data : Mapping , * , many : Optional [bool ] = None , ** kwargs ):
1123
- all_loaded = super ().load (data , many = many , ** kwargs )
1124
- many = self .many if many is None else bool (many )
1125
- if many :
1126
- return [clazz (** loaded ) for loaded in all_loaded ]
1127
- else :
1128
- return clazz (** all_loaded )
1129
-
1130
- return BaseSchema
1131
-
1132
-
1133
- def _get_field_default (field : dataclasses .Field ):
1134
- """
1135
- Return a marshmallow default value given a dataclass default value
1176
+ # nested dataclasses
1177
+ if (
1178
+ dataclasses .is_dataclass (typ )
1179
+ or _is_generic_alias_of_dataclass (typ )
1180
+ or _is_simple_annotated_class (typ )
1181
+ ):
1182
+ nested = _schema_for_nested (typ )
1183
+ # type spec for Nested.__init__ is not correct
1184
+ return marshmallow .fields .Nested (nested , ** metadata ) # type: ignore[arg-type]
1136
1185
1137
- >>> _get_field_default(dataclasses.field())
1138
- <marshmallow.missing>
1139
- """
1140
- # Remove `type: ignore` when https://github.com/python/mypy/issues/6910 is fixed
1141
- default_factory = field .default_factory # type: ignore
1142
- if default_factory is not dataclasses .MISSING :
1143
- return default_factory
1144
- elif field .default is dataclasses .MISSING :
1145
- return marshmallow .missing
1146
- return field .default
1186
+ raise TypeError (f"can not deduce field type for { typ } " )
1147
1187
1148
1188
1149
1189
def NewType (
1150
1190
name : str ,
1151
1191
typ : Type [_U ],
1152
1192
field : Optional [Type [marshmallow .fields .Field ]] = None ,
1153
- ** kwargs ,
1154
- ) -> Callable [[ _U ], _U ] :
1193
+ ** kwargs : Any ,
1194
+ ) -> type :
1155
1195
"""NewType creates simple unique types
1156
1196
to which you can attach custom marshmallow attributes.
1157
1197
All the keyword arguments passed to this function will be transmitted
@@ -1184,9 +1224,9 @@ def NewType(
1184
1224
# noinspection PyTypeHints
1185
1225
new_type = typing_NewType (name , typ ) # type: ignore
1186
1226
# noinspection PyTypeHints
1187
- new_type ._marshmallow_field = field # type: ignore
1227
+ new_type ._marshmallow_field = field
1188
1228
# noinspection PyTypeHints
1189
- new_type ._marshmallow_args = kwargs # type: ignore
1229
+ new_type ._marshmallow_args = kwargs
1190
1230
return new_type
1191
1231
1192
1232
0 commit comments