Skip to content

Commit 3ae161e

Browse files
authored
Cancel WebSocketTestSession on close (#2427)
* Cancel `WebSocketTestSession` on close * Undo some noise * Fix test * Undo pyproject * Undo anyio bump * Undo changes on test_authentication * Always call cancel scope
1 parent 13c66c9 commit 3ae161e

File tree

2 files changed

+110
-69
lines changed

2 files changed

+110
-69
lines changed

starlette/testclient.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
from __future__ import annotations
2+
13
import contextlib
24
import inspect
35
import io
46
import json
57
import math
68
import queue
9+
import sys
710
import typing
811
import warnings
912
from concurrent.futures import Future
1013
from types import GeneratorType
1114
from urllib.parse import unquote, urljoin
1215

1316
import anyio
17+
import anyio.abc
1418
import anyio.from_thread
1519
from anyio.abc import ObjectReceiveStream, ObjectSendStream
1620
from anyio.streams.stapled import StapledObjectStream
@@ -19,6 +23,11 @@
1923
from starlette.types import ASGIApp, Message, Receive, Scope, Send
2024
from starlette.websockets import WebSocketDisconnect
2125

26+
if sys.version_info >= (3, 10): # pragma: no cover
27+
from typing import TypeGuard
28+
else: # pragma: no cover
29+
from typing_extensions import TypeGuard
30+
2231
try:
2332
import httpx
2433
except ModuleNotFoundError: # pragma: no cover
@@ -39,7 +48,7 @@
3948
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
4049

4150

42-
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
51+
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> TypeGuard[ASGI3App]:
4352
if inspect.isclass(app):
4453
return hasattr(app, "__await__")
4554
return is_async_callable(app)
@@ -64,7 +73,7 @@ class _AsyncBackend(typing.TypedDict):
6473

6574

6675
class _Upgrade(Exception):
67-
def __init__(self, session: "WebSocketTestSession") -> None:
76+
def __init__(self, session: WebSocketTestSession) -> None:
6877
self.session = session
6978

7079

@@ -79,16 +88,17 @@ def __init__(
7988
self.scope = scope
8089
self.accepted_subprotocol = None
8190
self.portal_factory = portal_factory
82-
self._receive_queue: "queue.Queue[Message]" = queue.Queue()
83-
self._send_queue: "queue.Queue[Message | BaseException]" = queue.Queue()
91+
self._receive_queue: queue.Queue[Message] = queue.Queue()
92+
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
8493
self.extra_headers = None
8594

86-
def __enter__(self) -> "WebSocketTestSession":
95+
def __enter__(self) -> WebSocketTestSession:
8796
self.exit_stack = contextlib.ExitStack()
8897
self.portal = self.exit_stack.enter_context(self.portal_factory())
98+
self.should_close = anyio.Event()
8999

90100
try:
91-
_: "Future[None]" = self.portal.start_task_soon(self._run)
101+
_: Future[None] = self.portal.start_task_soon(self._run)
92102
self.send({"type": "websocket.connect"})
93103
message = self.receive()
94104
self._raise_on_close(message)
@@ -99,10 +109,14 @@ def __enter__(self) -> "WebSocketTestSession":
99109
self.extra_headers = message.get("headers", None)
100110
return self
101111

112+
async def _notify_close(self) -> None:
113+
self.should_close.set()
114+
102115
def __exit__(self, *args: typing.Any) -> None:
103116
try:
104117
self.close(1000)
105118
finally:
119+
self.portal.start_task_soon(self._notify_close)
106120
self.exit_stack.close()
107121
while not self._send_queue.empty():
108122
message = self._send_queue.get()
@@ -113,14 +127,22 @@ async def _run(self) -> None:
113127
"""
114128
The sub-thread in which the websocket session runs.
115129
"""
116-
scope = self.scope
117-
receive = self._asgi_receive
118-
send = self._asgi_send
119-
try:
120-
await self.app(scope, receive, send)
121-
except BaseException as exc:
122-
self._send_queue.put(exc)
123-
raise
130+
131+
async def run_app(tg: anyio.abc.TaskGroup) -> None:
132+
try:
133+
await self.app(self.scope, self._asgi_receive, self._asgi_send)
134+
except anyio.get_cancelled_exc_class():
135+
...
136+
except BaseException as exc:
137+
self._send_queue.put(exc)
138+
raise
139+
finally:
140+
tg.cancel_scope.cancel()
141+
142+
async with anyio.create_task_group() as tg:
143+
tg.start_soon(run_app, tg)
144+
await self.should_close.wait()
145+
tg.cancel_scope.cancel()
124146

125147
async def _asgi_receive(self) -> Message:
126148
while self._receive_queue.empty():
@@ -153,7 +175,7 @@ def send_json(self, data: typing.Any, mode: str = "text") -> None:
153175
else:
154176
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
155177

156-
def close(self, code: int = 1000, reason: typing.Union[str, None] = None) -> None:
178+
def close(self, code: int = 1000, reason: str | None = None) -> None:
157179
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
158180

159181
def receive(self) -> Message:
@@ -172,8 +194,9 @@ def receive_bytes(self) -> bytes:
172194
self._raise_on_close(message)
173195
return typing.cast(bytes, message["bytes"])
174196

175-
def receive_json(self, mode: str = "text") -> typing.Any:
176-
assert mode in ["text", "binary"]
197+
def receive_json(
198+
self, mode: typing.Literal["text", "binary"] = "text"
199+
) -> typing.Any:
177200
message = self.receive()
178201
self._raise_on_close(message)
179202
if mode == "text":
@@ -191,7 +214,7 @@ def __init__(
191214
raise_server_exceptions: bool = True,
192215
root_path: str = "",
193216
*,
194-
app_state: typing.Dict[str, typing.Any],
217+
app_state: dict[str, typing.Any],
195218
) -> None:
196219
self.app = app
197220
self.raise_server_exceptions = raise_server_exceptions
@@ -217,7 +240,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
217240

218241
# Include the 'host' header.
219242
if "host" in request.headers:
220-
headers: typing.List[typing.Tuple[bytes, bytes]] = []
243+
headers: list[tuple[bytes, bytes]] = []
221244
elif port == default_port: # pragma: no cover
222245
headers = [(b"host", host.encode())]
223246
else: # pragma: no cover
@@ -229,7 +252,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
229252
for key, value in request.headers.multi_items()
230253
]
231254

232-
scope: typing.Dict[str, typing.Any]
255+
scope: dict[str, typing.Any]
233256

234257
if scheme in {"ws", "wss"}:
235258
subprotocol = request.headers.get("sec-websocket-protocol", None)
@@ -272,7 +295,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
272295
request_complete = False
273296
response_started = False
274297
response_complete: anyio.Event
275-
raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
298+
raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
276299
template = None
277300
context = None
278301

@@ -363,26 +386,25 @@ async def send(message: Message) -> None:
363386

364387
class TestClient(httpx.Client):
365388
__test__ = False
366-
task: "Future[None]"
367-
portal: typing.Optional[anyio.abc.BlockingPortal] = None
389+
task: Future[None]
390+
portal: anyio.abc.BlockingPortal | None = None
368391

369392
def __init__(
370393
self,
371394
app: ASGIApp,
372395
base_url: str = "http://testserver",
373396
raise_server_exceptions: bool = True,
374397
root_path: str = "",
375-
backend: str = "asyncio",
376-
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
377-
cookies: httpx._types.CookieTypes = None,
378-
headers: typing.Dict[str, str] = None,
398+
backend: typing.Literal["asyncio", "trio"] = "asyncio",
399+
backend_options: typing.Dict[str, typing.Any] | None = None,
400+
cookies: httpx._types.CookieTypes | None = None,
401+
headers: typing.Dict[str, str] | None = None,
379402
follow_redirects: bool = True,
380403
) -> None:
381404
self.async_backend = _AsyncBackend(
382405
backend=backend, backend_options=backend_options or {}
383406
)
384407
if _is_asgi3(app):
385-
app = typing.cast(ASGI3App, app)
386408
asgi_app = app
387409
else:
388410
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
@@ -419,13 +441,11 @@ def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, No
419441
yield portal
420442

421443
def _choose_redirect_arg(
422-
self,
423-
follow_redirects: typing.Optional[bool],
424-
allow_redirects: typing.Optional[bool],
425-
) -> typing.Union[bool, httpx._client.UseClientDefault]:
426-
redirect: typing.Union[
427-
bool, httpx._client.UseClientDefault
428-
] = httpx._client.USE_CLIENT_DEFAULT
444+
self, follow_redirects: bool | None, allow_redirects: bool | None
445+
) -> bool | httpx._client.UseClientDefault:
446+
redirect: bool | httpx._client.UseClientDefault = (
447+
httpx._client.USE_CLIENT_DEFAULT
448+
)
429449
if allow_redirects is not None:
430450
message = (
431451
"The `allow_redirects` argument is deprecated. "
@@ -709,7 +729,10 @@ def delete( # type: ignore[override]
709729
)
710730

711731
def websocket_connect(
712-
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
732+
self,
733+
url: str,
734+
subprotocols: typing.Sequence[str] | None = None,
735+
**kwargs: typing.Any,
713736
) -> "WebSocketTestSession":
714737
url = urljoin("ws://testserver", url)
715738
headers = kwargs.get("headers", {})

0 commit comments

Comments
 (0)