1
+ from __future__ import annotations
2
+
1
3
import contextlib
2
4
import inspect
3
5
import io
4
6
import json
5
7
import math
6
8
import queue
9
+ import sys
7
10
import typing
8
11
import warnings
9
12
from concurrent .futures import Future
10
13
from types import GeneratorType
11
14
from urllib .parse import unquote , urljoin
12
15
13
16
import anyio
17
+ import anyio .abc
14
18
import anyio .from_thread
15
19
from anyio .abc import ObjectReceiveStream , ObjectSendStream
16
20
from anyio .streams .stapled import StapledObjectStream
19
23
from starlette .types import ASGIApp , Message , Receive , Scope , Send
20
24
from starlette .websockets import WebSocketDisconnect
21
25
26
+ if sys .version_info >= (3 , 10 ): # pragma: no cover
27
+ from typing import TypeGuard
28
+ else : # pragma: no cover
29
+ from typing_extensions import TypeGuard
30
+
22
31
try :
23
32
import httpx
24
33
except ModuleNotFoundError : # pragma: no cover
39
48
_RequestData = typing .Mapping [str , typing .Union [str , typing .Iterable [str ]]]
40
49
41
50
42
- def _is_asgi3 (app : typing .Union [ASGI2App , ASGI3App ]) -> bool :
51
+ def _is_asgi3 (app : typing .Union [ASGI2App , ASGI3App ]) -> TypeGuard [ ASGI3App ] :
43
52
if inspect .isclass (app ):
44
53
return hasattr (app , "__await__" )
45
54
return is_async_callable (app )
@@ -64,7 +73,7 @@ class _AsyncBackend(typing.TypedDict):
64
73
65
74
66
75
class _Upgrade (Exception ):
67
- def __init__ (self , session : " WebSocketTestSession" ) -> None :
76
+ def __init__ (self , session : WebSocketTestSession ) -> None :
68
77
self .session = session
69
78
70
79
@@ -79,16 +88,17 @@ def __init__(
79
88
self .scope = scope
80
89
self .accepted_subprotocol = None
81
90
self .portal_factory = portal_factory
82
- self ._receive_queue : " queue.Queue[Message]" = queue .Queue ()
83
- self ._send_queue : " queue.Queue[Message | BaseException]" = queue .Queue ()
91
+ self ._receive_queue : queue .Queue [Message ] = queue .Queue ()
92
+ self ._send_queue : queue .Queue [Message | BaseException ] = queue .Queue ()
84
93
self .extra_headers = None
85
94
86
- def __enter__ (self ) -> " WebSocketTestSession" :
95
+ def __enter__ (self ) -> WebSocketTestSession :
87
96
self .exit_stack = contextlib .ExitStack ()
88
97
self .portal = self .exit_stack .enter_context (self .portal_factory ())
98
+ self .should_close = anyio .Event ()
89
99
90
100
try :
91
- _ : " Future[None]" = self .portal .start_task_soon (self ._run )
101
+ _ : Future [None ] = self .portal .start_task_soon (self ._run )
92
102
self .send ({"type" : "websocket.connect" })
93
103
message = self .receive ()
94
104
self ._raise_on_close (message )
@@ -99,10 +109,14 @@ def __enter__(self) -> "WebSocketTestSession":
99
109
self .extra_headers = message .get ("headers" , None )
100
110
return self
101
111
112
+ async def _notify_close (self ) -> None :
113
+ self .should_close .set ()
114
+
102
115
def __exit__ (self , * args : typing .Any ) -> None :
103
116
try :
104
117
self .close (1000 )
105
118
finally :
119
+ self .portal .start_task_soon (self ._notify_close )
106
120
self .exit_stack .close ()
107
121
while not self ._send_queue .empty ():
108
122
message = self ._send_queue .get ()
@@ -113,14 +127,22 @@ async def _run(self) -> None:
113
127
"""
114
128
The sub-thread in which the websocket session runs.
115
129
"""
116
- scope = self .scope
117
- receive = self ._asgi_receive
118
- send = self ._asgi_send
119
- try :
120
- await self .app (scope , receive , send )
121
- except BaseException as exc :
122
- self ._send_queue .put (exc )
123
- raise
130
+
131
+ async def run_app (tg : anyio .abc .TaskGroup ) -> None :
132
+ try :
133
+ await self .app (self .scope , self ._asgi_receive , self ._asgi_send )
134
+ except anyio .get_cancelled_exc_class ():
135
+ ...
136
+ except BaseException as exc :
137
+ self ._send_queue .put (exc )
138
+ raise
139
+ finally :
140
+ tg .cancel_scope .cancel ()
141
+
142
+ async with anyio .create_task_group () as tg :
143
+ tg .start_soon (run_app , tg )
144
+ await self .should_close .wait ()
145
+ tg .cancel_scope .cancel ()
124
146
125
147
async def _asgi_receive (self ) -> Message :
126
148
while self ._receive_queue .empty ():
@@ -153,7 +175,7 @@ def send_json(self, data: typing.Any, mode: str = "text") -> None:
153
175
else :
154
176
self .send ({"type" : "websocket.receive" , "bytes" : text .encode ("utf-8" )})
155
177
156
- def close (self , code : int = 1000 , reason : typing . Union [ str , None ] = None ) -> None :
178
+ def close (self , code : int = 1000 , reason : str | None = None ) -> None :
157
179
self .send ({"type" : "websocket.disconnect" , "code" : code , "reason" : reason })
158
180
159
181
def receive (self ) -> Message :
@@ -172,8 +194,9 @@ def receive_bytes(self) -> bytes:
172
194
self ._raise_on_close (message )
173
195
return typing .cast (bytes , message ["bytes" ])
174
196
175
- def receive_json (self , mode : str = "text" ) -> typing .Any :
176
- assert mode in ["text" , "binary" ]
197
+ def receive_json (
198
+ self , mode : typing .Literal ["text" , "binary" ] = "text"
199
+ ) -> typing .Any :
177
200
message = self .receive ()
178
201
self ._raise_on_close (message )
179
202
if mode == "text" :
@@ -191,7 +214,7 @@ def __init__(
191
214
raise_server_exceptions : bool = True ,
192
215
root_path : str = "" ,
193
216
* ,
194
- app_state : typing . Dict [str , typing .Any ],
217
+ app_state : dict [str , typing .Any ],
195
218
) -> None :
196
219
self .app = app
197
220
self .raise_server_exceptions = raise_server_exceptions
@@ -217,7 +240,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
217
240
218
241
# Include the 'host' header.
219
242
if "host" in request .headers :
220
- headers : typing . List [ typing . Tuple [bytes , bytes ]] = []
243
+ headers : list [ tuple [bytes , bytes ]] = []
221
244
elif port == default_port : # pragma: no cover
222
245
headers = [(b"host" , host .encode ())]
223
246
else : # pragma: no cover
@@ -229,7 +252,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
229
252
for key , value in request .headers .multi_items ()
230
253
]
231
254
232
- scope : typing . Dict [str , typing .Any ]
255
+ scope : dict [str , typing .Any ]
233
256
234
257
if scheme in {"ws" , "wss" }:
235
258
subprotocol = request .headers .get ("sec-websocket-protocol" , None )
@@ -272,7 +295,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
272
295
request_complete = False
273
296
response_started = False
274
297
response_complete : anyio .Event
275
- raw_kwargs : typing . Dict [str , typing .Any ] = {"stream" : io .BytesIO ()}
298
+ raw_kwargs : dict [str , typing .Any ] = {"stream" : io .BytesIO ()}
276
299
template = None
277
300
context = None
278
301
@@ -363,26 +386,25 @@ async def send(message: Message) -> None:
363
386
364
387
class TestClient (httpx .Client ):
365
388
__test__ = False
366
- task : " Future[None]"
367
- portal : typing . Optional [ anyio .abc .BlockingPortal ] = None
389
+ task : Future [None ]
390
+ portal : anyio .abc .BlockingPortal | None = None
368
391
369
392
def __init__ (
370
393
self ,
371
394
app : ASGIApp ,
372
395
base_url : str = "http://testserver" ,
373
396
raise_server_exceptions : bool = True ,
374
397
root_path : str = "" ,
375
- backend : str = "asyncio" ,
376
- backend_options : typing .Optional [ typing . Dict [str , typing .Any ]] = None ,
377
- cookies : httpx ._types .CookieTypes = None ,
378
- headers : typing .Dict [str , str ] = None ,
398
+ backend : typing . Literal [ "asyncio" , "trio" ] = "asyncio" ,
399
+ backend_options : typing .Dict [str , typing .Any ] | None = None ,
400
+ cookies : httpx ._types .CookieTypes | None = None ,
401
+ headers : typing .Dict [str , str ] | None = None ,
379
402
follow_redirects : bool = True ,
380
403
) -> None :
381
404
self .async_backend = _AsyncBackend (
382
405
backend = backend , backend_options = backend_options or {}
383
406
)
384
407
if _is_asgi3 (app ):
385
- app = typing .cast (ASGI3App , app )
386
408
asgi_app = app
387
409
else :
388
410
app = typing .cast (ASGI2App , app ) # type: ignore[assignment]
@@ -419,13 +441,11 @@ def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, No
419
441
yield portal
420
442
421
443
def _choose_redirect_arg (
422
- self ,
423
- follow_redirects : typing .Optional [bool ],
424
- allow_redirects : typing .Optional [bool ],
425
- ) -> typing .Union [bool , httpx ._client .UseClientDefault ]:
426
- redirect : typing .Union [
427
- bool , httpx ._client .UseClientDefault
428
- ] = httpx ._client .USE_CLIENT_DEFAULT
444
+ self , follow_redirects : bool | None , allow_redirects : bool | None
445
+ ) -> bool | httpx ._client .UseClientDefault :
446
+ redirect : bool | httpx ._client .UseClientDefault = (
447
+ httpx ._client .USE_CLIENT_DEFAULT
448
+ )
429
449
if allow_redirects is not None :
430
450
message = (
431
451
"The `allow_redirects` argument is deprecated. "
@@ -709,7 +729,10 @@ def delete( # type: ignore[override]
709
729
)
710
730
711
731
def websocket_connect (
712
- self , url : str , subprotocols : typing .Sequence [str ] = None , ** kwargs : typing .Any
732
+ self ,
733
+ url : str ,
734
+ subprotocols : typing .Sequence [str ] | None = None ,
735
+ ** kwargs : typing .Any ,
713
736
) -> "WebSocketTestSession" :
714
737
url = urljoin ("ws://testserver" , url )
715
738
headers = kwargs .get ("headers" , {})
0 commit comments