Skip to content

Commit c3db325

Browse files
authored
Merge pull request #1 from boukeversteegh/client-streaming-tests
Client streaming tests
2 parents e1ccd54 + 4e78fe9 commit c3db325

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import asyncio
2+
from dataclasses import dataclass
3+
from typing import AsyncIterator
4+
5+
import pytest
6+
7+
import betterproto
8+
from betterproto.grpc.util.async_channel import AsyncChannel
9+
10+
11+
@dataclass
12+
class Message(betterproto.Message):
13+
body: str = betterproto.string_field(1)
14+
15+
16+
@pytest.fixture
17+
def expected_responses():
18+
return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
19+
20+
21+
class ClientStub:
22+
async def connect(self, requests: AsyncIterator):
23+
await asyncio.sleep(0.1)
24+
async for request in requests:
25+
await asyncio.sleep(0.1)
26+
yield request
27+
await asyncio.sleep(0.1)
28+
yield Message("Done")
29+
30+
31+
async def to_list(generator: AsyncIterator):
32+
lis = []
33+
async for value in generator:
34+
lis.append(value)
35+
return lis
36+
37+
38+
@pytest.fixture
39+
def client():
40+
# channel = Channel(host='127.0.0.1', port=50051)
41+
# return ClientStub(channel)
42+
return ClientStub()
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_from_list_close_automatically(client, expected_responses):
47+
requests = AsyncChannel(
48+
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
49+
)
50+
51+
responses = client.connect(requests)
52+
53+
assert await to_list(responses) == expected_responses
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_from_list_close_manually_immediately(client, expected_responses):
58+
requests = AsyncChannel(
59+
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
60+
)
61+
62+
requests.close()
63+
64+
responses = client.connect(requests)
65+
66+
assert await to_list(responses) == expected_responses
67+
68+
69+
@pytest.mark.asyncio
70+
async def test_from_list_close_manually_after_connect(client, expected_responses):
71+
requests = AsyncChannel(
72+
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
73+
)
74+
75+
responses = client.connect(requests)
76+
77+
requests.close()
78+
79+
assert await to_list(responses) == expected_responses
80+
81+
82+
@pytest.mark.asyncio
83+
async def test_send_from_before_connect_and_close_automatically(
84+
client, expected_responses
85+
):
86+
requests = AsyncChannel()
87+
88+
await requests.send_from(
89+
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
90+
)
91+
92+
responses = client.connect(requests)
93+
94+
assert await to_list(responses) == expected_responses
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_send_from_after_connect_and_close_automatically(
99+
client, expected_responses
100+
):
101+
requests = AsyncChannel()
102+
103+
responses = client.connect(requests)
104+
105+
await requests.send_from(
106+
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
107+
)
108+
109+
assert await to_list(responses) == expected_responses
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_send_from_close_manually_immediately(client, expected_responses):
114+
requests = AsyncChannel()
115+
116+
responses = client.connect(requests)
117+
118+
await requests.send_from(
119+
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
120+
)
121+
122+
requests.close()
123+
124+
assert await to_list(responses) == expected_responses
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_send_individually_and_close_before_connect(client, expected_responses):
129+
requests = AsyncChannel()
130+
131+
await requests.send(Message(body="Hello world 1"))
132+
await requests.send(Message(body="Hello world 2"))
133+
requests.close()
134+
135+
responses = client.connect(requests)
136+
137+
assert await to_list(responses) == expected_responses
138+
139+
140+
@pytest.mark.asyncio
141+
async def test_send_individually_and_close_after_connect(client, expected_responses):
142+
requests = AsyncChannel()
143+
144+
await requests.send(Message(body="Hello world 1"))
145+
await requests.send(Message(body="Hello world 2"))
146+
147+
responses = client.connect(requests)
148+
149+
requests.close()
150+
151+
assert await to_list(responses) == expected_responses

0 commit comments

Comments
 (0)