Skip to content

Commit 6536181

Browse files
crflynnGobot1234
andauthored
Add to/from_pydict methods (danielgtaylor#203)
* add to/from_pydict methods * Remove unnecessary method call * Fix formatting Co-authored-by: James Hilton-Balfe <[email protected]>
1 parent 85e4be9 commit 6536181

File tree

2 files changed

+226
-1
lines changed

2 files changed

+226
-1
lines changed

src/betterproto/__init__.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,136 @@ def from_json(self: T, value: Union[str, bytes]) -> T:
13061306
"""
13071307
return self.from_dict(json.loads(value))
13081308

1309+
def to_pydict(
1310+
self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
1311+
) -> Dict[str, Any]:
1312+
"""
1313+
Returns a python dict representation of this object.
1314+
1315+
Parameters
1316+
-----------
1317+
casing: :class:`Casing`
1318+
The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1319+
compatibility purposes.
1320+
include_default_values: :class:`bool`
1321+
If ``True`` will include the default values of fields. Default is ``False``.
1322+
E.g. an ``int32`` field will be included with a value of ``0`` if this is
1323+
set to ``True``, otherwise this would be ignored.
1324+
1325+
Returns
1326+
--------
1327+
Dict[:class:`str`, Any]
1328+
The python dict representation of this object.
1329+
"""
1330+
output: Dict[str, Any] = {}
1331+
defaults = self._betterproto.default_gen
1332+
for field_name, meta in self._betterproto.meta_by_field_name.items():
1333+
field_is_repeated = defaults[field_name] is list
1334+
value = getattr(self, field_name)
1335+
cased_name = casing(field_name).rstrip("_") # type: ignore
1336+
if meta.proto_type == TYPE_MESSAGE:
1337+
if isinstance(value, datetime):
1338+
if (
1339+
value != DATETIME_ZERO
1340+
or include_default_values
1341+
or self._include_default_value_for_oneof(
1342+
field_name=field_name, meta=meta
1343+
)
1344+
):
1345+
output[cased_name] = value
1346+
elif isinstance(value, timedelta):
1347+
if (
1348+
value != timedelta(0)
1349+
or include_default_values
1350+
or self._include_default_value_for_oneof(
1351+
field_name=field_name, meta=meta
1352+
)
1353+
):
1354+
output[cased_name] = value
1355+
elif meta.wraps:
1356+
if value is not None or include_default_values:
1357+
output[cased_name] = value
1358+
elif field_is_repeated:
1359+
# Convert each item.
1360+
value = [i.to_pydict(casing, include_default_values) for i in value]
1361+
if value or include_default_values:
1362+
output[cased_name] = value
1363+
elif (
1364+
value._serialized_on_wire
1365+
or include_default_values
1366+
or self._include_default_value_for_oneof(
1367+
field_name=field_name, meta=meta
1368+
)
1369+
):
1370+
output[cased_name] = value.to_pydict(casing, include_default_values)
1371+
elif meta.proto_type == TYPE_MAP:
1372+
for k in value:
1373+
if hasattr(value[k], "to_pydict"):
1374+
value[k] = value[k].to_pydict(casing, include_default_values)
1375+
1376+
if value or include_default_values:
1377+
output[cased_name] = value
1378+
elif (
1379+
value != self._get_field_default(field_name)
1380+
or include_default_values
1381+
or self._include_default_value_for_oneof(
1382+
field_name=field_name, meta=meta
1383+
)
1384+
):
1385+
output[cased_name] = value
1386+
return output
1387+
1388+
def from_pydict(self: T, value: Dict[str, Any]) -> T:
1389+
"""
1390+
Parse the key/value pairs into the current message instance. This returns the
1391+
instance itself and is therefore assignable and chainable.
1392+
1393+
Parameters
1394+
-----------
1395+
value: Dict[:class:`str`, Any]
1396+
The dictionary to parse from.
1397+
1398+
Returns
1399+
--------
1400+
:class:`Message`
1401+
The initialized message.
1402+
"""
1403+
self._serialized_on_wire = True
1404+
for key in value:
1405+
field_name = safe_snake_case(key)
1406+
meta = self._betterproto.meta_by_field_name.get(field_name)
1407+
if not meta:
1408+
continue
1409+
1410+
if value[key] is not None:
1411+
if meta.proto_type == TYPE_MESSAGE:
1412+
v = getattr(self, field_name)
1413+
if isinstance(v, list):
1414+
cls = self._betterproto.cls_by_field[field_name]
1415+
for item in value[key]:
1416+
v.append(cls().from_pydict(item))
1417+
elif isinstance(v, datetime):
1418+
v = value[key]
1419+
elif isinstance(v, timedelta):
1420+
v = value[key]
1421+
elif meta.wraps:
1422+
v = value[key]
1423+
else:
1424+
# NOTE: `from_pydict` mutates the underlying message, so no
1425+
# assignment here is necessary.
1426+
v.from_pydict(value[key])
1427+
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
1428+
v = getattr(self, field_name)
1429+
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
1430+
for k in value[key]:
1431+
v[k] = cls().from_pydict(value[key][k])
1432+
else:
1433+
v = value[key]
1434+
1435+
if v is not None:
1436+
setattr(self, field_name, v)
1437+
return self
1438+
13091439
def is_set(self, name: str) -> bool:
13101440
"""
13111441
Check if field with the given name has been set.

tests/test_features.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
deepcopy,
44
)
55
from dataclasses import dataclass
6-
from datetime import datetime
6+
from datetime import (
7+
datetime,
8+
timedelta,
9+
)
710
from inspect import (
811
Parameter,
912
signature,
@@ -79,6 +82,7 @@ class Foo(betterproto.Message):
7982
foo = Foo(name="foo", child=Bar(name="bar"))
8083

8184
assert foo.to_dict() == {"name": "foo", "child": {"name": "bar"}}
85+
assert foo.to_pydict() == {"name": "foo", "child": {"name": "bar"}}
8286

8387

8488
def test_enum_as_int_json():
@@ -98,6 +102,11 @@ class Foo(betterproto.Message):
98102
foo.bar = 1
99103
assert foo.to_dict() == {"bar": "ONE"}
100104

105+
# Similar expectations for pydict
106+
foo = Foo().from_pydict({"bar": 1})
107+
assert foo.bar == TestEnum.ONE
108+
assert foo.to_pydict() == {"bar": TestEnum.ONE}
109+
101110

102111
def test_unknown_fields():
103112
@dataclass
@@ -188,13 +197,25 @@ class CasingTest(betterproto.Message):
188197
"snakeCase": 3,
189198
"kabobCase": 4,
190199
}
200+
assert test.to_pydict() == {
201+
"pascalCase": 1,
202+
"camelCase": 2,
203+
"snakeCase": 3,
204+
"kabobCase": 4,
205+
}
191206

