Skip to content

Commit 6e08097

Browse files
boukeversteeghGobot1234
authored andcommitted
Fix: to_dict returns wrong enum fields when numbering is not consecutive (danielgtaylor#102)
Fixes danielgtaylor#93 to_dict returns wrong enum fields when numbering is not consecutive
1 parent ac197fd commit 6e08097

File tree

7 files changed

+167
-57
lines changed

7 files changed

+167
-57
lines changed

src/betterproto/__init__.py

+58-39
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import struct
66
import sys
7+
import warnings
78
from abc import ABC
89
from base64 import b64decode, b64encode
910
from datetime import datetime, timedelta, timezone
@@ -21,6 +22,8 @@
2122
get_type_hints,
2223
)
2324

25+
import typing
26+
2427
from ._types import T
2528
from .casing import camel_case, safe_snake_case, snake_case
2629
from .grpc.grpclib_client import ServiceStub
@@ -251,7 +254,7 @@ def map_field(
251254
)
252255

253256

254-
class Enum(int, enum.Enum):
257+
class Enum(enum.IntEnum):
255258
"""Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""
256259

257260
@classmethod
@@ -641,9 +644,13 @@ def __bytes__(self) -> bytes:
641644

642645
@classmethod
643646
def _type_hint(cls, field_name: str) -> Type:
647+
return cls._type_hints()[field_name]
648+
649+
@classmethod
650+
def _type_hints(cls) -> Dict[str, Type]:
644651
module = inspect.getmodule(cls)
645652
type_hints = get_type_hints(cls, vars(module))
646-
return type_hints[field_name]
653+
return type_hints
647654

648655
@classmethod
649656
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
@@ -795,55 +802,67 @@ def to_dict(
795802
`False`.
796803
"""
797804
output: Dict[str, Any] = {}
805+
field_types = self._type_hints()
798806
for field_name, meta in self._betterproto.meta_by_field_name.items():
799-
v = getattr(self, field_name)
807+
field_type = field_types[field_name]
808+
field_is_repeated = type(field_type) is type(typing.List)
809+
value = getattr(self, field_name)
800810
cased_name = casing(field_name).rstrip("_") # type: ignore
801-
if meta.proto_type == "message":
802-
if isinstance(v, datetime):
803-
if v != DATETIME_ZERO or include_default_values:
804-
output[cased_name] = _Timestamp.timestamp_to_json(v)
805-
elif isinstance(v, timedelta):
806-
if v != timedelta(0) or include_default_values:
807-
output[cased_name] = _Duration.delta_to_json(v)
811+
if meta.proto_type == TYPE_MESSAGE:
812+
if isinstance(value, datetime):
813+
if value != DATETIME_ZERO or include_default_values:
814+
output[cased_name] = _Timestamp.timestamp_to_json(value)
815+
elif isinstance(value, timedelta):
816+
if value != timedelta(0) or include_default_values:
817+
output[cased_name] = _Duration.delta_to_json(value)
808818
elif meta.wraps:
809-
if v is not None or include_default_values:
810-
output[cased_name] = v
811-
elif isinstance(v, list):
819+
if value is not None or include_default_values:
820+
output[cased_name] = value
821+
elif field_is_repeated:
812822
# Convert each item.
813-
v = [i.to_dict(casing, include_default_values) for i in v]
814-
if v or include_default_values:
815-
output[cased_name] = v
823+
value = [i.to_dict(casing, include_default_values) for i in value]
824+
if value or include_default_values:
825+
output[cased_name] = value
816826
else:
817-
if v._serialized_on_wire or include_default_values:
818-
output[cased_name] = v.to_dict(casing, include_default_values)
819-
elif meta.proto_type == "map":
820-
for k in v:
821-
if hasattr(v[k], "to_dict"):
822-
v[k] = v[k].to_dict(casing, include_default_values)
823-
824-
if v or include_default_values:
825-
output[cased_name] = v
826-
elif v != self._get_field_default(field_name) or include_default_values:
827+
if value._serialized_on_wire or include_default_values:
828+
output[cased_name] = value.to_dict(
829+
casing, include_default_values
830+
)
831+
elif meta.proto_type == TYPE_MAP:
832+
for k in value:
833+
if hasattr(value[k], "to_dict"):
834+
value[k] = value[k].to_dict(casing, include_default_values)
835+
836+
if value or include_default_values:
837+
output[cased_name] = value
838+
elif value != self._get_field_default(field_name) or include_default_values:
827839
if meta.proto_type in INT_64_TYPES:
828-
if isinstance(v, list):
829-
output[cased_name] = [str(n) for n in v]
840+
if field_is_repeated:
841+
output[cased_name] = [str(n) for n in value]
830842
else:
831-
output[cased_name] = str(v)
843+
output[cased_name] = str(value)
832844
elif meta.proto_type == TYPE_BYTES:
833-
if isinstance(v, list):
834-
output[cased_name] = [b64encode(b).decode("utf8") for b in v]
845+
if field_is_repeated:
846+
output[cased_name] = [
847+
b64encode(b).decode("utf8") for b in value
848+
]
835849
else:
836-
output[cased_name] = b64encode(v).decode("utf8")
850+
output[cased_name] = b64encode(value).decode("utf8")
837851
elif meta.proto_type == TYPE_ENUM:
838-
enum_values = list(
839-
self._betterproto.cls_by_field[field_name]
840-
) # type: ignore
841-
if isinstance(v, list):
842-
output[cased_name] = [enum_values[e].name for e in v]
852+
if field_is_repeated:
853+
enum_class: Type[Enum] = field_type.__args__[0]
854+
if isinstance(value, typing.Iterable) and not isinstance(
855+
value, str
856+
):
857+
output[cased_name] = [enum_class(el).name for el in value]
858+
else:
859+
# transparently upgrade single value to repeated
860+
output[cased_name] = [enum_class(value).name]
843861
else:
844-
output[cased_name] = enum_values[v].name
862+
enum_class: Type[Enum] = field_type # noqa
863+
output[cased_name] = enum_class(value).name
845864
else:
846-
output[cased_name] = v
865+
output[cased_name] = value
847866
return output
848867

