Skip to content

Commit c1a76a5

Browse files
authored
Serialize default values in oneofs when calling to_dict() or to_json() (#110)
* Serialize default values in oneofs when calling to_dict() or to_json() This change is consistent with the official protobuf implementation. If a default value is set when using a oneof, and then a message is translated from message -> JSON -> message, the default value is kept in tact. Also, if no default value is set, they remain null. * Some cleanup + testing for nested messages with oneofs * Cleanup oneof_enum test cases, they should be fixed This _should_ address: #63 * Include default value oneof fields when serializing to bytes This will cause oneof fields with default values to explicitly be sent to clients. Note that does not mean that all fields are serialized and sent to clients, just those that _could_ be null and are not. * Remove assignment when populating a sub-message within a proto Also, move setattr out one indentation level * Properly transform proto with empty string in oneof to bytes Also, updated tests to ensure that which_one_of picks up the set field * Formatting betterproto/__init__.py * Adding test cases demonstrating equivalent behaviour with google impl * Removing a temporary file I made locally * Adding some clarifying comments * Fixing tests for python38
1 parent 2745953 commit c1a76a5

File tree

7 files changed

+277
-27
lines changed

7 files changed

+277
-27
lines changed

src/betterproto/__init__.py

+63-20
Original file line numberDiff line numberDiff line change
@@ -583,18 +583,20 @@ def __bytes__(self) -> bytes:
583583
# Being selected in a a group means this field is the one that is
584584
# currently set in a `oneof` group, so it must be serialized even
585585
# if the value is the default zero value.
586-
selected_in_group = False
587-
if meta.group and self._group_current[meta.group] == field_name:
588-
selected_in_group = True
586+
selected_in_group = (
587+
meta.group and self._group_current[meta.group] == field_name
588+
)
589589

590-
serialize_empty = False
591-
if isinstance(value, Message) and value._serialized_on_wire:
592-
# Empty messages can still be sent on the wire if they were
593-
# set (or received empty).
594-
serialize_empty = True
590+
# Empty messages can still be sent on the wire if they were
591+
# set (or received empty).
592+
serialize_empty = isinstance(value, Message) and value._serialized_on_wire
593+
594+
include_default_value_for_oneof = self._include_default_value_for_oneof(
595+
field_name=field_name, meta=meta
596+
)
595597

596598
if value == self._get_field_default(field_name) and not (
597-
selected_in_group or serialize_empty
599+
selected_in_group or serialize_empty or include_default_value_for_oneof
598600
):
599601
# Default (zero) values are not serialized. Two exceptions are
600602
# if this is the selected oneof item or if we know we have to
@@ -623,6 +625,17 @@ def __bytes__(self) -> bytes:
623625
sv = _serialize_single(2, meta.map_types[1], v)
624626
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
625627
else:
628+
# If we have an empty string and we're including the default value for
629+
# a oneof, make sure we serialize it. This ensures that the byte string
630+
# output isn't simply an empty string. This also ensures that round trip
631+
# serialization will keep `which_one_of` calls consistent.
632+
if (
633+
isinstance(value, str)
634+
and value == ""
635+
and include_default_value_for_oneof
636+
):
637+
serialize_empty = True
638+
626639
output += _serialize_single(
627640
meta.number,
628641
meta.proto_type,
@@ -726,6 +739,13 @@ def _postprocess_single(
726739

727740
return value
728741

742+
def _include_default_value_for_oneof(
743+
self, field_name: str, meta: FieldMetadata
744+
) -> bool:
745+
return (
746+
meta.group is not None and self._group_current.get(meta.group) == field_name
747+
)
748+
729749
def parse(self: T, data: bytes) -> T:
730750
"""
731751
Parse the binary encoded Protobuf into this message instance. This
@@ -804,10 +824,22 @@ def to_dict(
804824
cased_name = casing(field_name).rstrip("_") # type: ignore
805825
if meta.proto_type == TYPE_MESSAGE:
806826
if isinstance(value, datetime):
807-
if value != DATETIME_ZERO or include_default_values:
827+
if (
828+
value != DATETIME_ZERO
829+
or include_default_values
830+
or self._include_default_value_for_oneof(
831+
field_name=field_name, meta=meta
832+
)
833+
):
808834
output[cased_name] = _Timestamp.timestamp_to_json(value)
809835
elif isinstance(value, timedelta):
810-
if value != timedelta(0) or include_default_values:
836+
if (
837+
value != timedelta(0)
838+
or include_default_values
839+
or self._include_default_value_for_oneof(
840+
field_name=field_name, meta=meta
841+
)
842+
):
811843
output[cased_name] = _Duration.delta_to_json(value)
812844
elif meta.wraps:
813845
if value is not None or include_default_values:
@@ -817,19 +849,28 @@ def to_dict(
817849
value = [i.to_dict(casing, include_default_values) for i in value]
818850
if value or include_default_values:
819851
output[cased_name] = value
820-
else:
821-
if value._serialized_on_wire or include_default_values:
822-
output[cased_name] = value.to_dict(
823-
casing, include_default_values
824-
)
825-
elif meta.proto_type == TYPE_MAP:
852+
elif (
853+
value._serialized_on_wire
854+
or include_default_values
855+
or self._include_default_value_for_oneof(
856+
field_name=field_name, meta=meta
857+
)
858+
):
859+
output[cased_name] = value.to_dict(casing, include_default_values,)
860+
elif meta.proto_type == "map":
826861
for k in value:
827862
if hasattr(value[k], "to_dict"):
828863
value[k] = value[k].to_dict(casing, include_default_values)
829864

830865
if value or include_default_values:
831866
output[cased_name] = value
832-
elif value != self._get_field_default(field_name) or include_default_values:
867+
elif (
868+
value != self._get_field_default(field_name)
869+
or include_default_values
870+
or self._include_default_value_for_oneof(
871+
field_name=field_name, meta=meta
872+
)
873+
):
833874
if meta.proto_type in INT_64_TYPES:
834875
if field_is_repeated:
835876
output[cased_name] = [str(n) for n in value]
@@ -888,6 +929,8 @@ def from_dict(self: T, value: dict) -> T:
888929
elif meta.wraps:
889930
setattr(self, field_name, value[key])
890931
else:
932+
# NOTE: `from_dict` mutates the underlying message, so no
933+
# assignment here is necessary.
891934
v.from_dict(value[key])
892935
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
893936
v = getattr(self, field_name)
@@ -913,8 +956,8 @@ def from_dict(self: T, value: dict) -> T:
913956
elif isinstance(v, str):
914957
v = enum_cls.from_string(v)
915958

916-
if v is not None:
917-
setattr(self, field_name, v)
959+
if v is not None:
960+
setattr(self, field_name, v)
918961
return self
919962

920963
def to_json(self, indent: Union[None, int, str] = None) -> str:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto3";
2+
3+
message Foo{
4+
int64 bar = 1;
5+
}
6+
7+
message Test{
8+
oneof group{
9+
string string = 1;
10+
int64 integer = 2;
11+
Foo foo = 3;
12+
}
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
3+
from google.protobuf import json_format
4+
import betterproto
5+
from tests.output_betterproto.google_impl_behavior_equivalence import (
6+
Test,
7+
Foo,
8+
)
9+
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
10+
Test as ReferenceTest,
11+
Foo as ReferenceFoo,
12+
)
13+
14+
15+
def test_oneof_serializes_similar_to_google_oneof():
16+
17+
tests = [
18+
(Test(string="abc"), ReferenceTest(string="abc")),
19+
(Test(integer=2), ReferenceTest(integer=2)),
20+
(Test(foo=Foo(bar=1)), ReferenceTest(foo=ReferenceFoo(bar=1))),
21+
# Default values should also behave the same within oneofs
22+
(Test(string=""), ReferenceTest(string="")),
23+
(Test(integer=0), ReferenceTest(integer=0)),
24+
(Test(foo=Foo(bar=0)), ReferenceTest(foo=ReferenceFoo(bar=0))),
25+
]
26+
for message, message_reference in tests:
27+
# NOTE: As of July 2020, MessageToJson inserts newlines in the output string so,
28+
# just compare dicts
29+
assert message.to_dict() == json_format.MessageToDict(message_reference)
30+
31+
32+
def test_bytes_are_the_same_for_oneof():
33+
34+
message = Test(string="")
35+
message_reference = ReferenceTest(string="")
36+
37+
message_bytes = bytes(message)
38+
message_reference_bytes = message_reference.SerializeToString()
39+
40+
assert message_bytes == message_reference_bytes
41+
42+
message2 = Test().parse(message_reference_bytes)
43+
message_reference2 = ReferenceTest()
44+
message_reference2.ParseFromString(message_reference_bytes)
45+
46+
assert message == message2
47+
assert message_reference == message_reference2
48+
49+
# None of these fields were explicitly set BUT they should not actually be null
50+
# themselves
51+
assert isinstance(message.foo, Foo)
52+
assert isinstance(message2.foo, Foo)
53+
54+
assert isinstance(message_reference.foo, ReferenceFoo)
55+
assert isinstance(message_reference2.foo, ReferenceFoo)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
syntax = "proto3";
2+
3+
import "google/protobuf/duration.proto";
4+
import "google/protobuf/timestamp.proto";
5+
import "google/protobuf/wrappers.proto";
6+
7+
message Message{
8+
int64 value = 1;
9+
}
10+
11+
message NestedMessage{
12+
int64 id = 1;
13+
oneof value_type{
14+
Message wrapped_message_value = 2;
15+
}
16+
}
17+
18+
message Test{
19+
oneof value_type {
20+
bool bool_value = 1;
21+
int64 int64_value = 2;
22+
google.protobuf.Timestamp timestamp_value = 3;
23+
google.protobuf.Duration duration_value = 4;
24+
Message wrapped_message_value = 5;
25+
NestedMessage wrapped_nested_message_value = 6;
26+
google.protobuf.BoolValue wrapped_bool_value = 7;
27+
}
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
import datetime
3+
4+
import betterproto
5+
from tests.output_betterproto.oneof_default_value_serialization import (
6+
Test,
7+
Message,
8+
NestedMessage,
9+
)
10+
11+
12+
def assert_round_trip_serialization_works(message: Test) -> None:
13+
assert betterproto.which_one_of(message, "value_type") == betterproto.which_one_of(
14+
Test().from_json(message.to_json()), "value_type"
15+
)
16+
17+
18+
def test_oneof_default_value_serialization_works_for_all_values():
19+
"""
20+
Serialization from message with oneof set to default -> JSON -> message should keep
21+
default value field intact.
22+
"""
23+
24+
test_cases = [
25+
Test(bool_value=False),
26+
Test(int64_value=0),
27+
Test(
28+
timestamp_value=datetime.datetime(
29+
year=1970,
30+
month=1,
31+
day=1,
32+
hour=0,
33+
minute=0,
34+
tzinfo=datetime.timezone.utc,
35+
)
36+
),
37+
Test(duration_value=datetime.timedelta(0)),
38+
Test(wrapped_message_value=Message(value=0)),
39+
# NOTE: Do NOT use betterproto.BoolValue here, it will cause JSON serialization
40+
# errors.
41+
# TODO: Do we want to allow use of BoolValue directly within a wrapped field or
42+
# should we simply hard fail here?
43+
Test(wrapped_bool_value=False),
44+
]
45+
for message in test_cases:
46+
assert_round_trip_serialization_works(message)
47+
48+
49+
def test_oneof_no_default_values_passed():
50+
message = Test()
51+
assert (
52+
betterproto.which_one_of(message, "value_type")
53+
== betterproto.which_one_of(Test().from_json(message.to_json()), "value_type")
54+
== ("", None)
55+
)
56+
57+
58+
def test_oneof_nested_oneof_messages_are_serialized_with_defaults():
59+
"""
60+
Nested messages with oneofs should also be handled
61+
"""
62+
message = Test(
63+
wrapped_nested_message_value=NestedMessage(
64+
id=0, wrapped_message_value=Message(value=0)
65+
)
66+
)
67+
assert (
68+
betterproto.which_one_of(message, "value_type")
69+
== betterproto.which_one_of(Test().from_json(message.to_json()), "value_type")
70+
== (
71+
"wrapped_nested_message_value",
72+
NestedMessage(id=0, wrapped_message_value=Message(value=0)),
73+
)
74+
)

tests/inputs/oneof_enum/test_oneof_enum.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,36 @@
99
from tests.util import get_test_case_json_data
1010

1111

12-
@pytest.mark.xfail
1312
def test_which_one_of_returns_enum_with_default_value():
1413
"""
1514
returns first field when it is enum and set with default value
1615
"""
1716
message = Test()
1817
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json"))
19-
assert message.move is None
18+
19+
assert message.move == Move(
20+
x=0, y=0
21+
) # Proto3 will default this as there is no null
2022
assert message.signal == Signal.PASS
2123
assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS)
2224

2325

24-
@pytest.mark.xfail
2526
def test_which_one_of_returns_enum_with_non_default_value():
2627
"""
2728
returns first field when it is enum and set with non default value
2829
"""
2930
message = Test()
3031
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json"))
31-
assert message.move is None
32-
assert message.signal == Signal.PASS
32+
assert message.move == Move(
33+
x=0, y=0
34+
) # Proto3 will default this as there is no null
35+
assert message.signal == Signal.RESIGN
3336
assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)
3437

3538

36-
@pytest.mark.xfail
3739
def test_which_one_of_returns_second_field_when_set():
3840
message = Test()
3941
message.from_json(get_test_case_json_data("oneof_enum"))
4042
assert message.move == Move(x=2, y=3)
41-
assert message.signal == 0
43+
assert message.signal == Signal.PASS
4244
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

0 commit comments

Comments
 (0)