Skip to content

Commit 50bb67b

Browse files
committed
Fix bugs and remove footgun feature in AsyncChannel
1 parent c8229e5 commit 50bb67b

File tree

2 files changed

+22
-24
lines changed

2 files changed

+22
-24
lines changed

betterproto/grpc/util/async_channel.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]):
8181
"""
8282

8383
def __init__(
84-
self,
85-
source: Union[Iterable[T], AsyncIterable[T]] = tuple(),
86-
*,
87-
buffer_limit: int = 0,
88-
close: bool = False,
84+
self, *, buffer_limit: int = 0, close: bool = False,
8985
):
9086
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
9187
self._closed = False
92-
self._sending_task = (
93-
asyncio.ensure_future(self.send_from(source, close)) if source else None
94-
)
9588
self._waiting_recievers: int = 0
9689
# Track whether flush has been invoked so it can only happen once
9790
self._flushed = False
@@ -100,13 +93,14 @@ def __aiter__(self) -> AsyncIterator[T]:
10093
return self
10194

10295
async def __anext__(self) -> T:
103-
if self.done:
96+
if self.done():
10497
raise StopAsyncIteration
10598
self._waiting_recievers += 1
10699
try:
107100
result = await self._queue.get()
108101
if result is self.__flush:
109102
raise StopAsyncIteration
103+
return result
110104
finally:
111105
self._waiting_recievers -= 1
112106
self._queue.task_done()
@@ -131,7 +125,7 @@ def done(self) -> bool:
131125

132126
async def send_from(
133127
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
134-
):
128+
) -> "AsyncChannel[T]":
135129
"""
136130
Iterates the given [Async]Iterable and sends all the resulting items.
137131
If close is set to True then subsequent send calls will be rejected with a
@@ -151,24 +145,26 @@ async def send_from(
151145
await self._queue.put(item)
152146
if close:
153147
# Complete the closing process
154-
await self.close()
148+
self.close()
149+
return self
155150

156-
async def send(self, item: T):
151+
async def send(self, item: T) -> "AsyncChannel[T]":
157152
"""
158153
Send a single item over this channel.
159154
:param item: The item to send
160155
"""
161156
if self._closed:
162157
raise ChannelClosed("Cannot send through a closed channel")
163158
await self._queue.put(item)
159+
return self
164160

165161
async def recieve(self) -> Optional[T]:
166162
"""
167163
Returns the next item from this channel when it becomes available,
168164
or None if the channel is closed before another item is sent.
169165
:return: An item from the channel
170166
"""
171-
if self.done:
167+
if self.done():
172168
raise ChannelDone("Cannot recieve from a closed channel")
173169
self._waiting_recievers += 1
174170
try:
@@ -184,8 +180,6 @@ def close(self):
184180
"""
185181
Close this channel to new items
186182
"""
187-
if self._sending_task is not None:
188-
self._sending_task.cancel()
189183
self._closed = True
190184
asyncio.ensure_future(self._flush_queue())
191185

betterproto/tests/grpc/test_grpclib_client.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from betterproto.tests.output_betterproto.service.service import (
23
DoThingResponse,
34
DoThingRequest,
@@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request():
129130
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
130131
# immediately and we'll use it to send more_things later, after recieving some
131132
# results
132-
request_chan = AsyncChannel(GetThingRequest(name) for name in some_things)
133+
request_chan = AsyncChannel()
134+
send_initial_requests = asyncio.ensure_future(
135+
request_chan.send_from(GetThingRequest(name) for name in some_things)
136+
)
133137
response_index = 0
134138
async for response in client.get_different_things(request_chan):
135139
assert response.name == expected_things[response_index]
@@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request():
138142
if more_things:
139143
# Send some more requests as we recieve reponses to be sure coordination of
140144
# send/recieve events doesn't matter
141-
another_response = await request_chan.send(
142-
GetThingRequest(more_things.pop(0))
143-
)
144-
if another_response is not None:
145-
assert another_response.name == expected_things[response_index]
146-
assert another_response.version == response_index
147-
response_index += 1
145+
await request_chan.send(GetThingRequest(more_things.pop(0)))
146+
elif not send_initial_requests.done():
147+
# Make sure the sending task it completed
148+
await send_initial_requests
148149
else:
149150
# No more things to send make sure channel is closed
150-
await request_chan.close()
151+
request_chan.close()
152+
assert response_index == len(
153+
expected_things
154+
), "Didn't recieve all exptected responses"

0 commit comments

Comments
 (0)