849868
def from_dict(self: T, value: dict) -> T:

tests/inputs/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"namespace_keywords", # 70
66
"namespace_builtin_types", # 53
77
"googletypes_struct", # 9
8-
"googletypes_value", # 9,
8+
"googletypes_value", # 9
99
"import_capitalized_package",
1010
"example", # This is the example in the readme. Not a test.
1111
}

tests/inputs/enum/enum.json

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"choice": "FOUR",
3+
"choices": [
4+
"ZERO",
5+
"ONE",
6+
"THREE",
7+
"FOUR"
8+
]
9+
}

tests/inputs/enum/enum.proto

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
syntax = "proto3";
2+
3+
// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
4+
message Test {
5+
Choice choice = 1;
6+
repeated Choice choices = 2;
7+
}
8+
9+
enum Choice {
10+
ZERO = 0;
11+
ONE = 1;
12+
// TWO = 2;
13+
FOUR = 4;
14+
THREE = 3;
15+
}

tests/inputs/enum/test_enum.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from tests.output_betterproto.enum import (
2+
Test,
3+
Choice,
4+
)
5+
6+
7+
def test_enum_set_and_get():
8+
assert Test(choice=Choice.ZERO).choice == Choice.ZERO
9+
assert Test(choice=Choice.ONE).choice == Choice.ONE
10+
assert Test(choice=Choice.THREE).choice == Choice.THREE
11+
assert Test(choice=Choice.FOUR).choice == Choice.FOUR
12+
13+
14+
def test_enum_set_with_int():
15+
assert Test(choice=0).choice == Choice.ZERO
16+
assert Test(choice=1).choice == Choice.ONE
17+
assert Test(choice=3).choice == Choice.THREE
18+
assert Test(choice=4).choice == Choice.FOUR
19+
20+
21+
def test_enum_is_comparable_with_int():
22+
assert Test(choice=Choice.ZERO).choice == 0
23+
assert Test(choice=Choice.ONE).choice == 1
24+
assert Test(choice=Choice.THREE).choice == 3
25+
assert Test(choice=Choice.FOUR).choice == 4
26+
27+
28+
def test_enum_to_dict():
29+
assert (
30+
"choice" not in Test(choice=Choice.ZERO).to_dict()
31+
), "Default enum value is not serialized"
32+
assert (
33+
Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"]
34+
== "ZERO"
35+
)
36+
assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE"
37+
assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE"
38+
assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR"
39+
40+
41+
def test_repeated_enum_is_comparable_with_int():
42+
assert Test(choices=[Choice.ZERO]).choices == [0]
43+
assert Test(choices=[Choice.ONE]).choices == [1]
44+
assert Test(choices=[Choice.THREE]).choices == [3]
45+
assert Test(choices=[Choice.FOUR]).choices == [4]
46+
47+
48+
def test_repeated_enum_set_and_get():
49+
assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO]
50+
assert Test(choices=[Choice.ONE]).choices == [Choice.ONE]
51+
assert Test(choices=[Choice.THREE]).choices == [Choice.THREE]
52+
assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR]
53+
54+
55+
def test_repeated_enum_to_dict():
56+
assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"]
57+
assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"]
58+
assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"]
59+
assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"]
60+
61+
all_enums_dict = Test(
62+
choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR]
63+
).to_dict()
64+
assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"]
65+
66+
67+
def test_repeated_enum_with_single_value_to_dict():
68+
assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"]
69+
assert Test(choices=1).to_dict()["choices"] == ["ONE"]
70+
71+
72+
def test_repeated_enum_with_non_list_iterables_to_dict():
73+
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
74+
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
75+
assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [
76+
"ONE",
77+
"THREE",
78+
]
79+
80+
def enum_generator():
81+
yield Choice.ONE
82+
yield Choice.THREE
83+
84+
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]

tests/inputs/enums/enums.json

-3
This file was deleted.

tests/inputs/enums/enums.proto

-14
This file was deleted.

0 commit comments

Comments
 (0)