Skip to content

Commit 1a87ea4

Browse files
authored
Merge pull request #40 from boukeversteegh/pr/wrapper-as-output
Support using Google's wrapper types as RPC output values
2 parents 983e089 + 8f0caf1 commit 1a87ea4

File tree

7 files changed

+191
-43
lines changed

7 files changed

+191
-43
lines changed

betterproto/plugin.py

+31-25
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#!/usr/bin/env python
22

33
import itertools
4-
import json
54
import os.path
6-
import re
75
import sys
86
import textwrap
9-
from typing import Any, List, Tuple
7+
from collections import defaultdict
8+
from typing import Dict, List, Optional, Type
109

1110
try:
1211
import black
@@ -24,44 +23,51 @@
2423
DescriptorProto,
2524
EnumDescriptorProto,
2625
FieldDescriptorProto,
27-
FileDescriptorProto,
28-
ServiceDescriptorProto,
2926
)
3027

3128
from betterproto.casing import safe_snake_case
3229

30+
import google.protobuf.wrappers_pb2 as google_wrappers
3331

34-
WRAPPER_TYPES = {
35-
"google.protobuf.DoubleValue": "float",
36-
"google.protobuf.FloatValue": "float",
37-
"google.protobuf.Int64Value": "int",
38-
"google.protobuf.UInt64Value": "int",
39-
"google.protobuf.Int32Value": "int",
40-
"google.protobuf.UInt32Value": "int",
41-
"google.protobuf.BoolValue": "bool",
42-
"google.protobuf.StringValue": "str",
43-
"google.protobuf.BytesValue": "bytes",
44-
}
32+
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, {
33+
'google.protobuf.DoubleValue': google_wrappers.DoubleValue,
34+
'google.protobuf.FloatValue': google_wrappers.FloatValue,
35+
'google.protobuf.Int64Value': google_wrappers.Int64Value,
36+
'google.protobuf.UInt64Value': google_wrappers.UInt64Value,
37+
'google.protobuf.Int32Value': google_wrappers.Int32Value,
38+
'google.protobuf.UInt32Value': google_wrappers.UInt32Value,
39+
'google.protobuf.BoolValue': google_wrappers.BoolValue,
40+
'google.protobuf.StringValue': google_wrappers.StringValue,
41+
'google.protobuf.BytesValue': google_wrappers.BytesValue,
42+
})
4543

4644

47-
def get_ref_type(package: str, imports: set, type_name: str) -> str:
45+
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
4846
"""
4947
Return a Python type name for a proto type reference. Adds the import if
50-
necessary.
48+
necessary. Unwraps well known type if required.
5149
"""
5250
# If the package name is a blank string, then this should still work
5351
# because by convention packages are lowercase and message/enum types are
5452
# pascal-cased. May require refactoring in the future.
5553
type_name = type_name.lstrip(".")
5654

57-
if type_name in WRAPPER_TYPES:
58-
return f"Optional[{WRAPPER_TYPES[type_name]}]"
55+
# Check if type is wrapper.
56+
wrapper_class = WRAPPER_TYPES[type_name]
5957

60-
if type_name == "google.protobuf.Duration":
61-
return "timedelta"
58+
if unwrap:
59+
if wrapper_class:
60+
wrapped_type = type(wrapper_class().value)
61+
return f"Optional[{wrapped_type.__name__}]"
6262

63-
if type_name == "google.protobuf.Timestamp":
64-
return "datetime"
63+
if type_name == "google.protobuf.Duration":
64+
return "timedelta"
65+
66+
if type_name == "google.protobuf.Timestamp":
67+
return "datetime"
68+
elif wrapper_class:
69+
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
70+
return f"{wrapper_class.__name__}"
6571

6672
if type_name.startswith(package):
6773
parts = type_name.lstrip(package).lstrip(".").split(".")
@@ -379,7 +385,7 @@ def generate_code(request, response):
379385
).strip('"'),
380386
"input_message": input_message,
381387
"output": get_ref_type(
382-
package, output["imports"], method.output_type
388+
package, output["imports"], method.output_type, unwrap=False
383389
).strip('"'),
384390
"client_streaming": method.client_streaming,
385391
"server_streaming": method.server_streaming,

betterproto/tests/inputs/googletypes_response/googletypes_response.proto

+10-7
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@ syntax = "proto3";
22

33
import "google/protobuf/wrappers.proto";
44

5+
// Tests that wrapped values can be used directly as return values
6+
57
service Test {
6-
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
7-
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value);
8+
rpc GetDouble (Input) returns (google.protobuf.DoubleValue);
9+
rpc GetFloat (Input) returns (google.protobuf.FloatValue);
810
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
9-
rpc GetOutput (Input) returns (Output);
11+
rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value);
12+
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
13+
rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value);
14+
rpc GetBool (Input) returns (google.protobuf.BoolValue);
15+
rpc GetString (Input) returns (google.protobuf.StringValue);
16+
rpc GetBytes (Input) returns (google.protobuf.BytesValue);
1017
}
1118

1219
message Input {
1320

1421
}
15-
16-
message Output {
17-
google.protobuf.Int64Value int64 = 1;
18-
}
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,53 @@
1-
from typing import Optional
1+
from typing import Any, Callable, Optional
22

3+
import google.protobuf.wrappers_pb2 as wrappers
34
import pytest
45

6+
from betterproto.tests.mocks import MockChannel
57
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
6-
TestStub
8+
TestStub,
79
)
810

