@@ -13,20 +13,13 @@ class Message(betterproto.Message):
13
13
body : str = betterproto .string_field (1 )
14
14
15
15
16
- async def to_list (generator : AsyncIterator ):
17
- lis = []
18
- async for value in generator :
19
- lis .append (value )
20
- return lis
21
-
22
-
23
16
@pytest .fixture
24
17
def expected_responses ():
25
18
return [Message ("Hello world 1" ), Message ("Hello world 2" ), Message ("Done" )]
26
19
27
20
28
21
class ClientStub :
29
- async def connect (self , requests ):
22
+ async def connect (self , requests : AsyncIterator ):
30
23
await asyncio .sleep (0.1 )
31
24
async for request in requests :
32
25
await asyncio .sleep (0.1 )
@@ -35,6 +28,13 @@ async def connect(self, requests):
35
28
yield Message ("Done" )
36
29
37
30
31
+ async def to_list (generator : AsyncIterator ):
32
+ lis = []
33
+ async for value in generator :
34
+ lis .append (value )
35
+ return lis
36
+
37
+
38
38
@pytest .fixture
39
39
def client ():
40
40
# channel = Channel(host='127.0.0.1', port=50051)
@@ -122,3 +122,30 @@ async def test_send_from_close_manually_immediately(client, expected_responses):
122
122
requests .close ()
123
123
124
124
assert await to_list (responses ) == expected_responses
125
+
126
+
127
+ @pytest .mark .asyncio
128
+ async def test_send_individually_and_close_before_connect (client , expected_responses ):
129
+ requests = AsyncChannel ()
130
+
131
+ await requests .send (Message (body = "Hello world 1" ))
132
+ await requests .send (Message (body = "Hello world 2" ))
133
+ requests .close ()
134
+
135
+ responses = client .connect (requests )
136
+
137
+ assert await to_list (responses ) == expected_responses
138
+
139
+
140
+ @pytest .mark .asyncio
141
+ async def test_send_individually_and_close_after_connect (client , expected_responses ):
142
+ requests = AsyncChannel ()
143
+
144
+ await requests .send (Message (body = "Hello world 1" ))
145
+ await requests .send (Message (body = "Hello world 2" ))
146
+
147
+ responses = client .connect (requests )
148
+
149
+ requests .close ()
150
+
151
+ assert await to_list (responses ) == expected_responses
0 commit comments