Skip to content

Commit 4f48950

Browse files
committed
map enum int's into python enums (danielgtaylor#157)
1 parent ad8b917 commit 4f48950

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

src/betterproto/__init__.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -835,8 +835,8 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
835835
else:
836836
return t
837837
elif issubclass(t, Enum):
838-
# Enums always default to zero.
839-
return int
838+
cls = cls._cls_for(field)
839+
return lambda: cls(0) # Enums always default to zero.
840840
elif t is datetime:
841841
# Offsets are relative to 1970-01-01T00:00:00Z
842842
return datetime_default_gen
@@ -861,6 +861,10 @@ def _postprocess_single(
861861
elif meta.proto_type == TYPE_BOOL:
862862
# Booleans use a varint encoding, so convert it to true/false.
863863
value = value > 0
864+
elif meta.proto_type == TYPE_ENUM:
865+
# Convert enum ints to python enum instances
866+
cls = self._betterproto.cls_by_field[field_name]
867+
value = cls(value)
864868
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
865869
fmt = _pack_fmt(meta.proto_type)
866870
value = struct.unpack(fmt, value)[0]

tests/inputs/enum/test_enum.py

+20
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,23 @@ def enum_generator():
8282
yield Choice.THREE
8383

8484
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
85+
86+
87+
def test_enum_mapped_on_parse():
88+
# test default value
89+
b = Test().parse(bytes(Test()))
90+
assert b.choice.name == Choice.ZERO.name
91+
assert b.choices == []
92+
93+
# test non default value
94+
a = Test().parse(bytes(Test(choice=Choice.ONE)))
95+
assert a.choice.name == Choice.ONE.name
96+
assert b.choices == []
97+
98+
# test repeated
99+
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
100+
assert c.choices[0].name == Choice.THREE.name
101+
assert c.choices[1].name == Choice.FOUR.name
102+
103+
# bonus: defaults after empty init are also mapped
104+
assert Test().choice.name == Choice.ZERO.name

0 commit comments

Comments
 (0)