Skip to content

Commit 179d934

Browse files
TechNiickScirlat DanutKludex
authored
Add type hints to test_websockets.py (#2494)
Co-authored-by: Scirlat Danut <[email protected]> Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 7cc2ec0 commit 179d934

File tree

1 file changed

+47
-45
lines changed

1 file changed

+47
-45
lines changed

tests/test_websockets.py

+47-45
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from starlette.types import Message, Receive, Scope, Send
1212
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
1313

14+
TestClientFactory = Callable[..., TestClient]
1415

15-
def test_websocket_url(test_client_factory: Callable[..., TestClient]):
16+
17+
def test_websocket_url(test_client_factory: TestClientFactory) -> None:
1618
async def app(scope: Scope, receive: Receive, send: Send) -> None:
1719
websocket = WebSocket(scope, receive=receive, send=send)
1820
await websocket.accept()
@@ -25,7 +27,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
2527
assert data == {"url": "ws://testserver/123?a=abc"}
2628

2729

28-
def test_websocket_binary_json(test_client_factory: Callable[..., TestClient]):
30+
def test_websocket_binary_json(test_client_factory: TestClientFactory) -> None:
2931
async def app(scope: Scope, receive: Receive, send: Send) -> None:
3032
websocket = WebSocket(scope, receive=receive, send=send)
3133
await websocket.accept()
@@ -41,8 +43,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
4143

4244

4345
def test_websocket_ensure_unicode_on_send_json(
44-
test_client_factory: Callable[..., TestClient],
45-
):
46+
test_client_factory: TestClientFactory,
47+
) -> None:
4648
async def app(scope: Scope, receive: Receive, send: Send) -> None:
4749
websocket = WebSocket(scope, receive=receive, send=send)
4850

@@ -58,7 +60,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
5860
assert data == '{"test":"数据"}'
5961

6062

61-
def test_websocket_query_params(test_client_factory: Callable[..., TestClient]):
63+
def test_websocket_query_params(test_client_factory: TestClientFactory) -> None:
6264
async def app(scope: Scope, receive: Receive, send: Send) -> None:
6365
websocket = WebSocket(scope, receive=receive, send=send)
6466
query_params = dict(websocket.query_params)
@@ -76,7 +78,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
7678
any(module in sys.modules for module in ("brotli", "brotlicffi")),
7779
reason='urllib3 includes "br" to the "accept-encoding" headers.',
7880
)
79-
def test_websocket_headers(test_client_factory: Callable[..., TestClient]):
81+
def test_websocket_headers(test_client_factory: TestClientFactory) -> None:
8082
async def app(scope: Scope, receive: Receive, send: Send) -> None:
8183
websocket = WebSocket(scope, receive=receive, send=send)
8284
headers = dict(websocket.headers)
@@ -99,7 +101,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
99101
assert data == {"headers": expected_headers}
100102

101103

102-
def test_websocket_port(test_client_factory: Callable[..., TestClient]):
104+
def test_websocket_port(test_client_factory: TestClientFactory) -> None:
103105
async def app(scope: Scope, receive: Receive, send: Send) -> None:
104106
websocket = WebSocket(scope, receive=receive, send=send)
105107
await websocket.accept()
@@ -113,8 +115,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
113115

114116

115117
def test_websocket_send_and_receive_text(
116-
test_client_factory: Callable[..., TestClient],
117-
):
118+
test_client_factory: TestClientFactory,
119+
) -> None:
118120
async def app(scope: Scope, receive: Receive, send: Send) -> None:
119121
websocket = WebSocket(scope, receive=receive, send=send)
120122
await websocket.accept()
@@ -130,8 +132,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
130132

131133

132134
def test_websocket_send_and_receive_bytes(
133-
test_client_factory: Callable[..., TestClient],
134-
):
135+
test_client_factory: TestClientFactory,
136+
) -> None:
135137
async def app(scope: Scope, receive: Receive, send: Send) -> None:
136138
websocket = WebSocket(scope, receive=receive, send=send)
137139
await websocket.accept()
@@ -147,8 +149,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
147149

148150

149151
def test_websocket_send_and_receive_json(
150-
test_client_factory: Callable[..., TestClient],
151-
):
152+
test_client_factory: TestClientFactory,
153+
) -> None:
152154
async def app(scope: Scope, receive: Receive, send: Send) -> None:
153155
websocket = WebSocket(scope, receive=receive, send=send)
154156
await websocket.accept()
@@ -163,7 +165,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
163165
assert data == {"message": {"hello": "world"}}
164166