11+
test_cases = [
12+
(TestStub.get_double, wrappers.DoubleValue, 2.5),
13+
(TestStub.get_float, wrappers.FloatValue, 2.5),
14+
(TestStub.get_int64, wrappers.Int64Value, -64),
15+
(TestStub.get_u_int64, wrappers.UInt64Value, 64),
16+
(TestStub.get_int32, wrappers.Int32Value, -32),
17+
(TestStub.get_u_int32, wrappers.UInt32Value, 32),
18+
(TestStub.get_bool, wrappers.BoolValue, True),
19+
(TestStub.get_string, wrappers.StringValue, "string"),
20+
(TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]),
21+
]
922

10-
class TestStubChild(TestStub):
11-
async def _unary_unary(self, route, request, response_type, **kwargs):
12-
self.response_type = response_type
23+
24+
@pytest.mark.asyncio
25+
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
26+
async def test_channel_receives_wrapped_type(
27+
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
28+
):
29+
wrapped_value = wrapper_class()
30+
wrapped_value.value = value
31+
channel = MockChannel(responses=[wrapped_value])
32+
service = TestStub(channel)
33+
34+
await service_method(service)
35+
36+
assert channel.requests[0]["response_type"] != Optional[type(value)]
37+
assert channel.requests[0]["response_type"] == type(wrapped_value)
1338

1439

1540
@pytest.mark.asyncio
16-
async def test():
17-
pytest.skip("todo")
18-
stub = TestStubChild(None)
19-
await stub.get_int64()
20-
assert stub.response_type != Optional[int]
41+
@pytest.mark.xfail
42+
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
43+
async def test_service_unwraps_response(
44+
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
45+
):
46+
wrapped_value = wrapper_class()
47+
wrapped_value.value = value
48+
service = TestStub(MockChannel(responses=[wrapped_value]))
49+
50+
response_value = await service_method(service)
51+
52+
assert type(response_value) == value
53+
assert type(response_value) == type(value)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
syntax = "proto3";
2+
3+
import "google/protobuf/wrappers.proto";
4+
5+
// Tests that wrapped values are supported as part of output message
6+
service Test {
7+
rpc getOutput (Input) returns (Output);
8+
}
9+
10+
message Input {
11+
12+
}
13+
14+
message Output {
15+
google.protobuf.DoubleValue double_value = 1;
16+
google.protobuf.FloatValue float_value = 2;
17+
google.protobuf.Int64Value int64_value = 3;
18+
google.protobuf.UInt64Value uint64_value = 4;
19+
google.protobuf.Int32Value int32_value = 5;
20+
google.protobuf.UInt32Value uint32_value = 6;
21+
google.protobuf.BoolValue bool_value = 7;
22+
google.protobuf.StringValue string_value = 8;
23+
google.protobuf.BytesValue bytes_value = 9;
24+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
3+
from betterproto.tests.mocks import MockChannel
4+
from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import (
5+
Output,
6+
TestStub,
7+
)
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_service_passes_through_unwrapped_values_embedded_in_response():
12+
"""
13+
We do not not need to implement value unwrapping for embedded well-known types,
14+
as this is already handled by grpclib. This test merely shows that this is the case.
15+
"""
16+
output = Output(
17+
double_value=10.0,
18+
float_value=12.0,
19+
int64_value=-13,
20+
uint64_value=14,
21+
int32_value=-15,
22+
uint32_value=16,
23+
bool_value=True,
24+
string_value="string",
25+
bytes_value=bytes(0xFF)[0:4],
26+
)
27+
28+
service = TestStub(MockChannel(responses=[output]))
29+
response = await service.get_output()
30+
31+
assert response.double_value == 10.0
32+
assert response.float_value == 12.0
33+
assert response.int64_value == -13
34+
assert response.uint64_value == 14
35+
assert response.int32_value == -15
36+
assert response.uint32_value == 16
37+
assert response.bool_value
38+
assert response.string_value == "string"
39+
assert response.bytes_value == bytes(0xFF)[0:4]

betterproto/tests/mocks.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import List
2+
3+
from grpclib.client import Channel
4+
5+
6+
class MockChannel(Channel):
7+
# noinspection PyMissingConstructor
8+
def __init__(self, responses: List) -> None:
9+
self.responses = responses
10+
self.requests = []
11+
12+
def request(self, route, cardinality, request, response_type, **kwargs):
13+
self.requests.append(
14+
{
15+
"route": route,
16+
"cardinality": cardinality,
17+
"request": request,
18+
"response_type": response_type,
19+
}
20+
)
21+
return MockStream(self.responses)
22+
23+
24+
class MockStream:
25+
def __init__(self, responses: List) -> None:
26+
super().__init__()
27+
self.responses = responses
28+
29+
async def recv_message(self):
30+
return self.responses.pop(0)
31+
32+
async def send_message(self, *args, **kwargs):
33+
pass
34+
35+
async def __aexit__(self, exc_type, exc_val, exc_tb):
36+
return True
37+
38+
async def __aenter__(self):
39+
return self

betterproto/tests/test_inputs.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from google.protobuf.json_format import Parse
1717

1818

19-
excluded_test_cases = {"googletypes_response", "service"}
19+
excluded_test_cases = {
20+
"googletypes_response",
21+
"googletypes_response_embedded",
22+
"service",
23+
}
2024
test_case_names = {*get_directories(inputs_path)} - excluded_test_cases
2125

2226
plugin_output_package = "betterproto.tests.output_betterproto"

0 commit comments

Comments
 (0)