@@ -723,25 +723,12 @@ def _internal_class_schema(
723
723
return schema_class
724
724
725
725
726
- def _generic_type_add_any (typ : type ) -> type : # FIXME: signature is wrong
727
- """if typ is generic type without arguments, replace them by Any."""
728
- if typ is list or typ is List :
729
- typ = List [Any ]
730
- elif typ is dict or typ is Dict :
731
- typ = Dict [Any , Any ]
732
- elif typ is Mapping :
733
- typ = Mapping [Any , Any ]
734
- elif typ is Sequence :
735
- typ = Sequence [Any ]
736
- elif typ is set or typ is Set :
737
- typ = Set [Any ]
738
- elif typ is frozenset or typ is FrozenSet :
739
- typ = FrozenSet [Any ]
740
- return typ
741
-
742
-
743
726
def _is_builtin_collection_type (typ : object ) -> bool :
744
- return get_origin (typ ) in {
727
+ origin = get_origin (typ )
728
+ if origin is None :
729
+ origin = typ
730
+
731
+ return origin in {
745
732
list ,
746
733
collections .abc .Sequence ,
747
734
set ,
@@ -759,24 +746,11 @@ def _field_for_builtin_collection_type(
759
746
Handle builtin container types like list, tuple, set, etc.
760
747
"""
761
748
origin = get_origin (typ )
762
- assert origin is not None
763
- assert not typing_inspect .is_union_type (typ )
764
-
765
- arguments = get_args (typ )
766
- # if len(arguments) == 0:
767
- # if issubclass(origin, (collections.abc.Sequence, collections.abc.Set)):
768
- # arguments = (Any,)
769
- # elif issubclass(origin, collections.abc.Mapping):
770
- # arguments = (Any, Any)
771
- # else:
772
- # print(repr(origin))
773
- # raise TypeError(f"{typ!r} requires generic arguments")
774
-
775
- if origin is tuple and len (arguments ) == 2 and arguments [1 ] is Ellipsis :
776
- origin = collections .abc .Sequence
777
- arguments = (arguments [0 ],)
749
+ if origin is None :
750
+ origin = typ
751
+ assert len (get_args (typ )) == 0
778
752
779
- fields = tuple ( map ( _field_for_schema , arguments ) )
753
+ args = get_args ( typ )
780
754
781
755
schema_ctx = _schema_ctx_stack .top
782
756
@@ -785,31 +759,38 @@ def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]:
785
759
type_mapping = schema_ctx .get_type_mapping ()
786
760
return type_mapping .get (type_spec , default ) # type: ignore[return-value]
787
761
762
+ if origin is tuple and (len (args ) == 0 or (len (args ) == 2 and args [1 ] is Ellipsis )):
763
+ # Special case: homogeneous tuple — treat as Sequence
764
+ origin = collections .abc .Sequence
765
+ args = args [:1 ]
766
+
767
+ if origin is tuple :
768
+ tuple_type = get_field_type (Tuple , default = marshmallow .fields .Tuple )
769
+ return tuple_type (tuple (map (_field_for_schema , args )), ** metadata )
770
+
771
+ def get_field (i : int ) -> marshmallow .fields .Field :
772
+ return _field_for_schema (args [i ] if args else Any )
773
+
774
+ if origin in (dict , collections .abc .Mapping ):
775
+ dict_type = get_field_type (Dict , default = marshmallow .fields .Dict )
776
+ return dict_type (keys = get_field (0 ), values = get_field (1 ), ** metadata )
777
+
788
778
if origin is list :
789
- assert len (fields ) == 1
790
779
list_type = get_field_type (List , default = marshmallow .fields .List )
791
- return list_type (fields [ 0 ] , ** metadata )
780
+ return list_type (get_field ( 0 ) , ** metadata )
792
781
793
782
if origin is collections .abc .Sequence :
794
783
from . import collection_field
795
784
796
- assert len (fields ) == 1
797
- return collection_field .Sequence (fields [0 ], ** metadata )
785
+ return collection_field .Sequence (get_field (0 ), ** metadata )
798
786
799
787
if origin in (set , frozenset ):
800
788
from . import collection_field
801
789
802
- assert len (fields ) == 1
803
790
frozen = origin is frozenset
804
- return collection_field .Set (fields [0 ], frozen = frozen , ** metadata )
805
-
806
- if origin is tuple :
807
- tuple_type = get_field_type (Tuple , default = marshmallow .fields .Tuple )
808
- return tuple_type (fields , ** metadata )
791
+ return collection_field .Set (get_field (0 ), frozen = frozen , ** metadata )
809
792
810
- assert origin in (dict , collections .abc .Mapping )
811
- dict_type = get_field_type (Dict , default = marshmallow .fields .Dict )
812
- return dict_type (keys = fields [0 ], values = fields [1 ], ** metadata )
793
+ raise TypeError (f"{ typ } is not a builtin collection type" )
813
794
814
795
815
796
def _field_for_union_type (
@@ -1038,8 +1019,8 @@ def _field_for_schema(
1038
1019
if schema_ctx .generic_args is not None and isinstance (typ , TypeVar ):
1039
1020
typ = schema_ctx .generic_args .resolve (typ )
1040
1021
1041
- # Generic types specified without type arguments
1042
- typ = _generic_type_add_any (typ )
1022
+ if _is_builtin_collection_type ( typ ):
1023
+ return _field_for_builtin_collection_type (typ , metadata )
1043
1024
1044
1025
# Base types
1045
1026
type_mapping = schema_ctx .get_type_mapping (use_mro = True )
@@ -1061,9 +1042,6 @@ def _field_for_schema(
1061
1042
metadata = metadata ,
1062
1043
)
1063
1044
1064
- if _is_builtin_collection_type (typ ):
1065
- return _field_for_builtin_collection_type (typ , metadata )
1066
-
1067
1045
if typing_inspect .is_union_type (typ ):
1068
1046
return _field_for_union_type (typ , metadata )
1069
1047
0 commit comments