Skip to content

Commit 06bc55a

Browse files
committed
Added support for infinite and nan floats/doubles
- Added support for the custom double values from the protobuf json spec: "Infinity", "-Infinity", and "NaN" - Added `infinite_floats` test data - Updated Message.__eq__ to consider nan values equal - Updated `test_message_json` and `test_binary_compatibility` to replace NaN float values in dictionaries before comparison (because two NaN values are not equal)
1 parent 59f5f88 commit 06bc55a

File tree

4 files changed

+146
-6
lines changed

4 files changed

+146
-6
lines changed

src/betterproto/__init__.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import enum
33
import inspect
44
import json
5+
import math
56
import struct
67
import sys
78
import typing
@@ -117,6 +118,12 @@ def datetime_default_gen() -> datetime:
117118
DATETIME_ZERO = datetime_default_gen()
118119

119120

121+
# Special protobuf json doubles
122+
INFINITY = "Infinity"
123+
NEG_INFINITY = "-Infinity"
124+
NAN = "NaN"
125+
126+
120127
class Casing(enum.Enum):
121128
"""Casing constants for serialization."""
122129

@@ -373,6 +380,51 @@ def _serialize_single(
373380
return bytes(output)
374381

375382

383+
def _parse_float(value: Any) -> float:
384+
"""Parse the given value to a float
385+
386+
Parameters
387+
----------
388+
value : Any
389+
Value to parse
390+
391+
Returns
392+
-------
393+
float
394+
Parsed value
395+
"""
396+
if value == INFINITY:
397+
return float("inf")
398+
if value == NEG_INFINITY:
399+
return -float("inf")
400+
if value == NAN:
401+
return float("nan")
402+
return float(value)
403+
404+
405+
def _dump_float(value: float) -> Union[float, str]:
406+
"""Dump the given float to JSON
407+
408+
Parameters
409+
----------
410+
value : float
411+
Value to dump
412+
413+
Returns
414+
-------
415+
Union[float, str]
416+
Dumped valid, either a float or the strings
417+
"Infinity" or "-Infinity"
418+
"""
419+
if value == float("inf"):
420+
return INFINITY
421+
if value == -float("inf"):
422+
return NEG_INFINITY
423+
if value == float("nan"):
424+
return NAN
425+
return value
426+
427+
376428
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
377429
"""
378430
Decode a single varint value from a byte buffer. Returns the value and the
@@ -568,7 +620,18 @@ def __eq__(self, other) -> bool:
568620
other_val = other._get_field_default(field_name)
569621

570622
if self_val != other_val:
571-
return False
623+
# We consider two nan values to be the same for the
624+
# purposes of comparing messages (otherwise a message
625+
# is not equal to itself)
626+
if (
627+
isinstance(self_val, float)
628+
and isinstance(other_val, float)
629+
and math.isnan(self_val)
630+
and math.isnan(other_val)
631+
):
632+
continue
633+
else:
634+
return False
572635

573636
return True
574637

@@ -1011,6 +1074,11 @@ def to_dict(
10111074
else:
10121075
enum_class: Type[Enum] = field_types[field_name] # noqa
10131076
output[cased_name] = enum_class(value).name
1077+
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1078+
if field_is_repeated:
1079+
output[cased_name] = [_dump_float(n) for n in value]
1080+
else:
1081+
output[cased_name] = _dump_float(value)
10141082
else:
10151083
output[cased_name] = value
10161084
return output
@@ -1079,6 +1147,11 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
10791147
v = [enum_cls.from_string(e) for e in v]
10801148
elif isinstance(v, str):
10811149
v = enum_cls.from_string(v)
1150+
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1151+
if isinstance(value[key], list):
1152+
v = [_parse_float(n) for n in value[key]]
1153+
else:
1154+
v = _parse_float(value[key])
10821155

10831156
if v is not None:
10841157
setattr(self, field_name, v)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"positive": "Infinity",
3+
"negative": "-Infinity",
4+
"nan": "NaN"
5+
}
6+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
syntax = "proto3";
2+
3+
// Some documentation about the Test message.
4+
message Test {
5+
double positive = 1;
6+
double negative = 2;
7+
double nan = 3;
8+
}

tests/test_inputs.py

+58-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import importlib
22
import json
3+
import math
34
import os
45
import sys
56
from collections import namedtuple
67
from types import ModuleType
7-
from typing import Set
8+
from typing import Any, Dict, List, Set
89

910
import pytest
1011

@@ -69,6 +70,56 @@ def module_has_entry_point(module: ModuleType):
6970
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
7071

7172

73+
def list_replace_nans(l: List) -> List[Any]:
74+
"""Replace float("nan") in a list with the string "NaN"
75+
76+
Parameters
77+
----------
78+
l : List
79+
List to update
80+
81+
Returns
82+
-------
83+
List[Any]
84+
Updated list
85+
"""
86+
x = []
87+
for e in l:
88+
if isinstance(e, list):
89+
e = list_replace_nans(e)
90+
elif isinstance(e, dict):
91+
e = dict_replace_nans(e)
92+
elif isinstance(e, float) and math.isnan(e):
93+
e = betterproto.NAN
94+
x.append(e)
95+
return x
96+
97+
98+
def dict_replace_nans(d: Dict[Any, Any]) -> Dict[Any, Any]:
99+
"""Replace float("nan") in a dictionary with the string "NaN"
100+
101+
Parameters
102+
----------
103+
l : Dict[Any, Any]
104+
Dictionary to update
105+
106+
Returns
107+
-------
108+
Dict[Any, Any]
109+
Updated dictionary
110+
"""
111+
x = {}
112+
for k, v in d.items():
113+
if isinstance(v, dict):
114+
v = dict_replace_nans(v)
115+
elif isinstance(v, list):
116+
v = list_replace_nans(v)
117+
elif isinstance(v, float) and math.isnan(v):
118+
v = betterproto.NAN
119+
x[k] = v
120+
return x
121+
122+
72123
@pytest.fixture
73124
def test_data(request):
74125
test_case_name = request.param
@@ -131,7 +182,9 @@ def test_message_json(repeat, test_data: TestData) -> None:
131182
message.from_json(json_data)
132183
message_json = message.to_json(0)
133184

134-
assert json.loads(message_json) == json.loads(json_data)
185+
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
186+
json.loads(json_data)
187+
)
135188

136189

137190
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
@@ -162,6 +215,6 @@ def test_binary_compatibility(repeat, test_data: TestData) -> None:
162215
assert bytes(plugin_instance_from_binary) == reference_binary_output
163216

164217
assert plugin_instance_from_json == plugin_instance_from_binary
165-
assert (
166-
plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict()
167-
)
218+
assert dict_replace_nans(
219+
plugin_instance_from_json.to_dict()
220+
) == dict_replace_nans(plugin_instance_from_binary.to_dict())

0 commit comments

Comments
 (0)