Skip to content

Commit 2ef14a6

Browse files
authored
[PR #8641/0a88bab backport][3.10] Fix WebSocket ping tasks being prematurely garbage collected (#8646)
1 parent 68e8496 commit 2ef14a6

File tree

4 files changed

+91
-12
lines changed

4 files changed

+91
-12
lines changed

CHANGES/8641.bugfix.rst

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed WebSocket ping tasks being prematurely garbage collected -- by :user:`bdraco`.
2+
3+
There was a small risk that WebSocket ping tasks would be prematurely garbage collected because the event loop only holds a weak reference to the task. The garbage collection risk has been fixed by holding a strong reference to the task. Additionally, the task is now scheduled eagerly with Python 3.12+ to increase the chance it can be completed immediately and avoid having to hold any references to the task.

aiohttp/client_ws.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
self._exception: Optional[BaseException] = None
7373
self._compress = compress
7474
self._client_notakeover = client_notakeover
75+
self._ping_task: Optional[asyncio.Task[None]] = None
7576

7677
self._reset_heartbeat()
7778

@@ -80,6 +81,9 @@ def _cancel_heartbeat(self) -> None:
8081
if self._heartbeat_cb is not None:
8182
self._heartbeat_cb.cancel()
8283
self._heartbeat_cb = None
84+
if self._ping_task is not None:
85+
self._ping_task.cancel()
86+
self._ping_task = None
8387

8488
def _cancel_pong_response_cb(self) -> None:
8589
if self._pong_response_cb is not None:
@@ -118,11 +122,6 @@ def _send_heartbeat(self) -> None:
118122
)
119123
return
120124

121-
# fire-and-forget a task is not perfect but maybe ok for
122-
# sending ping. Otherwise we need a long-living heartbeat
123-
# task in the class.
124-
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]
125-
126125
conn = self._conn
127126
timeout_ceil_threshold = (
128127
conn._connector._timeout_ceil_threshold if conn is not None else 5
@@ -131,6 +130,22 @@ def _send_heartbeat(self) -> None:
131130
self._cancel_pong_response_cb()
132131
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
133132

133+
if sys.version_info >= (3, 12):
134+
# Optimization for Python 3.12, try to send the ping
135+
# immediately to avoid having to schedule
136+
# the task on the event loop.
137+
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
138+
else:
139+
ping_task = loop.create_task(self._writer.ping())
140+
141+
if not ping_task.done():
142+
self._ping_task = ping_task
143+
ping_task.add_done_callback(self._ping_task_done)
144+
145+
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
146+
"""Callback for when the ping task completes."""
147+
self._ping_task = None
148+
134149
def _pong_not_received(self) -> None:
135150
if not self._closed:
136151
self._set_closed()

aiohttp/web_ws.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,16 @@ def __init__(
9696
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
9797
self._compress = compress
9898
self._max_msg_size = max_msg_size
99+
self._ping_task: Optional[asyncio.Task[None]] = None
99100

100101
def _cancel_heartbeat(self) -> None:
101102
self._cancel_pong_response_cb()
102103
if self._heartbeat_cb is not None:
103104
self._heartbeat_cb.cancel()
104105
self._heartbeat_cb = None
106+
if self._ping_task is not None:
107+
self._ping_task.cancel()
108+
self._ping_task = None
105109

106110
def _cancel_pong_response_cb(self) -> None:
107111
if self._pong_response_cb is not None:
@@ -141,11 +145,6 @@ def _send_heartbeat(self) -> None:
141145
)
142146
return
143147

144-
# fire-and-forget a task is not perfect but maybe ok for
145-
# sending ping. Otherwise we need a long-living heartbeat
146-
# task in the class.
147-
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]
148-
149148
req = self._req
150149
timeout_ceil_threshold = (
151150
req._protocol._timeout_ceil_threshold if req is not None else 5
@@ -154,6 +153,22 @@ def _send_heartbeat(self) -> None:
154153
self._cancel_pong_response_cb()
155154
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
156155

156+
if sys.version_info >= (3, 12):
157+
# Optimization for Python 3.12, try to send the ping
158+
# immediately to avoid having to schedule
159+
# the task on the event loop.
160+
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
161+
else:
162+
ping_task = loop.create_task(self._writer.ping())
163+
164+
if not ping_task.done():
165+
self._ping_task = ping_task
166+
ping_task.add_done_callback(self._ping_task_done)
167+
168+
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
169+
"""Callback for when the ping task completes."""
170+
self._ping_task = None
171+
157172
def _pong_not_received(self) -> None:
158173
if self._req is not None and self._req.transport is not None:
159174
self._set_closed()

tests/test_client_ws_functional.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import sys
33
from typing import Any, NoReturn
4+
from unittest import mock
45

56
import pytest
67

@@ -727,8 +728,53 @@ async def handler(request):
727728
assert isinstance(msg.data, ServerTimeoutError)
728729

729730

730-
async def test_send_recv_compress(aiohttp_client: Any) -> None:
731-
async def handler(request):
731+
async def test_close_websocket_while_ping_inflight(
732+
aiohttp_client: AiohttpClient,
733+
) -> None:
734+
"""Test closing the websocket while a ping is in-flight."""
735+
ping_received = False
736+
737+
async def handler(request: web.Request) -> NoReturn:
738+
nonlocal ping_received
739+
ws = web.WebSocketResponse(autoping=False)
740+
await ws.prepare(request)
741+
msg = await ws.receive()
742+
assert msg.type is aiohttp.WSMsgType.BINARY
743+
msg = await ws.receive()
744+
ping_received = msg.type is aiohttp.WSMsgType.PING
745+
await ws.receive()
746+
assert False
747+
748+
app = web.Application()
749+
app.router.add_route("GET", "/", handler)
750+
751+
client = await aiohttp_client(app)
752+
resp = await client.ws_connect("/", heartbeat=0.1)
753+
await resp.send_bytes(b"ask")
754+
755+
cancelled = False
756+
ping_stated = False
757+
758+
async def delayed_ping() -> None:
759+
nonlocal cancelled, ping_stated
760+
ping_stated = True
761+
try:
762+
await asyncio.sleep(1)
763+
except asyncio.CancelledError:
764+
cancelled = True
765+
raise
766+
767+
with mock.patch.object(resp._writer, "ping", delayed_ping):
768+
await asyncio.sleep(0.1)
769+
770+
await resp.close()
771+
await asyncio.sleep(0)
772+
assert ping_stated is True
773+
assert cancelled is True
774+
775+
776+
async def test_send_recv_compress(aiohttp_client: AiohttpClient) -> None:
777+
async def handler(request: web.Request) -> web.WebSocketResponse:
732778
ws = web.WebSocketResponse()
733779
await ws.prepare(request)
734780

0 commit comments

Comments
 (0)