|
4 | 4 | import json
|
5 | 5 | import struct
|
6 | 6 | import sys
|
| 7 | +import warnings |
7 | 8 | from abc import ABC
|
8 | 9 | from base64 import b64decode, b64encode
|
9 | 10 | from datetime import datetime, timedelta, timezone
|
|
21 | 22 | get_type_hints,
|
22 | 23 | )
|
23 | 24 |
|
| 25 | +import typing |
| 26 | + |
24 | 27 | from ._types import T
|
25 | 28 | from .casing import camel_case, safe_snake_case, snake_case
|
26 | 29 | from .grpc.grpclib_client import ServiceStub
|
@@ -251,7 +254,7 @@ def map_field(
|
251 | 254 | )
|
252 | 255 |
|
253 | 256 |
|
254 |
| -class Enum(int, enum.Enum): |
| 257 | +class Enum(enum.IntEnum): |
255 | 258 | """Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""
|
256 | 259 |
|
257 | 260 | @classmethod
|
@@ -635,9 +638,13 @@ def __bytes__(self) -> bytes:
|
635 | 638 |
|
636 | 639 | @classmethod
|
637 | 640 | def _type_hint(cls, field_name: str) -> Type:
|
| 641 | + return cls._type_hints()[field_name] |
| 642 | + |
| 643 | + @classmethod |
| 644 | + def _type_hints(cls) -> Dict[str, Type]: |
638 | 645 | module = inspect.getmodule(cls)
|
639 | 646 | type_hints = get_type_hints(cls, vars(module))
|
640 |
| - return type_hints[field_name] |
| 647 | + return type_hints |
641 | 648 |
|
642 | 649 | @classmethod
|
643 | 650 | def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
@@ -789,55 +796,68 @@ def to_dict(
|
789 | 796 | `False`.
|
790 | 797 | """
|
791 | 798 | output: Dict[str, Any] = {}
|
| 799 | + field_types = self._type_hints() |
792 | 800 | for field_name, meta in self._betterproto.meta_by_field_name.items():
|
793 |
| - v = getattr(self, field_name) |
| 801 | + field_type = field_types[field_name] |
| 802 | + field_is_repeated = type(field_type) is type(typing.List) |
| 803 | + value = getattr(self, field_name) |
794 | 804 | cased_name = casing(field_name).rstrip("_") # type: ignore
|
795 |
| - if meta.proto_type == "message": |
796 |
| - if isinstance(v, datetime): |
797 |
| - if v != DATETIME_ZERO or include_default_values: |
798 |
| - output[cased_name] = _Timestamp.timestamp_to_json(v) |
799 |
| - elif isinstance(v, timedelta): |
800 |
| - if v != timedelta(0) or include_default_values: |
801 |
| - output[cased_name] = _Duration.delta_to_json(v) |
| 805 | + if meta.proto_type == TYPE_MESSAGE: |
| 806 | + if isinstance(value, datetime): |
| 807 | + if value != DATETIME_ZERO or include_default_values: |
| 808 | + output[cased_name] = _Timestamp.timestamp_to_json(value) |
| 809 | + elif isinstance(value, timedelta): |
| 810 | + if value != timedelta(0) or include_default_values: |
| 811 | + output[cased_name] = _Duration.delta_to_json(value) |
802 | 812 | elif meta.wraps:
|
803 |
| - if v is not None or include_default_values: |
804 |
| - output[cased_name] = v |
805 |
| - elif isinstance(v, list): |
| 813 | + if value is not None or include_default_values: |
| 814 | + output[cased_name] = value |
| 815 | + elif field_is_repeated: |
806 | 816 | # Convert each item.
|
807 |
| - v = [i.to_dict(casing, include_default_values) for i in v] |
808 |
| - if v or include_default_values: |
809 |
| - output[cased_name] = v |
| 817 | + value = [i.to_dict(casing, include_default_values) for i in value] |
| 818 | + if value or include_default_values: |
| 819 | + output[cased_name] = value |
810 | 820 | else:
|
811 |
| - if v._serialized_on_wire or include_default_values: |
812 |
| - output[cased_name] = v.to_dict(casing, include_default_values) |
813 |
| - elif meta.proto_type == "map": |
814 |
| - for k in v: |
815 |
| - if hasattr(v[k], "to_dict"): |
816 |
| - v[k] = v[k].to_dict(casing, include_default_values) |
817 |
| - |
818 |
| - if v or include_default_values: |
819 |
| - output[cased_name] = v |
820 |
| - elif v != self._get_field_default(field_name) or include_default_values: |
| 821 | + if value._serialized_on_wire or include_default_values: |
| 822 | + output[cased_name] = value.to_dict( |
| 823 | + casing, include_default_values |
| 824 | + ) |
| 825 | + elif meta.proto_type == TYPE_MAP: |
| 826 | + for k in value: |
| 827 | + if hasattr(value[k], "to_dict"): |
| 828 | + value[k] = value[k].to_dict(casing, include_default_values) |
| 829 | + |
| 830 | + if value or include_default_values: |
| 831 | + output[cased_name] = value |
| 832 | + elif value != self._get_field_default(field_name) or include_default_values: |
821 | 833 | if meta.proto_type in INT_64_TYPES:
|
822 |
| - if isinstance(v, list): |
823 |
| - output[cased_name] = [str(n) for n in v] |
| 834 | + if field_is_repeated: |
| 835 | + output[cased_name] = [str(n) for n in value] |
824 | 836 | else:
|
825 |
| - output[cased_name] = str(v) |
| 837 | + output[cased_name] = str(value) |
826 | 838 | elif meta.proto_type == TYPE_BYTES:
|
827 |
| - if isinstance(v, list): |
828 |
| - output[cased_name] = [b64encode(b).decode("utf8") for b in v] |
| 839 | + if field_is_repeated: |
| 840 | + output[cased_name] = [ |
| 841 | + b64encode(b).decode("utf8") for b in value |
| 842 | + ] |
829 | 843 | else:
|
830 |
| - output[cased_name] = b64encode(v).decode("utf8") |
| 844 | + output[cased_name] = b64encode(value).decode("utf8") |
831 | 845 | elif meta.proto_type == TYPE_ENUM:
|
832 |
| - enum_values = list( |
833 |
| - self._betterproto.cls_by_field[field_name] |
834 |
| - ) # type: ignore |
835 |
| - if isinstance(v, list): |
836 |
| - output[cased_name] = [enum_values[e].name for e in v] |
| 846 | + if field_is_repeated: |
| 847 | + enum_class = field_type.__args__[0] |
| 848 | + if isinstance(value, typing.Iterable): |
| 849 | + output[cased_name] = [ |
| 850 | + enum_class(element).name for element in value |
| 851 | + ] |
| 852 | + else: |
| 853 | + warnings.warn( |
| 854 | + f"Non-iterable value for repeated enum field {field_name}" |
| 855 | + ) |
837 | 856 | else:
|
838 |
| - output[cased_name] = enum_values[v].name |
| 857 | + enum_class = field_type |
| 858 | + output[cased_name] = enum_class(value).name |
839 | 859 | else:
|
840 |
| - output[cased_name] = v |
| 860 | + output[cased_name] = value |
841 | 861 | return output
|
842 | 862 |
|
843 | 863 | def from_dict(self: T, value: dict) -> T:
|
|
0 commit comments