Skip to content

Support using Google's wrapper types as RPC output values #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions betterproto/plugin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#!/usr/bin/env python

import itertools
import json
import os.path
import re
import sys
import textwrap
from typing import Any, List, Tuple
from collections import defaultdict
from typing import Dict, List, Optional, Type

try:
import black
Expand All @@ -24,44 +23,51 @@
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
)

from betterproto.casing import safe_snake_case

import google.protobuf.wrappers_pb2 as google_wrappers

WRAPPER_TYPES = {
"google.protobuf.DoubleValue": "float",
"google.protobuf.FloatValue": "float",
"google.protobuf.Int64Value": "int",
"google.protobuf.UInt64Value": "int",
"google.protobuf.Int32Value": "int",
"google.protobuf.UInt32Value": "int",
"google.protobuf.BoolValue": "bool",
"google.protobuf.StringValue": "str",
"google.protobuf.BytesValue": "bytes",
}
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, {
'google.protobuf.DoubleValue': google_wrappers.DoubleValue,
'google.protobuf.FloatValue': google_wrappers.FloatValue,
'google.protobuf.Int64Value': google_wrappers.Int64Value,
'google.protobuf.UInt64Value': google_wrappers.UInt64Value,
'google.protobuf.Int32Value': google_wrappers.Int32Value,
'google.protobuf.UInt32Value': google_wrappers.UInt32Value,
'google.protobuf.BoolValue': google_wrappers.BoolValue,
'google.protobuf.StringValue': google_wrappers.StringValue,
'google.protobuf.BytesValue': google_wrappers.BytesValue,
})


def get_ref_type(package: str, imports: set, type_name: str) -> str:
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary.
necessary. Unwraps well known type if required.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")

if type_name in WRAPPER_TYPES:
return f"Optional[{WRAPPER_TYPES[type_name]}]"
# Check if type is wrapper.
wrapper_class = WRAPPER_TYPES[type_name]

if type_name == "google.protobuf.Duration":
return "timedelta"
if unwrap:
if wrapper_class:
wrapped_type = type(wrapper_class().value)
return f"Optional[{wrapped_type.__name__}]"

if type_name == "google.protobuf.Timestamp":
return "datetime"
if type_name == "google.protobuf.Duration":
return "timedelta"

if type_name == "google.protobuf.Timestamp":
return "datetime"
elif wrapper_class:
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
return f"{wrapper_class.__name__}"

if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".")
Expand Down Expand Up @@ -379,7 +385,7 @@ def generate_code(request, response):
).strip('"'),
"input_message": input_message,
"output": get_ref_type(
package, output["imports"], method.output_type
package, output["imports"], method.output_type, unwrap=False
).strip('"'),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@ syntax = "proto3";

import "google/protobuf/wrappers.proto";

// Tests that wrapped values can be used directly as return values

service Test {
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetDouble (Input) returns (google.protobuf.DoubleValue);
rpc GetFloat (Input) returns (google.protobuf.FloatValue);
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
rpc GetOutput (Input) returns (Output);
rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value);
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value);
rpc GetBool (Input) returns (google.protobuf.BoolValue);
rpc GetString (Input) returns (google.protobuf.StringValue);
rpc GetBytes (Input) returns (google.protobuf.BytesValue);
}

message Input {

}

message Output {
google.protobuf.Int64Value int64 = 1;
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,53 @@
from typing import Optional
from typing import Any, Callable, Optional

import google.protobuf.wrappers_pb2 as wrappers
import pytest

from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
TestStub
TestStub,
)

test_cases = [
(TestStub.get_double, wrappers.DoubleValue, 2.5),
(TestStub.get_float, wrappers.FloatValue, 2.5),
(TestStub.get_int64, wrappers.Int64Value, -64),
(TestStub.get_u_int64, wrappers.UInt64Value, 64),
(TestStub.get_int32, wrappers.Int32Value, -32),
(TestStub.get_u_int32, wrappers.UInt32Value, 32),
(TestStub.get_bool, wrappers.BoolValue, True),
(TestStub.get_string, wrappers.StringValue, "string"),
(TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]),
]

class TestStubChild(TestStub):
async def _unary_unary(self, route, request, response_type, **kwargs):
self.response_type = response_type

@pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
channel = MockChannel(responses=[wrapped_value])
service = TestStub(channel)

await service_method(service)

assert channel.requests[0]["response_type"] != Optional[type(value)]
assert channel.requests[0]["response_type"] == type(wrapped_value)


@pytest.mark.asyncio
async def test():
pytest.skip("todo")
stub = TestStubChild(None)
await stub.get_int64()
assert stub.response_type != Optional[int]
@pytest.mark.xfail
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_service_unwraps_response(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
service = TestStub(MockChannel(responses=[wrapped_value]))

response_value = await service_method(service)

assert type(response_value) == value
assert type(response_value) == type(value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
syntax = "proto3";

import "google/protobuf/wrappers.proto";

// Tests that wrapped values are supported as part of output message
service Test {
rpc getOutput (Input) returns (Output);
}

message Input {

}

message Output {
google.protobuf.DoubleValue double_value = 1;
google.protobuf.FloatValue float_value = 2;
google.protobuf.Int64Value int64_value = 3;
google.protobuf.UInt64Value uint64_value = 4;
google.protobuf.Int32Value int32_value = 5;
google.protobuf.UInt32Value uint32_value = 6;
google.protobuf.BoolValue bool_value = 7;
google.protobuf.StringValue string_value = 8;
google.protobuf.BytesValue bytes_value = 9;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import (
Output,
TestStub,
)


@pytest.mark.asyncio
async def test_service_passes_through_unwrapped_values_embedded_in_response():
"""
We do not not need to implement value unwrapping for embedded well-known types,
as this is already handled by grpclib. This test merely shows that this is the case.
"""
output = Output(
double_value=10.0,
float_value=12.0,
int64_value=-13,
uint64_value=14,
int32_value=-15,
uint32_value=16,
bool_value=True,
string_value="string",
bytes_value=bytes(0xFF)[0:4],
)

service = TestStub(MockChannel(responses=[output]))
response = await service.get_output()

assert response.double_value == 10.0
assert response.float_value == 12.0
assert response.int64_value == -13
assert response.uint64_value == 14
assert response.int32_value == -15
assert response.uint32_value == 16
assert response.bool_value
assert response.string_value == "string"
assert response.bytes_value == bytes(0xFF)[0:4]
39 changes: 39 additions & 0 deletions betterproto/tests/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List

from grpclib.client import Channel


class MockChannel(Channel):
# noinspection PyMissingConstructor
def __init__(self, responses: List) -> None:
self.responses = responses
self.requests = []

def request(self, route, cardinality, request, response_type, **kwargs):
self.requests.append(
{
"route": route,
"cardinality": cardinality,
"request": request,
"response_type": response_type,
}
)
return MockStream(self.responses)


class MockStream:
def __init__(self, responses: List) -> None:
super().__init__()
self.responses = responses

async def recv_message(self):
return self.responses.pop(0)

async def send_message(self, *args, **kwargs):
pass

async def __aexit__(self, exc_type, exc_val, exc_tb):
return True

async def __aenter__(self):
return self
6 changes: 5 additions & 1 deletion betterproto/tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from google.protobuf.json_format import Parse


excluded_test_cases = {"googletypes_response", "service"}
excluded_test_cases = {
"googletypes_response",
"googletypes_response_embedded",
"service",
}
test_case_names = {*get_directories(inputs_path)} - excluded_test_cases

plugin_output_package = "betterproto.tests.output_betterproto"
Expand Down