Skip to content

Commit 0b0345f

Browse files
committed
Moved pickling tests into their own file
1 parent 0b6bd5f commit 0b0345f

File tree

3 files changed

+173
-38
lines changed

3 files changed

+173
-38
lines changed

src/betterproto/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1138,7 +1138,7 @@ def __getstate__(self) -> bytes:
11381138
def __setstate__(self: T, pickled_bytes: bytes) -> T:
11391139
return self.parse(pickled_bytes)
11401140

1141-
def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
1141+
def __reduce__(self) -> Tuple[Any, ...]:
11421142
return (self.__class__.FromString, (bytes(self),))
11431143

11441144
@classmethod

tests/test_features.py

-37
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import pickle
32
import sys
43
from copy import (
54
copy,
@@ -741,39 +740,3 @@ def test_equality_comparison():
741740
assert msg == TestMessage(value=True)
742741
assert msg != 1
743742
assert msg != TestMessage(value=False)
744-
745-
746-
@dataclass
747-
class PickleMessage(betterproto.Message):
748-
foo: bool = betterproto.bool_field(1)
749-
bar: Optional[int] = betterproto.int32_field(2, optional=True)
750-
751-
752-
def test_default_pickle():
753-
deserialized = pickle.loads(pickle.dumps(PickleMessage()))
754-
755-
assert deserialized.foo is False
756-
assert deserialized.bar is None
757-
758-
759-
def test_pickle_with_set_values():
760-
msg = PickleMessage(foo=True, bar=42)
761-
deserialized = pickle.loads(pickle.dumps(msg))
762-
763-
assert deserialized.foo is True
764-
assert deserialized.bar == 42
765-
766-
767-
def test_pickling_recursive_message():
768-
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
769-
770-
msg = RecursiveMessage()
771-
deserialized = pickle.loads(pickle.dumps(msg))
772-
773-
assert deserialized.child == RecursiveMessage()
774-
775-
# Lazily-created zero-value children must not affect equality.
776-
assert deserialized == RecursiveMessage()
777-
778-
# Lazily-created zero-value children must not affect serialization.
779-
assert bytes(deserialized) == b""

tests/test_pickling.py

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import pickle
2+
from copy import (
3+
copy,
4+
deepcopy,
5+
)
6+
from dataclasses import dataclass
7+
from typing import (
8+
Dict,
9+
List,
10+
)
11+
from unittest.mock import ANY
12+
13+
import betterproto
14+
from betterproto.lib.google import protobuf as google
15+
16+
17+
def unpickled(message):
18+
return pickle.loads(pickle.dumps(message))
19+
20+
21+
@dataclass(eq=False, repr=False)
22+
class Fe(betterproto.Message):
23+
abc: str = betterproto.string_field(1)
24+
25+
26+
@dataclass(eq=False, repr=False)
27+
class Fi(betterproto.Message):
28+
abc: str = betterproto.string_field(1)
29+
30+
31+
@dataclass(eq=False, repr=False)
32+
class Fo(betterproto.Message):
33+
abc: str = betterproto.string_field(1)
34+
35+
36+
@dataclass(eq=False, repr=False)
37+
class NestedData(betterproto.Message):
38+
struct_foo: Dict[str, "google.Struct"] = betterproto.map_field(
39+
1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
40+
)
41+
map_str_any_bar: Dict[str, "google.Any"] = betterproto.map_field(
42+
2, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
43+
)
44+
45+
46+
@dataclass(eq=False, repr=False)
47+
class Complex(betterproto.Message):
48+
foo_str: str = betterproto.string_field(1)
49+
fe: "Fe" = betterproto.message_field(3, group="grp")
50+
fi: "Fi" = betterproto.message_field(4, group="grp")
51+
fo: "Fo" = betterproto.message_field(5, group="grp")
52+
nested_data: "NestedData" = betterproto.message_field(6)
53+
mapping: Dict[str, "google.Any"] = betterproto.map_field(
54+
7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
55+
)
56+
57+
58+
def test_pickling_complex_message():
59+
msg = Complex(
60+
foo_str="yep",
61+
fe=Fe(abc="1"),
62+
nested_data=NestedData(
63+
struct_foo={
64+
"foo": google.Struct(
65+
fields={
66+
"hello": google.Value(
67+
list_value=google.ListValue(
68+
values=[google.Value(string_value="world")]
69+
)
70+
)
71+
}
72+
),
73+
},
74+
map_str_any_bar={
75+
"key": google.Any(value=b"value"),
76+
},
77+
),
78+
mapping={
79+
"message": google.Any(value=bytes(Fi(abc="hi"))),
80+
"string": google.Any(value=b"howdy"),
81+
},
82+
)
83+
deser = unpickled(msg)
84+
assert msg == deser
85+
assert msg.fe.abc == "1"
86+
assert msg.is_set("fi") is not True
87+
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
88+
assert msg.mapping["string"].value.decode() == "howdy"
89+
assert (
90+
msg.nested_data.struct_foo["foo"]
91+
.fields["hello"]
92+
.list_value.values[0]
93+
.string_value
94+
== "world"
95+
)
96+
97+
98+
def test_recursive_message():
99+
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
100+
101+
msg = RecursiveMessage()
102+
msg = unpickled(msg)
103+
104+
assert msg.child == RecursiveMessage()
105+
106+
# Lazily-created zero-value children must not affect equality.
107+
assert msg == RecursiveMessage()
108+
109+
# Lazily-created zero-value children must not affect serialization.
110+
assert bytes(msg) == b""
111+
112+
113+
def test_recursive_message_defaults():
114+
from tests.output_betterproto.recursivemessage import (
115+
Intermediate,
116+
Test as RecursiveMessage,
117+
)
118+
119+
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
120+
msg = unpickled(msg)
121+
122+
# set values are as expected
123+
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
124+
125+
# lazy initialized works modifies the message
126+
assert msg != RecursiveMessage(
127+
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
128+
)
129+
msg.child.child.name = "jude"
130+
assert msg == RecursiveMessage(
131+
name="bob",
132+
intermediate=Intermediate(42),
133+
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
134+
)
135+
136+
# lazily initialization recurses as needed
137+
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
138+
assert msg.intermediate.child.intermediate == Intermediate()
139+
140+
141+
@dataclass
142+
class Spam(betterproto.Message):
143+
foo: bool = betterproto.bool_field(1)
144+
bar: int = betterproto.int32_field(2)
145+
baz: List[str] = betterproto.string_field(3)
146+
147+
148+
def test_copyability():
149+
msg = Spam(bar=12, baz=["hello"])
150+
msg = unpickled(msg)
151+
152+
copied = copy(msg)
153+
assert msg == copied
154+
assert msg is not copied
155+
assert msg.baz is copied.baz
156+
157+
deepcopied = deepcopy(msg)
158+
assert msg == deepcopied
159+
assert msg is not deepcopied
160+
assert msg.baz is not deepcopied.baz
161+
162+
163+
def test_equality_comparison():
164+
from tests.output_betterproto.bool import Test as TestMessage
165+
166+
msg = TestMessage(value=True)
167+
msg = unpickled(msg)
168+
assert msg == msg
169+
assert msg == ANY
170+
assert msg == TestMessage(value=True)
171+
assert msg != 1
172+
assert msg != TestMessage(value=False)

0 commit comments

Comments
 (0)