Skip to content

Commit d260f07

Browse files
authored
Client and Service Stubs take 1 request parameter, not one for each field (#311)
1 parent 6dd7baa commit d260f07

File tree

13 files changed

+140
-193
lines changed

13 files changed

+140
-193
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ output
1717
.venv
1818
.asv
1919
venv
20+
.devcontainer

CHANGELOG.md

+26
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
99

10+
## [Unreleased]
11+
12+
- **Breaking**: Client and Service Stubs no longer pack and unpack the input message fields as parameters.
13+
14+
Update your client calls and server handlers as follows:
15+
16+
Clients before:
17+
```py
18+
response = await service.echo(value="hello", extra_times=1)
19+
```
20+
Clients after:
21+
```py
22+
response = await service.echo(EchoRequest(value="hello", extra_times=1))
23+
```
24+
Servers before:
25+
```py
26+
async def echo(self, value: str, extra_times: int) -> EchoResponse:
27+
```
28+
Servers after:
29+
```py
30+
async def echo(self, echo_request: EchoRequest) -> EchoResponse:
31+
# Use echo_request.value
32+
# Use echo_request.extra_times
33+
```
34+
35+
1036
## [2.0.0b4] - 2022-01-03
1137

1238
- **Breaking**: the minimum Python version has been bumped to `3.6.2`

README.md

+8-8
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,10 @@ from grpclib.client import Channel
177177
async def main():
178178
channel = Channel(host="127.0.0.1", port=50051)
179179
service = echo.EchoStub(channel)
180-
response = await service.echo(value="hello", extra_times=1)
180+
response = await service.echo(echo.EchoRequest(value="hello", extra_times=1))
181181
print(response)
182182

183-
async for response in service.echo_stream(value="hello", extra_times=1):
183+
async for response in service.echo_stream(echo.EchoRequest(value="hello", extra_times=1)):
184184
print(response)
185185

186186
# don't forget to close the channel when done!
@@ -206,18 +206,18 @@ service methods:
206206

207207
```python
208208
import asyncio
209-
from echo import EchoBase, EchoResponse, EchoStreamResponse
209+
from echo import EchoBase, EchoRequest, EchoResponse, EchoStreamResponse
210210
from grpclib.server import Server
211211
from typing import AsyncIterator
212212

213213

214214
class EchoService(EchoBase):
215-
async def echo(self, value: str, extra_times: int) -> "EchoResponse":
216-
return EchoResponse([value for _ in range(extra_times)])
215+
async def echo(self, echo_request: "EchoRequest") -> "EchoResponse":
216+
return EchoResponse([echo_request.value for _ in range(echo_request.extra_times)])
217217

218-
async def echo_stream(self, value: str, extra_times: int) -> AsyncIterator["EchoStreamResponse"]:
219-
for _ in range(extra_times):
220-
yield EchoStreamResponse(value)
218+
async def echo_stream(self, echo_request: "EchoRequest") -> AsyncIterator["EchoStreamResponse"]:
219+
for _ in range(echo_request.extra_times):
220+
yield EchoStreamResponse(echo_request.value)
221221

222222

223223
async def main():

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ omit = ["betterproto/tests/*"]
111111
legacy_tox_ini = """
112112
[tox]
113113
isolated_build = true
114-
envlist = py36, py37, py38
114+
envlist = py36, py37, py38, py310
115115
116116
[testenv]
117117
whitelist_externals = poetry

src/betterproto/grpc/grpclib_server.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC
22
from collections.abc import AsyncIterable
3-
from typing import Callable, Any, Dict
3+
from typing import Any, Callable, Dict
44

55
import grpclib
66
import grpclib.server
@@ -15,10 +15,10 @@ async def _call_rpc_handler_server_stream(
1515
self,
1616
handler: Callable,
1717
stream: grpclib.server.Stream,
18-
request_kwargs: Dict[str, Any],
18+
request: Any,
1919
) -> None:
2020

21-
response_iter = handler(**request_kwargs)
21+
response_iter = handler(request)
2222
# check if response is actually an AsyncIterator
2323
# this might be false if the method just returns without
2424
# yielding at least once

src/betterproto/plugin/models.py

+20-49
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@
3131

3232

3333
import builtins
34+
import re
35+
import textwrap
36+
from dataclasses import dataclass, field
37+
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
38+
3439
import betterproto
3540
from betterproto import which_one_of
3641
from betterproto.casing import sanitize_name
37-
from betterproto.compile.importing import (
38-
get_type_reference,
39-
parse_source_type_name,
40-
)
42+
from betterproto.compile.importing import get_type_reference, parse_source_type_name
4143
from betterproto.compile.naming import (
4244
pythonize_class_name,
4345
pythonize_field_name,
@@ -46,21 +48,15 @@
4648
from betterproto.lib.google.protobuf import (
4749
DescriptorProto,
4850
EnumDescriptorProto,
49-
FileDescriptorProto,
50-
MethodDescriptorProto,
5151
Field,
5252
FieldDescriptorProto,
53-
FieldDescriptorProtoType,
5453
FieldDescriptorProtoLabel,
54+
FieldDescriptorProtoType,
55+
FileDescriptorProto,
56+
MethodDescriptorProto,
5557
)
5658
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
5759

58-
59-
import re
60-
import textwrap
61-
from dataclasses import dataclass, field
62-
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
63-
6460
from ..casing import sanitize_name
6561
from ..compile.importing import get_type_reference, parse_source_type_name
6662
from ..compile.naming import (
@@ -69,7 +65,6 @@
6965
pythonize_method_name,
7066
)
7167

72-
7368
# Create a unique placeholder to deal with
7469
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
7570
PLACEHOLDER = object()
@@ -675,12 +670,8 @@ def __post_init__(self) -> None:
675670
self.parent.methods.append(self)
676671

677672
# Check for imports
678-
if self.py_input_message:
679-
for f in self.py_input_message.fields:
680-
f.add_imports_to(self.output_file)
681673
if "Optional" in self.py_output_message_type:
682674
self.output_file.typing_imports.add("Optional")
683-
self.mutable_default_args # ensure this is called before rendering
684675

685676
# Check for Async imports
686677
if self.client_streaming:
@@ -694,37 +685,6 @@ def __post_init__(self) -> None:
694685

695686
super().__post_init__() # check for unset fields
696687

697-
@property
698-
def mutable_default_args(self) -> Dict[str, str]:
699-
"""Handle mutable default arguments.
700-
701-
Returns a list of tuples containing the name and default value
702-
for arguments to this message who's default value is mutable.
703-
The defaults are swapped out for None and replaced back inside
704-
the method's body.
705-
Reference:
706-
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
707-
708-
Returns
709-
-------
710-
Dict[str, str]
711-
Name and actual default value (as a string)
712-
for each argument with mutable default values.
713-
"""
714-
mutable_default_args = {}
715-
716-
if self.py_input_message:
717-
for f in self.py_input_message.fields:
718-
if (
719-
not self.client_streaming
720-
and f.default_value_string != "None"
721-
and f.mutable
722-
):
723-
mutable_default_args[f.py_name] = f.default_value_string
724-
self.output_file.typing_imports.add("Optional")
725-
726-
return mutable_default_args
727-
728688
@property
729689
def py_name(self) -> str:
730690
"""Pythonized method name."""
@@ -782,6 +742,17 @@ def py_input_message_type(self) -> str:
782742
source_type=self.proto_obj.input_type,
783743
).strip('"')
784744

745+
@property
746+
def py_input_message_param(self) -> str:
747+
"""Param name corresponding to py_input_message_type.
748+
749+
Returns
750+
-------
751+
str
752+
Param name corresponding to py_input_message_type.
753+
"""
754+
return pythonize_field_name(self.py_input_message_type)
755+
785756
@property
786757
def py_output_message_type(self) -> str:
787758
"""String representation of the Python type corresponding to the

src/betterproto/templates/template.py.j2

+11-58
Original file line numberDiff line numberDiff line change
@@ -79,59 +79,29 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
7979
{% for method in service.methods %}
8080
async def {{ method.py_name }}(self
8181
{%- if not method.client_streaming -%}
82-
{%- if method.py_input_message and method.py_input_message.fields -%}, *,
83-
{%- for field in method.py_input_message.fields -%}
84-
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
85-
Optional[{{ field.annotation }}]
86-
{%- else -%}
87-
{{ field.annotation }}
88-
{%- endif -%} =
89-
{%- if field.py_name not in method.mutable_default_args -%}
90-
{{ field.default_value_string }}
91-
{%- else -%}
92-
None
93-
{% endif -%}
94-
{%- if not loop.last %}, {% endif -%}
95-
{%- endfor -%}
96-
{%- endif -%}
82+
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
9783
{%- else -%}
9884
{# Client streaming: need a request iterator instead #}
99-
, request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
85+
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
10086
{%- endif -%}
10187
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
10288
{% if method.comment %}
10389
{{ method.comment }}
10490

10591
{% endif %}
106-
{%- for py_name, zero in method.mutable_default_args.items() %}
107-
{{ py_name }} = {{ py_name }} or {{ zero }}
108-
{% endfor %}
109-
110-
{% if not method.client_streaming %}
111-
request = {{ method.py_input_message_type }}()
112-
{% for field in method.py_input_message.fields %}
113-
{% if field.field_type == 'message' %}
114-
if {{ field.py_name }} is not None:
115-
request.{{ field.py_name }} = {{ field.py_name }}
116-
{% else %}
117-
request.{{ field.py_name }} = {{ field.py_name }}
118-
{% endif %}
119-
{% endfor %}
120-
{% endif %}
121-
12292
{% if method.server_streaming %}
12393
{% if method.client_streaming %}
12494
async for response in self._stream_stream(
12595
"{{ method.route }}",
126-
request_iterator,
96+
{{ method.py_input_message_param }}_iterator,
12797
{{ method.py_input_message_type }},
12898
{{ method.py_output_message_type.strip('"') }},
12999
):
130100
yield response
131101
{% else %}{# i.e. not client streaming #}
132102
async for response in self._unary_stream(
133103
"{{ method.route }}",
134-
request,
104+
{{ method.py_input_message_param }},
135105
{{ method.py_output_message_type.strip('"') }},
136106
):
137107
yield response
@@ -141,14 +111,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
141111
{% if method.client_streaming %}
142112
return await self._stream_unary(
143113
"{{ method.route }}",
144-
request_iterator,
114+
{{ method.py_input_message_param }}_iterator,
145115
{{ method.py_input_message_type }},
146116
{{ method.py_output_message_type.strip('"') }}
147117
)
148118
{% else %}{# i.e. not client streaming #}
149119
return await self._unary_unary(
150120
"{{ method.route }}",
151-
request,
121+
{{ method.py_input_message_param }},
152122
{{ method.py_output_message_type.strip('"') }}
153123
)
154124
{% endif %}{# client streaming #}
@@ -167,19 +137,10 @@ class {{ service.py_name }}Base(ServiceBase):
167137
{% for method in service.methods %}
168138
async def {{ method.py_name }}(self
169139
{%- if not method.client_streaming -%}
170-
{%- if method.py_input_message and method.py_input_message.fields -%},
171-
{%- for field in method.py_input_message.fields -%}
172-
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
173-
Optional[{{ field.annotation }}]
174-
{%- else -%}
175-
{{ field.annotation }}
176-
{%- endif -%}
177-
{%- if not loop.last %}, {% endif -%}
178-
{%- endfor -%}
179-
{%- endif -%}
140+
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
180141
{%- else -%}
181142
{# Client streaming: need a request iterator instead #}
182-
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
143+
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
183144
{%- endif -%}
184145
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
185146
{% if method.comment %}
@@ -194,25 +155,17 @@ class {{ service.py_name }}Base(ServiceBase):
194155
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
195156
{% if not method.client_streaming %}
196157
request = await stream.recv_message()
197-
198-
request_kwargs = {
199-
{% for field in method.py_input_message.fields %}
200-
"{{ field.py_name }}": request.{{ field.py_name }},
201-
{% endfor %}
202-
}
203-
204158
{% else %}
205-
request_kwargs = {"request_iterator": stream.__aiter__()}
159+
request = stream.__aiter__()
206160
{% endif %}
207-
208161
{% if not method.server_streaming %}
209-
response = await self.{{ method.py_name }}(**request_kwargs)
162+
response = await self.{{ method.py_name }}(request)
210163
await stream.send_message(response)
211164
{% else %}
212165
await self._call_rpc_handler_server_stream(
213166
self.{{ method.py_name }},
214167
stream,
215-
request_kwargs,
168+
request,
216169
)
217170
{% endif %}
218171

0 commit comments

Comments
 (0)