165167

166-
def test_websocket_iter_text(test_client_factory: Callable[..., TestClient]):
168+
def test_websocket_iter_text(test_client_factory: TestClientFactory) -> None:
167169
async def app(scope: Scope, receive: Receive, send: Send) -> None:
168170
websocket = WebSocket(scope, receive=receive, send=send)
169171
await websocket.accept()
@@ -177,7 +179,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
177179
assert data == "Message was: Hello, world!"
178180

179181

180-
def test_websocket_iter_bytes(test_client_factory: Callable[..., TestClient]):
182+
def test_websocket_iter_bytes(test_client_factory: TestClientFactory) -> None:
181183
async def app(scope: Scope, receive: Receive, send: Send) -> None:
182184
websocket = WebSocket(scope, receive=receive, send=send)
183185
await websocket.accept()
@@ -191,7 +193,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
191193
assert data == b"Message was: Hello, world!"
192194

193195

194-
def test_websocket_iter_json(test_client_factory: Callable[..., TestClient]):
196+
def test_websocket_iter_json(test_client_factory: TestClientFactory) -> None:
195197
async def app(scope: Scope, receive: Receive, send: Send) -> None:
196198
websocket = WebSocket(scope, receive=receive, send=send)
197199
await websocket.accept()
@@ -205,17 +207,17 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
205207
assert data == {"message": {"hello": "world"}}
206208

207209

208-
def test_websocket_concurrency_pattern(test_client_factory: Callable[..., TestClient]):
210+
def test_websocket_concurrency_pattern(test_client_factory: TestClientFactory) -> None:
209211
stream_send: ObjectSendStream[MutableMapping[str, Any]]
210212
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
211213
stream_send, stream_receive = anyio.create_memory_object_stream()
212214

213-
async def reader(websocket: WebSocket):
215+
async def reader(websocket: WebSocket) -> None:
214216
async with stream_send:
215217
async for data in websocket.iter_json():
216218
await stream_send.send(data)
217219

218-
async def writer(websocket: WebSocket):
220+
async def writer(websocket: WebSocket) -> None:
219221
async with stream_receive:
220222
async for message in stream_receive:
221223
await websocket.send_json(message)
@@ -235,7 +237,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
235237
assert data == {"hello": "world"}
236238

237239

238-
def test_client_close(test_client_factory: Callable[..., TestClient]):
240+
def test_client_close(test_client_factory: TestClientFactory) -> None:
239241
close_code = None
240242
close_reason = None
241243

@@ -257,7 +259,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
257259

258260

259261
@pytest.mark.anyio
260-
async def test_client_disconnect_on_send():
262+
async def test_client_disconnect_on_send() -> None:
261263
async def app(scope: Scope, receive: Receive, send: Send) -> None:
262264
websocket = WebSocket(scope, receive=receive, send=send)
263265
await websocket.accept()
@@ -278,7 +280,7 @@ async def send(message: Message) -> None:
278280
assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE
279281

280282

281-
def test_application_close(test_client_factory: Callable[..., TestClient]):
283+
def test_application_close(test_client_factory: TestClientFactory) -> None:
282284
async def app(scope: Scope, receive: Receive, send: Send) -> None:
283285
websocket = WebSocket(scope, receive=receive, send=send)
284286
await websocket.accept()
@@ -291,7 +293,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
291293
assert exc.value.code == status.WS_1001_GOING_AWAY
292294

293295

294-
def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
296+
def test_rejected_connection(test_client_factory: TestClientFactory) -> None:
295297
async def app(scope: Scope, receive: Receive, send: Send) -> None:
296298
websocket = WebSocket(scope, receive=receive, send=send)
297299
msg = await websocket.receive()
@@ -305,7 +307,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
305307
assert exc.value.code == status.WS_1001_GOING_AWAY
306308

307309

308-
def test_send_denial_response(test_client_factory: Callable[..., TestClient]):
310+
def test_send_denial_response(test_client_factory: TestClientFactory) -> None:
309311
async def app(scope: Scope, receive: Receive, send: Send) -> None:
310312
websocket = WebSocket(scope, receive=receive, send=send)
311313
msg = await websocket.receive()
@@ -321,7 +323,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
321323
assert exc.value.content == b"foo"
322324

323325

