Skip to content

Commit 0373fdc

Browse files
roblablakalzoo
authored andcommitted
Properly support optional enums
1 parent 9fff31d commit 0373fdc

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

src/betterproto/__init__.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -834,8 +834,9 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
834834
# This is some kind of list (repeated) field.
835835
return list
836836
elif t.__origin__ is Union and t.__args__[1] is type(None):
837-
# This is an optional (wrapped) field. For setting the default we
838-
# really don't care what kind of field it is.
837+
# This is an optional field (either wrapped, or using proto3
838+
# field presence). For setting the default we really don't care
839+
# what kind of field it is.
839840
return type(None)
840841
else:
841842
return t
@@ -1009,6 +1010,7 @@ def to_dict(
10091010
defaults = self._betterproto.default_gen
10101011
for field_name, meta in self._betterproto.meta_by_field_name.items():
10111012
field_is_repeated = defaults[field_name] is list
1013+
field_is_optional = defaults[field_name] is type(None)
10121014
value = getattr(self, field_name)
10131015
cased_name = casing(field_name).rstrip("_") # type: ignore
10141016
if meta.proto_type == TYPE_MESSAGE:
@@ -1096,6 +1098,9 @@ def to_dict(
10961098
output[cased_name] = [enum_class(value).name]
10971099
elif value is None:
10981100
output[cased_name] = None
1101+
elif field_is_optional:
1102+
enum_class = field_types[field_name].__args__[0]
1103+
output[cased_name] = enum_class(value).name
10991104
else:
11001105
enum_class = field_types[field_name] # noqa
11011106
output[cased_name] = enum_class(value).name
@@ -1133,6 +1138,9 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
11331138
if value[key] is not None:
11341139
if meta.proto_type == TYPE_MESSAGE:
11351140
v = getattr(self, field_name)
1141+
if value[key] is None and self._get_field_default(key) == None:
1142+
# Setting an optional value to None.
1143+
setattr(self, field_name, None)
11361144
if isinstance(v, list):
11371145
cls = self._betterproto.cls_by_field[field_name]
11381146
if cls == datetime:
@@ -1152,6 +1160,9 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
11521160
setattr(self, field_name, v)
11531161
elif meta.wraps:
11541162
setattr(self, field_name, value[key])
1163+
elif v is None:
1164+
cls = self._betterproto.cls_by_field[field_name]
1165+
setattr(self, field_name, cls().from_dict(value[key]))
11551166
else:
11561167
# NOTE: `from_dict` mutates the underlying message, so no
11571168
# assignment here is necessary.

0 commit comments

Comments
 (0)