|
4 | 4 | import json
|
5 | 5 | import struct
|
6 | 6 | import sys
|
| 7 | +import warnings |
7 | 8 | from abc import ABC
|
8 | 9 | from base64 import b64decode, b64encode
|
9 | 10 | from datetime import datetime, timedelta, timezone
|
|
21 | 22 | get_type_hints,
|
22 | 23 | )
|
23 | 24 |
|
| 25 | +import typing |
| 26 | + |
24 | 27 | from ._types import T
|
25 | 28 | from .casing import camel_case, safe_snake_case, snake_case
|
26 | 29 | from .grpc.grpclib_client import ServiceStub
|
@@ -251,7 +254,7 @@ def map_field(
|
251 | 254 | )
|
252 | 255 |
|
253 | 256 |
|
254 |
| -class Enum(int, enum.Enum): |
| 257 | +class Enum(enum.IntEnum): |
255 | 258 | """Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""
|
256 | 259 |
|
257 | 260 | @classmethod
|
@@ -641,9 +644,13 @@ def __bytes__(self) -> bytes:
|
641 | 644 |
|
642 | 645 | @classmethod
|
643 | 646 | def _type_hint(cls, field_name: str) -> Type:
|
| 647 | + return cls._type_hints()[field_name] |
| 648 | + |
| 649 | + @classmethod |
| 650 | + def _type_hints(cls) -> Dict[str, Type]: |
644 | 651 | module = inspect.getmodule(cls)
|
645 | 652 | type_hints = get_type_hints(cls, vars(module))
|
646 |
| - return type_hints[field_name] |
| 653 | + return type_hints |
647 | 654 |
|
648 | 655 | @classmethod
|
649 | 656 | def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
@@ -795,55 +802,67 @@ def to_dict(
|
795 | 802 | `False`.
|
796 | 803 | """
|
797 | 804 | output: Dict[str, Any] = {}
|
| 805 | + field_types = self._type_hints() |
798 | 806 | for field_name, meta in self._betterproto.meta_by_field_name.items():
|
799 |
| - v = getattr(self, field_name) |
| 807 | + field_type = field_types[field_name] |
| 808 | + field_is_repeated = type(field_type) is type(typing.List) |
| 809 | + value = getattr(self, field_name) |
800 | 810 | cased_name = casing(field_name).rstrip("_") # type: ignore
|
801 |
| - if meta.proto_type == "message": |
802 |
| - if isinstance(v, datetime): |
803 |
| - if v != DATETIME_ZERO or include_default_values: |
804 |
| - output[cased_name] = _Timestamp.timestamp_to_json(v) |
805 |
| - elif isinstance(v, timedelta): |
806 |
| - if v != timedelta(0) or include_default_values: |
807 |
| - output[cased_name] = _Duration.delta_to_json(v) |
| 811 | + if meta.proto_type == TYPE_MESSAGE: |
| 812 | + if isinstance(value, datetime): |
| 813 | + if value != DATETIME_ZERO or include_default_values: |
| 814 | + output[cased_name] = _Timestamp.timestamp_to_json(value) |
| 815 | + elif isinstance(value, timedelta): |
| 816 | + if value != timedelta(0) or include_default_values: |
| 817 | + output[cased_name] = _Duration.delta_to_json(value) |
808 | 818 | elif meta.wraps:
|
809 |
| - if v is not None or include_default_values: |
810 |
| - output[cased_name] = v |
811 |
| - elif isinstance(v, list): |
| 819 | + if value is not None or include_default_values: |
| 820 | + output[cased_name] = value |
| 821 | + elif field_is_repeated: |
812 | 822 | # Convert each item.
|
813 |
| - v = [i.to_dict(casing, include_default_values) for i in v] |
814 |
| - if v or include_default_values: |
815 |
| - output[cased_name] = v |
| 823 | + value = [i.to_dict(casing, include_default_values) for i in value] |
| 824 | + if value or include_default_values: |
| 825 | + output[cased_name] = value |
816 | 826 | else:
|
817 |
| - if v._serialized_on_wire or include_default_values: |
818 |
| - output[cased_name] = v.to_dict(casing, include_default_values) |
819 |
| - elif meta.proto_type == "map": |
820 |
| - for k in v: |
821 |
| - if hasattr(v[k], "to_dict"): |
822 |
| - v[k] = v[k].to_dict(casing, include_default_values) |
823 |
| - |
824 |
| - if v or include_default_values: |
825 |
| - output[cased_name] = v |
826 |
| - elif v != self._get_field_default(field_name) or include_default_values: |
| 827 | + if value._serialized_on_wire or include_default_values: |
| 828 | + output[cased_name] = value.to_dict( |
| 829 | + casing, include_default_values |
| 830 | + ) |
| 831 | + elif meta.proto_type == TYPE_MAP: |
| 832 | + for k in value: |
| 833 | + if hasattr(value[k], "to_dict"): |
| 834 | + value[k] = value[k].to_dict(casing, include_default_values) |
| 835 | + |
| 836 | + if value or include_default_values: |
| 837 | + output[cased_name] = value |
| 838 | + elif value != self._get_field_default(field_name) or include_default_values: |
827 | 839 | if meta.proto_type in INT_64_TYPES:
|
828 |
| - if isinstance(v, list): |
829 |
| - output[cased_name] = [str(n) for n in v] |
| 840 | + if field_is_repeated: |
| 841 | + output[cased_name] = [str(n) for n in value] |
830 | 842 | else:
|
831 |
| - output[cased_name] = str(v) |
| 843 | + output[cased_name] = str(value) |
832 | 844 | elif meta.proto_type == TYPE_BYTES:
|
833 |
| - if isinstance(v, list): |
834 |
| - output[cased_name] = [b64encode(b).decode("utf8") for b in v] |
| 845 | + if field_is_repeated: |
| 846 | + output[cased_name] = [ |
| 847 | + b64encode(b).decode("utf8") for b in value |
| 848 | + ] |
835 | 849 | else:
|
836 |
| - output[cased_name] = b64encode(v).decode("utf8") |
| 850 | + output[cased_name] = b64encode(value).decode("utf8") |
837 | 851 | elif meta.proto_type == TYPE_ENUM:
|
838 |
| - enum_values = list( |
839 |
| - self._betterproto.cls_by_field[field_name] |
840 |
| - ) # type: ignore |
841 |
| - if isinstance(v, list): |
842 |
| - output[cased_name] = [enum_values[e].name for e in v] |
| 852 | + if field_is_repeated: |
| 853 | + enum_class: Type[Enum] = field_type.__args__[0] |
| 854 | + if isinstance(value, typing.Iterable) and not isinstance( |
| 855 | + value, str |
| 856 | + ): |
| 857 | + output[cased_name] = [enum_class(el).name for el in value] |
| 858 | + else: |
| 859 | + # transparently upgrade single value to repeated |
| 860 | + output[cased_name] = [enum_class(value).name] |
843 | 861 | else:
|
844 |
| - output[cased_name] = enum_values[v].name |
| 862 | + enum_class: Type[Enum] = field_type # noqa |
| 863 | + output[cased_name] = enum_class(value).name |
845 | 864 | else:
|
846 |
| - output[cased_name] = v |
| 865 | + output[cased_name] = value |
847 | 866 | return output
|
848 | 867 |
|
849 | 868 | def from_dict(self: T, value: dict) -> T:
|
|
0 commit comments