324-
def test_send_response_multi(test_client_factory: Callable[..., TestClient]):
326+
def test_send_response_multi(test_client_factory: TestClientFactory) -> None:
325327
async def app(scope: Scope, receive: Receive, send: Send) -> None:
326328
websocket = WebSocket(scope, receive=receive, send=send)
327329
msg = await websocket.receive()
@@ -356,7 +358,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
356358
assert exc.value.headers["foo"] == "bar"
357359

358360

359-
def test_send_response_unsupported(test_client_factory: Callable[..., TestClient]):
361+
def test_send_response_unsupported(test_client_factory: TestClientFactory) -> None:
360362
async def app(scope: Scope, receive: Receive, send: Send) -> None:
361363
del scope["extensions"]["websocket.http.response"]
362364
websocket = WebSocket(scope, receive=receive, send=send)
@@ -377,7 +379,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
377379
assert exc.value.code == status.WS_1000_NORMAL_CLOSURE
378380

379381

380-
def test_send_response_duplicate_start(test_client_factory: Callable[..., TestClient]):
382+
def test_send_response_duplicate_start(test_client_factory: TestClientFactory) -> None:
381383
async def app(scope: Scope, receive: Receive, send: Send) -> None:
382384
websocket = WebSocket(scope, receive=receive, send=send)
383385
msg = await websocket.receive()
@@ -410,7 +412,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
410412
pass # pragma: no cover
411413

412414

413-
def test_subprotocol(test_client_factory: Callable[..., TestClient]):
415+
def test_subprotocol(test_client_factory: TestClientFactory) -> None:
414416
async def app(scope: Scope, receive: Receive, send: Send) -> None:
415417
websocket = WebSocket(scope, receive=receive, send=send)
416418
assert websocket["subprotocols"] == ["soap", "wamp"]
@@ -422,7 +424,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
422424
assert websocket.accepted_subprotocol == "wamp"
423425

424426

425-
def test_additional_headers(test_client_factory: Callable[..., TestClient]):
427+
def test_additional_headers(test_client_factory: TestClientFactory) -> None:
426428
async def app(scope: Scope, receive: Receive, send: Send) -> None:
427429
websocket = WebSocket(scope, receive=receive, send=send)
428430
await websocket.accept(headers=[(b"additional", b"header")])
@@ -433,7 +435,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
433435
assert websocket.extra_headers == [(b"additional", b"header")]
434436

435437

436-
def test_no_additional_headers(test_client_factory: Callable[..., TestClient]):
438+
def test_no_additional_headers(test_client_factory: TestClientFactory) -> None:
437439
async def app(scope: Scope, receive: Receive, send: Send) -> None:
438440
websocket = WebSocket(scope, receive=receive, send=send)
439441
await websocket.accept()
@@ -444,7 +446,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
444446
assert websocket.extra_headers == []
445447

446448

447-
def test_websocket_exception(test_client_factory: Callable[..., TestClient]):
449+
def test_websocket_exception(test_client_factory: TestClientFactory) -> None:
448450
async def app(scope: Scope, receive: Receive, send: Send) -> None:
449451
assert False
450452

@@ -454,7 +456,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
454456
pass # pragma: no cover
455457

456458

457-
def test_duplicate_close(test_client_factory: Callable[..., TestClient]):
459+
def test_duplicate_close(test_client_factory: TestClientFactory) -> None:
458460
async def app(scope: Scope, receive: Receive, send: Send) -> None:
459461
websocket = WebSocket(scope, receive=receive, send=send)
460462
await websocket.accept()
@@ -467,7 +469,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
467469
pass # pragma: no cover
468470

469471

470-
def test_duplicate_disconnect(test_client_factory: Callable[..., TestClient]):
472+
def test_duplicate_disconnect(test_client_factory: TestClientFactory) -> None:
471473
async def app(scope: Scope, receive: Receive, send: Send) -> None:
472474
websocket = WebSocket(scope, receive=receive, send=send)
473475
await websocket.accept()
@@ -481,7 +483,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
481483
websocket.close()
482484

483485

