Skip to content

Commit dbf9ba3

Browse files
committed
refactor: delete _generic_type_add_any
Handle default generic params values in _field_for_builtin_collection_type.
1 parent 35b2950 commit dbf9ba3

File tree

1 file changed

+31
-53
lines changed

1 file changed

+31
-53
lines changed

marshmallow_dataclass/__init__.py

+31-53
Original file line numberDiff line numberDiff line change
@@ -723,25 +723,12 @@ def _internal_class_schema(
723723
return schema_class
724724

725725

726-
def _generic_type_add_any(typ: type) -> type: # FIXME: signature is wrong
727-
"""if typ is generic type without arguments, replace them by Any."""
728-
if typ is list or typ is List:
729-
typ = List[Any]
730-
elif typ is dict or typ is Dict:
731-
typ = Dict[Any, Any]
732-
elif typ is Mapping:
733-
typ = Mapping[Any, Any]
734-
elif typ is Sequence:
735-
typ = Sequence[Any]
736-
elif typ is set or typ is Set:
737-
typ = Set[Any]
738-
elif typ is frozenset or typ is FrozenSet:
739-
typ = FrozenSet[Any]
740-
return typ
741-
742-
743726
def _is_builtin_collection_type(typ: object) -> bool:
744-
return get_origin(typ) in {
727+
origin = get_origin(typ)
728+
if origin is None:
729+
origin = typ
730+
731+
return origin in {
745732
list,
746733
collections.abc.Sequence,
747734
set,
@@ -759,24 +746,11 @@ def _field_for_builtin_collection_type(
759746
Handle builtin container types like list, tuple, set, etc.
760747
"""
761748
origin = get_origin(typ)
762-
assert origin is not None
763-
assert not typing_inspect.is_union_type(typ)
764-
765-
arguments = get_args(typ)
766-
# if len(arguments) == 0:
767-
# if issubclass(origin, (collections.abc.Sequence, collections.abc.Set)):
768-
# arguments = (Any,)
769-
# elif issubclass(origin, collections.abc.Mapping):
770-
# arguments = (Any, Any)
771-
# else:
772-
# print(repr(origin))
773-
# raise TypeError(f"{typ!r} requires generic arguments")
774-
775-
if origin is tuple and len(arguments) == 2 and arguments[1] is Ellipsis:
776-
origin = collections.abc.Sequence
777-
arguments = (arguments[0],)
749+
if origin is None:
750+
origin = typ
751+
assert len(get_args(typ)) == 0
778752

779-
fields = tuple(map(_field_for_schema, arguments))
753+
args = get_args(typ)
780754

781755
schema_ctx = _schema_ctx_stack.top
782756

@@ -785,31 +759,38 @@ def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]:
785759
type_mapping = schema_ctx.get_type_mapping()
786760
return type_mapping.get(type_spec, default) # type: ignore[return-value]
787761

762+
if origin is tuple and (len(args) == 0 or (len(args) == 2 and args[1] is Ellipsis)):
763+
# Special case: homogeneous tuple — treat as Sequence
764+
origin = collections.abc.Sequence
765+
args = args[:1]
766+
767+
if origin is tuple:
768+
tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple)
769+
return tuple_type(tuple(map(_field_for_schema, args)), **metadata)
770+
771+
def get_field(i: int) -> marshmallow.fields.Field:
772+
return _field_for_schema(args[i] if args else Any)
773+
774+
if origin in (dict, collections.abc.Mapping):
775+
dict_type = get_field_type(Dict, default=marshmallow.fields.Dict)
776+
return dict_type(keys=get_field(0), values=get_field(1), **metadata)
777+
788778
if origin is list:
789-
assert len(fields) == 1
790779
list_type = get_field_type(List, default=marshmallow.fields.List)
791-
return list_type(fields[0], **metadata)
780+
return list_type(get_field(0), **metadata)
792781

793782
if origin is collections.abc.Sequence:
794783
from . import collection_field
795784

796-
assert len(fields) == 1
797-
return collection_field.Sequence(fields[0], **metadata)
785+
return collection_field.Sequence(get_field(0), **metadata)
798786

799787
if origin in (set, frozenset):
800788
from . import collection_field
801789

802-
assert len(fields) == 1
803790
frozen = origin is frozenset
804-
return collection_field.Set(fields[0], frozen=frozen, **metadata)
805-
806-
if origin is tuple:
807-
tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple)
808-
return tuple_type(fields, **metadata)
791+
return collection_field.Set(get_field(0), frozen=frozen, **metadata)
809792

810-
assert origin in (dict, collections.abc.Mapping)
811-
dict_type = get_field_type(Dict, default=marshmallow.fields.Dict)
812-
return dict_type(keys=fields[0], values=fields[1], **metadata)
793+
raise TypeError(f"{typ} is not a builtin collection type")
813794

814795

815796
def _field_for_union_type(
@@ -1038,8 +1019,8 @@ def _field_for_schema(
10381019
if schema_ctx.generic_args is not None and isinstance(typ, TypeVar):
10391020
typ = schema_ctx.generic_args.resolve(typ)
10401021

1041-
# Generic types specified without type arguments
1042-
typ = _generic_type_add_any(typ)
1022+
if _is_builtin_collection_type(typ):
1023+
return _field_for_builtin_collection_type(typ, metadata)
10431024

10441025
# Base types
10451026
type_mapping = schema_ctx.get_type_mapping(use_mro=True)
@@ -1061,9 +1042,6 @@ def _field_for_schema(
10611042
metadata=metadata,
10621043
)
10631044

1064-
if _is_builtin_collection_type(typ):
1065-
return _field_for_builtin_collection_type(typ, metadata)
1066-
10671045
if typing_inspect.is_union_type(typ):
10681046
return _field_for_union_type(typ, metadata)
10691047

0 commit comments

Comments
 (0)