Skip to content

Commit eff2ef8

Browse files
committed
Finish implementation and testing of client
Including stream_unary and stream_stream call methods. Also - improve organisation of relevant tests - fix some generated type annotations - Add AsyncChannel utility cos it's useful
1 parent 09f8219 commit eff2ef8

File tree

10 files changed

+495
-331
lines changed

10 files changed

+495
-331
lines changed

betterproto/grpc/grpclib_client.py

+51-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC
2+
import asyncio
23
import grpclib.const
34
from typing import (
4-
AsyncGenerator,
5+
Any,
56
AsyncIterator,
67
Collection,
78
Iterator,
@@ -16,17 +17,18 @@
1617

1718
if TYPE_CHECKING:
1819
from grpclib._protocols import IProtoMessage
19-
from grpclib.client import Channel
20+
from grpclib.client import Channel, Stream
2021
from grpclib.metadata import Deadline
2122

2223

2324
_Value = Union[str, bytes]
2425
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
26+
_MessageSource = Union[Iterator["IProtoMessage"], AsyncIterator["IProtoMessage"]]
2527

2628

2729
class ServiceStub(ABC):
2830
"""
29-
Base class for async gRPC service stubs.
31+
Base class for async gRPC clients.
3032
"""
3133

3234
def __init__(
@@ -86,7 +88,7 @@ async def _unary_stream(
8688
timeout: Optional[float] = None,
8789
deadline: Optional["Deadline"] = None,
8890
metadata: Optional[_MetadataLike] = None,
89-
) -> AsyncGenerator[T, None]:
91+
) -> AsyncIterator[T]:
9092
"""Make a unary request and return the stream response iterator."""
9193
async with self.channel.request(
9294
route,
@@ -102,34 +104,66 @@ async def _unary_stream(
102104
async def _stream_unary(
103105
self,
104106
route: str,
105-
request_iterator: Iterator["IProtoMessage"],
107+
request_iterator: _MessageSource,
106108
request_type: Type[ST],
107109
response_type: Type[T],
110+
*,
111+
timeout: Optional[float] = None,
112+
deadline: Optional["Deadline"] = None,
113+
metadata: Optional[_MetadataLike] = None,
108114
) -> T:
109115
"""Make a stream request and return the response."""
110116
async with self.channel.request(
111-
route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type
117+
route,
118+
grpclib.const.Cardinality.STREAM_UNARY,
119+
request_type,
120+
response_type,
121+
**self.__resolve_request_kwargs(timeout, deadline, metadata),
112122
) as stream:
113-
for message in request_iterator:
114-
await stream.send_message(message)
115-
await stream.send_request(end=True)
123+
await self._send_messages(stream, request_iterator)
116124
response = await stream.recv_message()
117125
assert response is not None
118126
return response
119127

120128
async def _stream_stream(
121129
self,
122130
route: str,
123-
request_iterator: Iterator["IProtoMessage"],
131+
request_iterator: _MessageSource,
124132
request_type: Type[ST],
125133
response_type: Type[T],
126-
) -> AsyncGenerator[T, None]:
127-
"""Make a stream request and return the stream response iterator."""
134+
*,
135+
timeout: Optional[float] = None,
136+
deadline: Optional["Deadline"] = None,
137+
metadata: Optional[_MetadataLike] = None,
138+
) -> AsyncIterator[T]:
139+
"""
140+
Make a stream request and return an AsyncIterator to iterate over response
141+
messages.
142+
"""
128143
async with self.channel.request(
129-
route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type
144+
route,
145+
grpclib.const.Cardinality.STREAM_STREAM,
146+
request_type,
147+
response_type,
148+
**self.__resolve_request_kwargs(timeout, deadline, metadata),
130149
) as stream:
131-
for message in request_iterator:
150+
await stream.send_request()
151+
sending_task = asyncio.ensure_future(
152+
self._send_messages(stream, request_iterator)
153+
)
154+
try:
155+
async for response in stream:
156+
yield response
157+
except:
158+
sending_task.cancel()
159+
raise
160+
161+
@staticmethod
162+
async def _send_messages(stream, messages: _MessageSource):
163+
if hasattr(messages, "__aiter__"):
164+
async for message in messages:
132165
await stream.send_message(message)
133-
await stream.send_request(end=True)
134-
async for message in stream:
135-
yield message
166+
else:
167+
for message in messages:
168+
await stream.send_message(message)
169+
await stream.end()

betterproto/grpc/util/__init__.py

Whitespace-only changes.
+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import asyncio
2+
from typing import (
3+
AsyncIterable,
4+
AsyncIterator,
5+
Iterable,
6+
Optional,
7+
TypeVar,
8+
Union,
9+
)
10+
11+
T = TypeVar("T")
12+
13+
14+
class ChannelClosed(Exception):
15+
"""
16+
An exception raised on an attempt to send through a closed channel
17+
"""
18+
19+
pass
20+
21+
22+
class ChannelDone(Exception):
23+
"""
24+
An exception raised on an attempt to send recieve from a channel that is both closed
25+
and empty.
26+
"""
27+
28+
pass
29+
30+
31+
class AsyncChannel(AsyncIterable[T]):
32+
"""
33+
A buffered async channel for sending items between coroutines with FIFO semantics.
34+
35+
This makes decoupled bidirection steaming gRPC requests easy if used like:
36+
37+
.. code-block:: python
38+
client = GeneratedStub(grpclib_chan)
39+
# The channel can be initialised with items to send immediately
40+
request_chan = AsyncChannel([ReqestObject(...), ReqestObject(...)])
41+
async for response in client.rpc_call(request_chan):
42+
# The response iterator will remain active until the connection is closed
43+
...
44+
# More items can be sent at any time
45+
await request_chan.send(ReqestObject(...))
46+
...
47+
# The channel must be closed to complete the gRPC connection
48+
request_chan.close()
49+
50+
Items can be sent through the channel by either:
51+
- providing an iterable to the constructor
52+
- providing an iterable to the send_from method
53+
- passing them to the send method one at a time
54+
55+
Items can be recieved from the channel by either:
56+
- iterating over the channel with a for loop to get all items
57+
- calling the recieve method to get one item at a time
58+
59+
If the channel is empty then recievers will wait until either an item appears or the
60+
channel is closed.
61+
62+
Once the channel is closed then subsequent attempt to send through the channel will
63+
fail with a ChannelClosed exception.
64+
65+
When th channel is closed and empty then it is done, and further attempts to recieve
66+
from it will fail with a ChannelDone exception
67+
68+
If multiple coroutines recieve from the channel concurrently, each item sent will be
69+
recieved by only one of the recievers.
70+
71+
:param source:
72+
An optional iterable will items that should be sent through the channel
73+
immediately.
74+
:param buffer_limit:
75+
Limit the number of items that can be buffered in the channel, A value less than
76+
1 implies no limit. If the channel is full then attempts to send more items will
77+
result in the sender waiting until an item is recieved from the channel.
78+
:param close:
79+
If set to True then the channel will automatically close after exhausting source
80+
or immediately if no source is provided.
81+
"""
82+
83+
def __init__(
84+
self,
85+
source: Union[Iterable[T], AsyncIterable[T]] = tuple(),
86+
*,
87+
buffer_limit: int = 0,
88+
close: bool = False,
89+
):
90+
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
91+
self._closed = False
92+
self._sending_task = (
93+
asyncio.ensure_future(self.send_from(source, close)) if source else None
94+
)
95+
self._waiting_recievers: int = 0
96+
# Track whether flush has been invoked so it can only happen once
97+
self._flushed = False
98+
99+
def __aiter__(self) -> AsyncIterator[T]:
100+
return self
101+
102+
async def __anext__(self) -> T:
103+
if self.done:
104+
raise StopAsyncIteration
105+
self._waiting_recievers += 1
106+
try:
107+
result = await self._queue.get()
108+
if result is self.__flush:
109+
raise StopAsyncIteration
110+
finally:
111+
self._waiting_recievers -= 1
112+
self._queue.task_done()
113+
114+
def closed(self) -> bool:
115+
"""
116+
Returns True if this channel is closed and no-longer accepting new items
117+
"""
118+
return self._closed
119+
120+
def done(self) -> bool:
121+
"""
122+
Check if this channel is done.
123+
124+
:return: True if this channel is closed and and has been drained of items in
125+
which case any further attempts to recieve an item from this channel will raise
126+
a ChannelDone exception.
127+
"""
128+
# After close the channel is not yet done until there is at least one waiting
129+
# reciever per enqueued item.
130+
return self._closed and self._queue.qsize() <= self._waiting_recievers
131+
132+
async def send_from(
133+
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
134+
):
135+
"""
136+
Iterates the given [Async]Iterable and sends all the resulting items.
137+
If close is set to True then subsequent send calls will be rejected with a
138+
ChannelClosed exception.
139+
:param source: an iterable of items to send
140+
:param close:
141+
if True then the channel will be closed after the source has been exhausted
142+
143+
"""
144+
if self._closed:
145+
raise ChannelClosed("Cannot send through a closed channel")
146+
if isinstance(source, AsyncIterable):
147+
async for item in source:
148+
await self._queue.put(item)
149+
else:
150+
for item in source:
151+
await self._queue.put(item)
152+
if close:
153+
# Complete the closing process
154+
await self.close()
155+
156+
async def send(self, item: T):
157+
"""
158+
Send a single item over this channel.
159+
:param item: The item to send
160+
"""
161+
if self._closed:
162+
raise ChannelClosed("Cannot send through a closed channel")
163+
await self._queue.put(item)
164+
165+
async def recieve(self) -> Optional[T]:
166+
"""
167+
Returns the next item from this channel when it becomes available,
168+
or None if the channel is closed before another item is sent.
169+
:return: An item from the channel
170+
"""
171+
if self.done:
172+
raise ChannelDone("Cannot recieve from a closed channel")
173+
self._waiting_recievers += 1
174+
try:
175+
result = await self._queue.get()
176+
if result is self.__flush:
177+
return None
178+
return result
179+
finally:
180+
self._waiting_recievers -= 1
181+
self._queue.task_done()
182+
183+
def close(self):
184+
"""
185+
Close this channel to new items
186+
"""
187+
if self._sending_task is not None:
188+
self._sending_task.cancel()
189+
self._closed = True
190+
asyncio.ensure_future(self._flush_queue())
191+
192+
async def _flush_queue(self):
193+
"""
194+
To be called after the channel is closed. Pushes a number of self.__flush
195+
objects to the queue to ensure no waiting consumers get deadlocked.
196+
"""
197+
if not self._flushed:
198+
self._flushed = True
199+
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
200+
for _ in range(deadlocked_recievers):
201+
await self._queue.put(self.__flush)
202+
203+
# A special signal object for flushing the queue when the channel is closed
204+
__flush = object()

betterproto/plugin.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,12 @@ def generate_code(request, response):
344344
}
345345
)
346346

347-
if method.server_streaming:
348-
output["typing_imports"].add("AsyncGenerator")
349-
350347
if method.client_streaming:
351-
output["typing_imports"].add("Iterator")
348+
output["typing_imports"].add("AsyncIterable")
349+
output["typing_imports"].add("Iterable")
350+
output["typing_imports"].add("Union")
351+
if method.server_streaming:
352+
output["typing_imports"].add("AsyncIterator")
352353

353354
output["services"].append(data)
354355

betterproto/templates/template.py.j2

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
7777
{%- endif -%}
7878
{%- else -%}
7979
{# Client streaming: need a request iterator instead #}
80-
, request_iterator: Iterator["{{ method.input }}"]
80+
, request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]]
8181
{%- endif -%}
82-
) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
82+
) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}:
8383
{% if method.comment %}
8484
{{ method.comment }}
8585

betterproto/tests/grpc/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)