484-
def test_websocket_scope_interface():
486+
def test_websocket_scope_interface() -> None:
485487
"""
486488
A WebSocket can be instantiated with a scope, and presents a `Mapping`
487489
interface.
@@ -513,7 +515,7 @@ async def mock_send(message: Message) -> None:
513515
assert {websocket} == {websocket}
514516

515517

516-
def test_websocket_close_reason(test_client_factory: Callable[..., TestClient]) -> None:
518+
def test_websocket_close_reason(test_client_factory: TestClientFactory) -> None:
517519
async def app(scope: Scope, receive: Receive, send: Send) -> None:
518520
websocket = WebSocket(scope, receive=receive, send=send)
519521
await websocket.accept()
@@ -527,7 +529,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
527529
assert exc.value.reason == "Going Away"
528530

529531

530-
def test_send_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
532+
def test_send_json_invalid_mode(test_client_factory: TestClientFactory) -> None:
531533
async def app(scope: Scope, receive: Receive, send: Send) -> None:
532534
websocket = WebSocket(scope, receive=receive, send=send)
533535
await websocket.accept()
@@ -539,7 +541,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
539541
pass # pragma: no cover
540542

541543

542-
def test_receive_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
544+
def test_receive_json_invalid_mode(test_client_factory: TestClientFactory) -> None:
543545
async def app(scope: Scope, receive: Receive, send: Send) -> None:
544546
websocket = WebSocket(scope, receive=receive, send=send)
545547
await websocket.accept()
@@ -551,7 +553,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
551553
pass # pragma: nocover
552554

553555

554-
def test_receive_text_before_accept(test_client_factory: Callable[..., TestClient]):
556+
def test_receive_text_before_accept(test_client_factory: TestClientFactory) -> None:
555557
async def app(scope: Scope, receive: Receive, send: Send) -> None:
556558
websocket = WebSocket(scope, receive=receive, send=send)
557559
await websocket.receive_text()
@@ -562,7 +564,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
562564
pass # pragma: nocover
563565

564566

565-
def test_receive_bytes_before_accept(test_client_factory: Callable[..., TestClient]):
567+
def test_receive_bytes_before_accept(test_client_factory: TestClientFactory) -> None:
566568
async def app(scope: Scope, receive: Receive, send: Send) -> None:
567569
websocket = WebSocket(scope, receive=receive, send=send)
568570
await websocket.receive_bytes()
@@ -573,7 +575,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
573575
pass # pragma: nocover
574576

575577

576-
def test_receive_json_before_accept(test_client_factory: Callable[..., TestClient]):
578+
def test_receive_json_before_accept(test_client_factory: TestClientFactory) -> None:
577579
async def app(scope: Scope, receive: Receive, send: Send) -> None:
578580
websocket = WebSocket(scope, receive=receive, send=send)
579581
await websocket.receive_json()
@@ -584,7 +586,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
584586
pass # pragma: no cover
585587

586588

587-
def test_send_before_accept(test_client_factory: Callable[..., TestClient]):
589+
def test_send_before_accept(test_client_factory: TestClientFactory) -> None:
588590
async def app(scope: Scope, receive: Receive, send: Send) -> None:
589591
websocket = WebSocket(scope, receive=receive, send=send)
590592
await websocket.send({"type": "websocket.send"})
@@ -595,7 +597,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
595597
pass # pragma: nocover
596598

597599

598-
def test_send_wrong_message_type(test_client_factory: Callable[..., TestClient]):
600+
def test_send_wrong_message_type(test_client_factory: TestClientFactory) -> None:
599601
async def app(scope: Scope, receive: Receive, send: Send) -> None:
600602
websocket = WebSocket(scope, receive=receive, send=send)
601603
await websocket.send({"type": "websocket.accept"})
@@ -607,7 +609,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
607609
pass # pragma: no cover
608610

609611

610-
def test_receive_before_accept(test_client_factory: Callable[..., TestClient]):
612+
def test_receive_before_accept(test_client_factory: TestClientFactory) -> None:
611613
async def app(scope: Scope, receive: Receive, send: Send) -> None:
612614
websocket = WebSocket(scope, receive=receive, send=send)
613615
await websocket.accept()
@@ -620,8 +622,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
620622
websocket.send({"type": "websocket.send"})
621623

622624

623-
def test_receive_wrong_message_type(test_client_factory: Callable[..., TestClient]):
624-
async def app(scope: Scope, receive: Receive, send: Send):
625+
def test_receive_wrong_message_type(test_client_factory: TestClientFactory) -> None:
626+
async def app(scope: Scope, receive: Receive, send: Send) -> None:
625627
websocket = WebSocket(scope, receive=receive, send=send)
626628
await websocket.accept()
627629
await websocket.receive()

0 commit comments

Comments
 (0)