Skip to content

Commit 4e78fe9

Browse files
authored
Merge branch 'client-streaming' into client-streaming-tests
2 parents 0814729 + 50bb67b commit 4e78fe9

File tree

2 files changed

+18
-21
lines changed

2 files changed

+18
-21
lines changed

betterproto/grpc/util/async_channel.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]):
8181
"""
8282

8383
def __init__(
84-
self,
85-
source: Union[Iterable[T], AsyncIterable[T]] = tuple(),
86-
*,
87-
buffer_limit: int = 0,
88-
close: bool = False,
84+
self, *, buffer_limit: int = 0, close: bool = False,
8985
):
9086
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
9187
self._closed = False
92-
self._sending_task = (
93-
asyncio.ensure_future(self.send_from(source, close)) if source else None
94-
)
9588
self._waiting_recievers: int = 0
9689
# Track whether flush has been invoked so it can only happen once
9790
self._flushed = False
@@ -132,7 +125,7 @@ def done(self) -> bool:
132125

133126
async def send_from(
134127
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
135-
):
128+
) -> "AsyncChannel[T]":
136129
"""
137130
Iterates the given [Async]Iterable and sends all the resulting items.
138131
If close is set to True then subsequent send calls will be rejected with a
@@ -153,15 +146,17 @@ async def send_from(
153146
if close:
154147
# Complete the closing process
155148
self.close()
149+
return self
156150

157-
async def send(self, item: T):
151+
async def send(self, item: T) -> "AsyncChannel[T]":
158152
"""
159153
Send a single item over this channel.
160154
:param item: The item to send
161155
"""
162156
if self._closed:
163157
raise ChannelClosed("Cannot send through a closed channel")
164158
await self._queue.put(item)
159+
return self
165160

166161
async def recieve(self) -> Optional[T]:
167162
"""
@@ -185,8 +180,6 @@ def close(self):
185180
"""
186181
Close this channel to new items
187182
"""
188-
if self._sending_task is not None:
189-
self._sending_task.cancel()
190183
self._closed = True
191184
asyncio.ensure_future(self._flush_queue())
192185

betterproto/tests/grpc/test_grpclib_client.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from betterproto.tests.output_betterproto.service.service import (
23
DoThingResponse,
34
DoThingRequest,
@@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request():
129130
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
130131
# immediately and we'll use it to send more_things later, after recieving some
131132
# results
132-
request_chan = AsyncChannel(GetThingRequest(name) for name in some_things)
133+
request_chan = AsyncChannel()
134+
send_initial_requests = asyncio.ensure_future(
135+
request_chan.send_from(GetThingRequest(name) for name in some_things)
136+
)
133137
response_index = 0
134138
async for response in client.get_different_things(request_chan):
135139
assert response.name == expected_things[response_index]
@@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request():
138142
if more_things:
139143
# Send some more requests as we recieve reponses to be sure coordination of
140144
# send/recieve events doesn't matter
141-
another_response = await request_chan.send(
142-
GetThingRequest(more_things.pop(0))
143-
)
144-
if another_response is not None:
145-
assert another_response.name == expected_things[response_index]
146-
assert another_response.version == response_index
147-
response_index += 1
145+
await request_chan.send(GetThingRequest(more_things.pop(0)))
146+
elif not send_initial_requests.done():
147+
# Make sure the sending task it completed
148+
await send_initial_requests
148149
else:
149150
# No more things to send make sure channel is closed
150-
await request_chan.close()
151+
request_chan.close()
152+
assert response_index == len(
153+
expected_things
154+
), "Didn't recieve all exptected responses"

0 commit comments

Comments
 (0)