Skip to content

Commit 7c5ee47

Browse files
authored
Added support for infinite and nan floats/doubles (#215)
- Added support for the custom double values from the protobuf json spec: "Infinity", "-Infinity", and "NaN" - Added `infinite_floats` test data - Updated Message.__eq__ to consider nan values equal - Updated `test_message_json` and `test_binary_compatibility` to replace NaN float values in dictionaries before comparison (because two NaN values are not equal)
1 parent bb646fe commit 7c5ee47

File tree

4 files changed

+154
-10
lines changed

4 files changed

+154
-10
lines changed

src/betterproto/__init__.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import enum
33
import inspect
44
import json
5+
import math
56
import struct
67
import sys
78
import typing
@@ -113,6 +114,12 @@ def datetime_default_gen() -> datetime:
113114
DATETIME_ZERO = datetime_default_gen()
114115

115116

117+
# Special protobuf json doubles
118+
INFINITY = "Infinity"
119+
NEG_INFINITY = "-Infinity"
120+
NAN = "NaN"
121+
122+
116123
class Casing(enum.Enum):
117124
"""Casing constants for serialization."""
118125

@@ -369,6 +376,51 @@ def _serialize_single(
369376
return bytes(output)
370377

371378

379+
def _parse_float(value: Any) -> float:
380+
"""Parse the given value to a float
381+
382+
Parameters
383+
----------
384+
value : Any
385+
Value to parse
386+
387+
Returns
388+
-------
389+
float
390+
Parsed value
391+
"""
392+
if value == INFINITY:
393+
return float("inf")
394+
if value == NEG_INFINITY:
395+
return -float("inf")
396+
if value == NAN:
397+
return float("nan")
398+
return float(value)
399+
400+
401+
def _dump_float(value: float) -> Union[float, str]:
402+
"""Dump the given float to JSON
403+
404+
Parameters
405+
----------
406+
value : float
407+
Value to dump
408+
409+
Returns
410+
-------
411+
Union[float, str]
412+
Dumped valid, either a float or the strings
413+
"Infinity" or "-Infinity"
414+
"""
415+
if value == float("inf"):
416+
return INFINITY
417+
if value == -float("inf"):
418+
return NEG_INFINITY
419+
if value == float("nan"):
420+
return NAN
421+
return value
422+
423+
372424
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
373425
"""
374426
Decode a single varint value from a byte buffer. Returns the value and the
@@ -564,7 +616,18 @@ def __eq__(self, other) -> bool:
564616
other_val = other._get_field_default(field_name)
565617

566618
if self_val != other_val:
567-
return False
619+
# We consider two nan values to be the same for the
620+
# purposes of comparing messages (otherwise a message
621+
# is not equal to itself)
622+
if (
623+
isinstance(self_val, float)
624+
and isinstance(other_val, float)
625+
and math.isnan(self_val)
626+
and math.isnan(other_val)
627+
):
628+
continue
629+
else:
630+
return False
568631

569632
return True
570633

@@ -1015,6 +1078,11 @@ def to_dict(
10151078
else:
10161079
enum_class: Type[Enum] = field_types[field_name] # noqa
10171080
output[cased_name] = enum_class(value).name
1081+
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1082+
if field_is_repeated:
1083+
output[cased_name] = [_dump_float(n) for n in value]
1084+
else:
1085+
output[cased_name] = _dump_float(value)
10181086
else:
10191087
output[cased_name] = value
10201088
return output
@@ -1090,6 +1158,11 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
10901158
v = [enum_cls.from_string(e) for e in v]
10911159
elif isinstance(v, str):
10921160
v = enum_cls.from_string(v)
1161+
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1162+
if isinstance(value[key], list):
1163+
v = [_parse_float(n) for n in value[key]]
1164+
else:
1165+
v = _parse_float(value[key])
10931166

10941167
if v is not None:
10951168
setattr(self, field_name, v)

tests/inputs/float/float.json

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"positive": "Infinity",
3+
"negative": "-Infinity",
4+
"nan": "NaN",
5+
"three": 3.0,
6+
"threePointOneFour": 3.14,
7+
"negThree": -3.0,
8+
"negThreePointOneFour": -3.14
9+
}

tests/inputs/float/float.proto

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
syntax = "proto3";
2+
3+
// Some documentation about the Test message.
4+
message Test {
5+
double positive = 1;
6+
double negative = 2;
7+
double nan = 3;
8+
double three = 4;
9+
double three_point_one_four = 5;
10+
double neg_three = 6;
11+
double neg_three_point_one_four = 7;
12+
}

tests/test_inputs.py

+59-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import importlib
22
import json
3+
import math
34
import os
45
import sys
56
from collections import namedtuple
67
from types import ModuleType
7-
from typing import Set
8+
from typing import Any, Dict, List, Set
89

910
import pytest
1011

@@ -69,6 +70,55 @@ def module_has_entry_point(module: ModuleType):
6970
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
7071

7172

73+
def list_replace_nans(items: List) -> List[Any]:
74+
"""Replace float("nan") in a list with the string "NaN"
75+
76+
Parameters
77+
----------
78+
items : List
79+
List to update
80+
81+
Returns
82+
-------
83+
List[Any]
84+
Updated list
85+
"""
86+
result = []
87+
for item in items:
88+
if isinstance(item, list):
89+
result.append(list_replace_nans(item))
90+
elif isinstance(item, dict):
91+
result.append(dict_replace_nans(item))
92+
elif isinstance(item, float) and math.isnan(item):
93+
result.append(betterproto.NAN)
94+
return result
95+
96+
97+
def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
98+
"""Replace float("nan") in a dictionary with the string "NaN"
99+
100+
Parameters
101+
----------
102+
input_dict : Dict[Any, Any]
103+
Dictionary to update
104+
105+
Returns
106+
-------
107+
Dict[Any, Any]
108+
Updated dictionary
109+
"""
110+
result = {}
111+
for key, value in input_dict.items():
112+
if isinstance(value, dict):
113+
value = dict_replace_nans(value)
114+
elif isinstance(value, list):
115+
value = list_replace_nans(value)
116+
elif isinstance(value, float) and math.isnan(value):
117+
value = betterproto.NAN
118+
result[key] = value
119+
return result
120+
121+
72122
@pytest.fixture
73123
def test_data(request):
74124
test_case_name = request.param
@@ -81,7 +131,6 @@ def test_data(request):
81131
reference_module_root = os.path.join(
82132
*reference_output_package.split("."), test_case_name
83133
)
84-
85134
sys.path.append(reference_module_root)
86135

87136
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
@@ -132,7 +181,9 @@ def test_message_json(repeat, test_data: TestData) -> None:
132181
message.from_json(json_sample)
133182
message_json = message.to_json(0)
134183

135-
assert json.loads(message_json) == json.loads(json_sample)
184+
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
185+
json.loads(json_sample)
186+
)
136187

137188

138189
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
@@ -156,14 +207,13 @@ def test_binary_compatibility(repeat, test_data: TestData) -> None:
156207
reference_binary_output
157208
)
158209

159-
# # Generally this can't be relied on, but here we are aiming to match the
160-
# # existing Python implementation and aren't doing anything tricky.
161-
# # https://developers.google.com/protocol-buffers/docs/encoding#implications
210+
# Generally this can't be relied on, but here we are aiming to match the
211+
# existing Python implementation and aren't doing anything tricky.
212+
# https://developers.google.com/protocol-buffers/docs/encoding#implications
162213
assert bytes(plugin_instance_from_json) == reference_binary_output
163214
assert bytes(plugin_instance_from_binary) == reference_binary_output
164215

165216
assert plugin_instance_from_json == plugin_instance_from_binary
166-
assert (
217+
assert dict_replace_nans(
167218
plugin_instance_from_json.to_dict()
168-
== plugin_instance_from_binary.to_dict()
169-
)
219+
) == dict_replace_nans(plugin_instance_from_binary.to_dict())

0 commit comments

Comments
 (0)