11
11
from starlette .types import Message , Receive , Scope , Send
12
12
from starlette .websockets import WebSocket , WebSocketDisconnect , WebSocketState
13
13
14
+ TestClientFactory = Callable [..., TestClient ]
14
15
15
- def test_websocket_url (test_client_factory : Callable [..., TestClient ]):
16
+
17
+ def test_websocket_url (test_client_factory : TestClientFactory ) -> None :
16
18
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
17
19
websocket = WebSocket (scope , receive = receive , send = send )
18
20
await websocket .accept ()
@@ -25,7 +27,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
25
27
assert data == {"url" : "ws://testserver/123?a=abc" }
26
28
27
29
28
- def test_websocket_binary_json (test_client_factory : Callable [..., TestClient ]) :
30
+ def test_websocket_binary_json (test_client_factory : TestClientFactory ) -> None :
29
31
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
30
32
websocket = WebSocket (scope , receive = receive , send = send )
31
33
await websocket .accept ()
@@ -41,8 +43,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
41
43
42
44
43
45
def test_websocket_ensure_unicode_on_send_json (
44
- test_client_factory : Callable [..., TestClient ] ,
45
- ):
46
+ test_client_factory : TestClientFactory ,
47
+ ) -> None :
46
48
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
47
49
websocket = WebSocket (scope , receive = receive , send = send )
48
50
@@ -58,7 +60,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
58
60
assert data == '{"test":"数据"}'
59
61
60
62
61
- def test_websocket_query_params (test_client_factory : Callable [..., TestClient ]) :
63
+ def test_websocket_query_params (test_client_factory : TestClientFactory ) -> None :
62
64
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
63
65
websocket = WebSocket (scope , receive = receive , send = send )
64
66
query_params = dict (websocket .query_params )
@@ -76,7 +78,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
76
78
any (module in sys .modules for module in ("brotli" , "brotlicffi" )),
77
79
reason = 'urllib3 includes "br" to the "accept-encoding" headers.' ,
78
80
)
79
- def test_websocket_headers (test_client_factory : Callable [..., TestClient ]) :
81
+ def test_websocket_headers (test_client_factory : TestClientFactory ) -> None :
80
82
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
81
83
websocket = WebSocket (scope , receive = receive , send = send )
82
84
headers = dict (websocket .headers )
@@ -99,7 +101,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
99
101
assert data == {"headers" : expected_headers }
100
102
101
103
102
- def test_websocket_port (test_client_factory : Callable [..., TestClient ]) :
104
+ def test_websocket_port (test_client_factory : TestClientFactory ) -> None :
103
105
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
104
106
websocket = WebSocket (scope , receive = receive , send = send )
105
107
await websocket .accept ()
@@ -113,8 +115,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
113
115
114
116
115
117
def test_websocket_send_and_receive_text (
116
- test_client_factory : Callable [..., TestClient ] ,
117
- ):
118
+ test_client_factory : TestClientFactory ,
119
+ ) -> None :
118
120
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
119
121
websocket = WebSocket (scope , receive = receive , send = send )
120
122
await websocket .accept ()
@@ -130,8 +132,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
130
132
131
133
132
134
def test_websocket_send_and_receive_bytes (
133
- test_client_factory : Callable [..., TestClient ] ,
134
- ):
135
+ test_client_factory : TestClientFactory ,
136
+ ) -> None :
135
137
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
136
138
websocket = WebSocket (scope , receive = receive , send = send )
137
139
await websocket .accept ()
@@ -147,8 +149,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
147
149
148
150
149
151
def test_websocket_send_and_receive_json (
150
- test_client_factory : Callable [..., TestClient ] ,
151
- ):
152
+ test_client_factory : TestClientFactory ,
153
+ ) -> None :
152
154
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
153
155
websocket = WebSocket (scope , receive = receive , send = send )
154
156
await websocket .accept ()
@@ -163,7 +165,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
163
165
assert data == {"message" : {"hello" : "world" }}
164
166
165
167
166
- def test_websocket_iter_text (test_client_factory : Callable [..., TestClient ]) :
168
+ def test_websocket_iter_text (test_client_factory : TestClientFactory ) -> None :
167
169
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
168
170
websocket = WebSocket (scope , receive = receive , send = send )
169
171
await websocket .accept ()
@@ -177,7 +179,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
177
179
assert data == "Message was: Hello, world!"
178
180
179
181
180
- def test_websocket_iter_bytes (test_client_factory : Callable [..., TestClient ]) :
182
+ def test_websocket_iter_bytes (test_client_factory : TestClientFactory ) -> None :
181
183
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
182
184
websocket = WebSocket (scope , receive = receive , send = send )
183
185
await websocket .accept ()
@@ -191,7 +193,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
191
193
assert data == b"Message was: Hello, world!"
192
194
193
195
194
- def test_websocket_iter_json (test_client_factory : Callable [..., TestClient ]) :
196
+ def test_websocket_iter_json (test_client_factory : TestClientFactory ) -> None :
195
197
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
196
198
websocket = WebSocket (scope , receive = receive , send = send )
197
199
await websocket .accept ()
@@ -205,17 +207,17 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
205
207
assert data == {"message" : {"hello" : "world" }}
206
208
207
209
208
- def test_websocket_concurrency_pattern (test_client_factory : Callable [..., TestClient ]) :
210
+ def test_websocket_concurrency_pattern (test_client_factory : TestClientFactory ) -> None :
209
211
stream_send : ObjectSendStream [MutableMapping [str , Any ]]
210
212
stream_receive : ObjectReceiveStream [MutableMapping [str , Any ]]
211
213
stream_send , stream_receive = anyio .create_memory_object_stream ()
212
214
213
- async def reader (websocket : WebSocket ):
215
+ async def reader (websocket : WebSocket ) -> None :
214
216
async with stream_send :
215
217
async for data in websocket .iter_json ():
216
218
await stream_send .send (data )
217
219
218
- async def writer (websocket : WebSocket ):
220
+ async def writer (websocket : WebSocket ) -> None :
219
221
async with stream_receive :
220
222
async for message in stream_receive :
221
223
await websocket .send_json (message )
@@ -235,7 +237,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
235
237
assert data == {"hello" : "world" }
236
238
237
239
238
- def test_client_close (test_client_factory : Callable [..., TestClient ]) :
240
+ def test_client_close (test_client_factory : TestClientFactory ) -> None :
239
241
close_code = None
240
242
close_reason = None
241
243
@@ -257,7 +259,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
257
259
258
260
259
261
@pytest .mark .anyio
260
- async def test_client_disconnect_on_send ():
262
+ async def test_client_disconnect_on_send () -> None :
261
263
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
262
264
websocket = WebSocket (scope , receive = receive , send = send )
263
265
await websocket .accept ()
@@ -278,7 +280,7 @@ async def send(message: Message) -> None:
278
280
assert ctx .value .code == status .WS_1006_ABNORMAL_CLOSURE
279
281
280
282
281
- def test_application_close (test_client_factory : Callable [..., TestClient ]) :
283
+ def test_application_close (test_client_factory : TestClientFactory ) -> None :
282
284
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
283
285
websocket = WebSocket (scope , receive = receive , send = send )
284
286
await websocket .accept ()
@@ -291,7 +293,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
291
293
assert exc .value .code == status .WS_1001_GOING_AWAY
292
294
293
295
294
- def test_rejected_connection (test_client_factory : Callable [..., TestClient ]) :
296
+ def test_rejected_connection (test_client_factory : TestClientFactory ) -> None :
295
297
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
296
298
websocket = WebSocket (scope , receive = receive , send = send )
297
299
msg = await websocket .receive ()
@@ -305,7 +307,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
305
307
assert exc .value .code == status .WS_1001_GOING_AWAY
306
308
307
309
308
- def test_send_denial_response (test_client_factory : Callable [..., TestClient ]) :
310
+ def test_send_denial_response (test_client_factory : TestClientFactory ) -> None :
309
311
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
310
312
websocket = WebSocket (scope , receive = receive , send = send )
311
313
msg = await websocket .receive ()
@@ -321,7 +323,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
321
323
assert exc .value .content == b"foo"
322
324
323
325
324
- def test_send_response_multi (test_client_factory : Callable [..., TestClient ]) :
326
+ def test_send_response_multi (test_client_factory : TestClientFactory ) -> None :
325
327
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
326
328
websocket = WebSocket (scope , receive = receive , send = send )
327
329
msg = await websocket .receive ()
@@ -356,7 +358,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
356
358
assert exc .value .headers ["foo" ] == "bar"
357
359
358
360
359
- def test_send_response_unsupported (test_client_factory : Callable [..., TestClient ]) :
361
+ def test_send_response_unsupported (test_client_factory : TestClientFactory ) -> None :
360
362
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
361
363
del scope ["extensions" ]["websocket.http.response" ]
362
364
websocket = WebSocket (scope , receive = receive , send = send )
@@ -377,7 +379,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
377
379
assert exc .value .code == status .WS_1000_NORMAL_CLOSURE
378
380
379
381
380
- def test_send_response_duplicate_start (test_client_factory : Callable [..., TestClient ]) :
382
+ def test_send_response_duplicate_start (test_client_factory : TestClientFactory ) -> None :
381
383
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
382
384
websocket = WebSocket (scope , receive = receive , send = send )
383
385
msg = await websocket .receive ()
@@ -410,7 +412,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
410
412
pass # pragma: no cover
411
413
412
414
413
- def test_subprotocol (test_client_factory : Callable [..., TestClient ]) :
415
+ def test_subprotocol (test_client_factory : TestClientFactory ) -> None :
414
416
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
415
417
websocket = WebSocket (scope , receive = receive , send = send )
416
418
assert websocket ["subprotocols" ] == ["soap" , "wamp" ]
@@ -422,7 +424,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
422
424
assert websocket .accepted_subprotocol == "wamp"
423
425
424
426
425
- def test_additional_headers (test_client_factory : Callable [..., TestClient ]) :
427
+ def test_additional_headers (test_client_factory : TestClientFactory ) -> None :
426
428
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
427
429
websocket = WebSocket (scope , receive = receive , send = send )
428
430
await websocket .accept (headers = [(b"additional" , b"header" )])
@@ -433,7 +435,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
433
435
assert websocket .extra_headers == [(b"additional" , b"header" )]
434
436
435
437
436
- def test_no_additional_headers (test_client_factory : Callable [..., TestClient ]) :
438
+ def test_no_additional_headers (test_client_factory : TestClientFactory ) -> None :
437
439
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
438
440
websocket = WebSocket (scope , receive = receive , send = send )
439
441
await websocket .accept ()
@@ -444,7 +446,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
444
446
assert websocket .extra_headers == []
445
447
446
448
447
- def test_websocket_exception (test_client_factory : Callable [..., TestClient ]) :
449
+ def test_websocket_exception (test_client_factory : TestClientFactory ) -> None :
448
450
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
449
451
assert False
450
452
@@ -454,7 +456,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
454
456
pass # pragma: no cover
455
457
456
458
457
- def test_duplicate_close (test_client_factory : Callable [..., TestClient ]) :
459
+ def test_duplicate_close (test_client_factory : TestClientFactory ) -> None :
458
460
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
459
461
websocket = WebSocket (scope , receive = receive , send = send )
460
462
await websocket .accept ()
@@ -467,7 +469,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
467
469
pass # pragma: no cover
468
470
469
471
470
- def test_duplicate_disconnect (test_client_factory : Callable [..., TestClient ]) :
472
+ def test_duplicate_disconnect (test_client_factory : TestClientFactory ) -> None :
471
473
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
472
474
websocket = WebSocket (scope , receive = receive , send = send )
473
475
await websocket .accept ()
@@ -481,7 +483,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
481
483
websocket .close ()
482
484
483
485
484
- def test_websocket_scope_interface ():
486
+ def test_websocket_scope_interface () -> None :
485
487
"""
486
488
A WebSocket can be instantiated with a scope, and presents a `Mapping`
487
489
interface.
@@ -513,7 +515,7 @@ async def mock_send(message: Message) -> None:
513
515
assert {websocket } == {websocket }
514
516
515
517
516
- def test_websocket_close_reason (test_client_factory : Callable [..., TestClient ] ) -> None :
518
+ def test_websocket_close_reason (test_client_factory : TestClientFactory ) -> None :
517
519
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
518
520
websocket = WebSocket (scope , receive = receive , send = send )
519
521
await websocket .accept ()
@@ -527,7 +529,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
527
529
assert exc .value .reason == "Going Away"
528
530
529
531
530
- def test_send_json_invalid_mode (test_client_factory : Callable [..., TestClient ]) :
532
+ def test_send_json_invalid_mode (test_client_factory : TestClientFactory ) -> None :
531
533
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
532
534
websocket = WebSocket (scope , receive = receive , send = send )
533
535
await websocket .accept ()
@@ -539,7 +541,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
539
541
pass # pragma: no cover
540
542
541
543
542
- def test_receive_json_invalid_mode (test_client_factory : Callable [..., TestClient ]) :
544
+ def test_receive_json_invalid_mode (test_client_factory : TestClientFactory ) -> None :
543
545
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
544
546
websocket = WebSocket (scope , receive = receive , send = send )
545
547
await websocket .accept ()
@@ -551,7 +553,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
551
553
pass # pragma: nocover
552
554
553
555
554
- def test_receive_text_before_accept (test_client_factory : Callable [..., TestClient ]) :
556
+ def test_receive_text_before_accept (test_client_factory : TestClientFactory ) -> None :
555
557
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
556
558
websocket = WebSocket (scope , receive = receive , send = send )
557
559
await websocket .receive_text ()
@@ -562,7 +564,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
562
564
pass # pragma: nocover
563
565
564
566
565
- def test_receive_bytes_before_accept (test_client_factory : Callable [..., TestClient ]) :
567
+ def test_receive_bytes_before_accept (test_client_factory : TestClientFactory ) -> None :
566
568
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
567
569
websocket = WebSocket (scope , receive = receive , send = send )
568
570
await websocket .receive_bytes ()
@@ -573,7 +575,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
573
575
pass # pragma: nocover
574
576
575
577
576
- def test_receive_json_before_accept (test_client_factory : Callable [..., TestClient ]) :
578
+ def test_receive_json_before_accept (test_client_factory : TestClientFactory ) -> None :
577
579
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
578
580
websocket = WebSocket (scope , receive = receive , send = send )
579
581
await websocket .receive_json ()
@@ -584,7 +586,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
584
586
pass # pragma: no cover
585
587
586
588
587
- def test_send_before_accept (test_client_factory : Callable [..., TestClient ]) :
589
+ def test_send_before_accept (test_client_factory : TestClientFactory ) -> None :
588
590
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
589
591
websocket = WebSocket (scope , receive = receive , send = send )
590
592
await websocket .send ({"type" : "websocket.send" })
@@ -595,7 +597,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
595
597
pass # pragma: nocover
596
598
597
599
598
- def test_send_wrong_message_type (test_client_factory : Callable [..., TestClient ]) :
600
+ def test_send_wrong_message_type (test_client_factory : TestClientFactory ) -> None :
599
601
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
600
602
websocket = WebSocket (scope , receive = receive , send = send )
601
603
await websocket .send ({"type" : "websocket.accept" })
@@ -607,7 +609,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
607
609
pass # pragma: no cover
608
610
609
611
610
- def test_receive_before_accept (test_client_factory : Callable [..., TestClient ]) :
612
+ def test_receive_before_accept (test_client_factory : TestClientFactory ) -> None :
611
613
async def app (scope : Scope , receive : Receive , send : Send ) -> None :
612
614
websocket = WebSocket (scope , receive = receive , send = send )
613
615
await websocket .accept ()
@@ -620,8 +622,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
620
622
websocket .send ({"type" : "websocket.send" })
621
623
622
624
623
- def test_receive_wrong_message_type (test_client_factory : Callable [..., TestClient ]) :
624
- async def app (scope : Scope , receive : Receive , send : Send ):
625
+ def test_receive_wrong_message_type (test_client_factory : TestClientFactory ) -> None :
626
+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
625
627
websocket = WebSocket (scope , receive = receive , send = send )
626
628
await websocket .accept ()
627
629
await websocket .receive ()
0 commit comments