192207
assert test.to_dict(casing=betterproto.Casing.SNAKE) == {
193208
"pascal_case": 1,
194209
"camel_case": 2,
195210
"snake_case": 3,
196211
"kabob_case": 4,
197212
}
213+
assert test.to_pydict(casing=betterproto.Casing.SNAKE) == {
214+
"pascal_case": 1,
215+
"camel_case": 2,
216+
"snake_case": 3,
217+
"kabob_case": 4,
218+
}
198219

199220

200221
def test_optional_flag():
@@ -230,6 +251,15 @@ class TestMessage(betterproto.Message):
230251
"someBool": False,
231252
}
232253

254+
test = TestMessage().from_pydict({})
255+
256+
assert test.to_pydict(include_default_values=True) == {
257+
"someInt": 0,
258+
"someDouble": 0.0,
259+
"someStr": "",
260+
"someBool": False,
261+
}
262+
233263
# All default values
234264
test = TestMessage().from_dict(
235265
{"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False}
@@ -242,6 +272,17 @@ class TestMessage(betterproto.Message):
242272
"someBool": False,
243273
}
244274

275+
test = TestMessage().from_pydict(
276+
{"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False}
277+
)
278+
279+
assert test.to_pydict(include_default_values=True) == {
280+
"someInt": 0,
281+
"someDouble": 0.0,
282+
"someStr": "",
283+
"someBool": False,
284+
}
285+
245286
# Some default and some other values
246287
@dataclass
247288
class TestMessage2(betterproto.Message):
@@ -278,6 +319,30 @@ class TestMessage2(betterproto.Message):
278319
"someDefaultBool": False,
279320
}
280321

322+
test = TestMessage2().from_pydict(
323+
{
324+
"someInt": 2,
325+
"someDouble": 1.2,
326+
"someStr": "hello",
327+
"someBool": True,
328+
"someDefaultInt": 0,
329+
"someDefaultDouble": 0.0,
330+
"someDefaultStr": "",
331+
"someDefaultBool": False,
332+
}
333+
)
334+
335+
assert test.to_pydict(include_default_values=True) == {
336+
"someInt": 2,
337+
"someDouble": 1.2,
338+
"someStr": "hello",
339+
"someBool": True,
340+
"someDefaultInt": 0,
341+
"someDefaultDouble": 0.0,
342+
"someDefaultStr": "",
343+
"someDefaultBool": False,
344+
}
345+
281346
# Nested messages
282347
@dataclass
283348
class TestChildMessage(betterproto.Message):
@@ -297,6 +362,36 @@ class TestParentMessage(betterproto.Message):
297362
"someMessage": {"someOtherInt": 0},
298363
}
299364

365+
test = TestParentMessage().from_pydict({"someInt": 0, "someDouble": 1.2})
366+
367+
assert test.to_pydict(include_default_values=True) == {
368+
"someInt": 0,
369+
"someDouble": 1.2,
370+
"someMessage": {"someOtherInt": 0},
371+
}
372+
373+
374+
def test_to_dict_datetime_values():
375+
@dataclass
376+
class TestDatetimeMessage(betterproto.Message):
377+
bar: datetime = betterproto.message_field(1)
378+
baz: timedelta = betterproto.message_field(2)
379+
380+
test = TestDatetimeMessage().from_dict(
381+
{"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"}
382+
)
383+
384+
assert test.to_dict() == {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"}
385+
386+
test = TestDatetimeMessage().from_pydict(
387+
{"bar": datetime(year=2020, month=1, day=1), "baz": timedelta(days=1)}
388+
)
389+
390+
assert test.to_pydict() == {
391+
"bar": datetime(year=2020, month=1, day=1),
392+
"baz": timedelta(days=1),
393+
}
394+
300395

301396
def test_oneof_default_value_set_causes_writes_wire():
302397
@dataclass

0 commit comments

Comments
 (0)