Skip to content

Commit 78418f7

Browse files
authored
[PR #9566/22f0831 backport][3.11] Refactor WebSocketWriter to remove high level protocol functions (#9569)
1 parent a3b8129 commit 78418f7

File tree

6 files changed

+56
-49
lines changed

6 files changed

+56
-49
lines changed

aiohttp/_websocket/writer.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434

3535

3636
class WebSocketWriter:
37+
"""WebSocket writer.
38+
39+
The writer is responsible for sending messages to the client. It is
40+
created by the protocol when a connection is established. The writer
41+
should avoid implementing any application logic and should only be
42+
concerned with the low-level details of the WebSocket protocol.
43+
"""
44+
3745
def __init__(
3846
self,
3947
protocol: BaseProtocol,
@@ -45,6 +53,7 @@ def __init__(
4553
compress: int = 0,
4654
notakeover: bool = False,
4755
) -> None:
56+
"""Initialize a WebSocket writer."""
4857
self.protocol = protocol
4958
self.transport = transport
5059
self.use_mask = use_mask
@@ -155,18 +164,6 @@ def _make_compress_obj(self, compress: int) -> ZLibCompressor:
155164
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
156165
)
157166

158-
async def pong(self, message: Union[bytes, str] = b"") -> None:
159-
"""Send pong message."""
160-
if isinstance(message, str):
161-
message = message.encode("utf-8")
162-
await self.send_frame(message, WSMsgType.PONG)
163-
164-
async def ping(self, message: Union[bytes, str] = b"") -> None:
165-
"""Send ping message."""
166-
if isinstance(message, str):
167-
message = message.encode("utf-8")
168-
await self.send_frame(message, WSMsgType.PING)
169-
170167
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
171168
"""Close the websocket, sending the specified code and message."""
172169
if isinstance(message, str):

aiohttp/client_ws.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,14 @@ def _send_heartbeat(self) -> None:
140140
self._cancel_pong_response_cb()
141141
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
142142

143+
coro = self._writer.send_frame(b"", WSMsgType.PING)
143144
if sys.version_info >= (3, 12):
144145
# Optimization for Python 3.12, try to send the ping
145146
# immediately to avoid having to schedule
146147
# the task on the event loop.
147-
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
148+
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
148149
else:
149-
ping_task = loop.create_task(self._writer.ping())
150+
ping_task = loop.create_task(coro)
150151

151152
if not ping_task.done():
152153
self._ping_task = ping_task
@@ -224,10 +225,10 @@ def exception(self) -> Optional[BaseException]:
224225
return self._exception
225226

226227
async def ping(self, message: bytes = b"") -> None:
227-
await self._writer.ping(message)
228+
await self._writer.send_frame(message, WSMsgType.PING)
228229

229230
async def pong(self, message: bytes = b"") -> None:
230-
await self._writer.pong(message)
231+
await self._writer.send_frame(message, WSMsgType.PONG)
231232

232233
async def send_frame(
233234
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None

aiohttp/web_ws.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,14 @@ def _send_heartbeat(self) -> None:
153153
self._cancel_pong_response_cb()
154154
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
155155

156+
coro = self._writer.send_frame(b"", WSMsgType.PING)
156157
if sys.version_info >= (3, 12):
157158
# Optimization for Python 3.12, try to send the ping
158159
# immediately to avoid having to schedule
159160
# the task on the event loop.
160-
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
161+
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
161162
else:
162-
ping_task = loop.create_task(self._writer.ping())
163+
ping_task = loop.create_task(coro)
163164

164165
if not ping_task.done():
165166
self._ping_task = ping_task
@@ -371,13 +372,13 @@ def exception(self) -> Optional[BaseException]:
371372
async def ping(self, message: bytes = b"") -> None:
372373
if self._writer is None:
373374
raise RuntimeError("Call .prepare() first")
374-
await self._writer.ping(message)
375+
await self._writer.send_frame(message, WSMsgType.PING)
375376

376377
async def pong(self, message: bytes = b"") -> None:
377378
# unsolicited pong
378379
if self._writer is None:
379380
raise RuntimeError("Call .prepare() first")
380-
await self._writer.pong(message)
381+
await self._writer.send_frame(message, WSMsgType.PONG)
381382

382383
async def send_frame(
383384
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None

tests/test_client_ws_functional.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import sys
3-
from typing import Any, NoReturn
3+
from typing import Any, NoReturn, Optional
44
from unittest import mock
55

66
import pytest
@@ -704,9 +704,11 @@ async def handler(request: web.Request) -> NoReturn:
704704
assert resp._conn is not None
705705
with mock.patch.object(
706706
resp._conn.transport, "write", side_effect=ClientConnectionResetError
707-
), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping:
707+
), mock.patch.object(
708+
resp._writer, "send_frame", wraps=resp._writer.send_frame
709+
) as send_frame:
708710
await resp.receive()
709-
ping_count = ping.call_count
711+
ping_count = send_frame.call_args_list.count(mock.call(b"", WSMsgType.PING))
710712
# Connection should be closed roughly after 1.5x heartbeat.
711713
await asyncio.sleep(0.2)
712714
assert ping_count == 1
@@ -842,7 +844,7 @@ async def handler(request):
842844

843845

844846
async def test_close_websocket_while_ping_inflight(
845-
aiohttp_client: AiohttpClient,
847+
aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop
846848
) -> None:
847849
"""Test closing the websocket while a ping is in-flight."""
848850
ping_received = False
@@ -866,23 +868,27 @@ async def handler(request: web.Request) -> NoReturn:
866868
await resp.send_bytes(b"ask")
867869

868870
cancelled = False
869-
ping_stated = False
870-
871-
async def delayed_ping() -> None:
872-
nonlocal cancelled, ping_stated
873-
ping_stated = True
871+
ping_started = loop.create_future()
872+
873+
async def delayed_send_frame(
874+
message: bytes, opcode: int, compress: Optional[int] = None
875+
) -> None:
876+
assert opcode == WSMsgType.PING
877+
nonlocal cancelled, ping_started
878+
ping_started.set_result(None)
874879
try:
875880
await asyncio.sleep(1)
876881
except asyncio.CancelledError:
877882
cancelled = True
878883
raise
879884

880-
with mock.patch.object(resp._writer, "ping", delayed_ping):
881-
await asyncio.sleep(0.1)
885+
with mock.patch.object(resp._writer, "send_frame", delayed_send_frame):
886+
async with async_timeout.timeout(1):
887+
await ping_started
882888

883889
await resp.close()
884890
await asyncio.sleep(0)
885-
assert ping_stated is True
891+
assert ping_started.result() is None
886892
assert cancelled is True
887893

888894

tests/test_web_websocket_functional.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ async def handler(request):
423423
ws = web.WebSocketResponse()
424424
await ws.prepare(request)
425425

426-
await ws.ping("data")
426+
await ws.ping(b"data")
427427
await ws.receive()
428428
closed.set_result(None)
429429
return ws
@@ -460,7 +460,7 @@ async def handler(request):
460460

461461
ws = await client.ws_connect("/", autoping=False)
462462

463-
await ws.ping("data")
463+
await ws.ping(b"data")
464464
msg = await ws.receive()
465465
assert msg.type == WSMsgType.PONG
466466
assert msg.data == b"data"
@@ -478,7 +478,7 @@ async def handler(request):
478478

479479
msg = await ws.receive()
480480
assert msg.type == WSMsgType.PING
481-
await ws.pong("data")
481+
await ws.pong(b"data")
482482

483483
msg = await ws.receive()
484484
assert msg.type == WSMsgType.CLOSE
@@ -493,7 +493,7 @@ async def handler(request):
493493

494494
ws = await client.ws_connect("/", autoping=False)
495495

496-
await ws.ping("data")
496+
await ws.ping(b"data")
497497
msg = await ws.receive()
498498
assert msg.type == WSMsgType.PONG
499499
assert msg.data == b"data"
@@ -741,12 +741,14 @@ async def handler(request: web.Request) -> NoReturn:
741741
with mock.patch.object(
742742
ws_server._req.transport, "write", side_effect=ConnectionResetError
743743
), mock.patch.object(
744-
ws_server._writer, "ping", wraps=ws_server._writer.ping
745-
) as ping:
744+
ws_server._writer, "send_frame", wraps=ws_server._writer.send_frame
745+
) as send_frame:
746746
try:
747747
await ws_server.receive()
748748
finally:
749-
ping_count = ping.call_count
749+
ping_count = send_frame.call_args_list.count(
750+
mock.call(b"", WSMsgType.PING)
751+
)
750752
assert False
751753

752754
app = web.Application()
@@ -990,7 +992,7 @@ async def handler(request):
990992
msg = await ws.receive()
991993
assert msg.type == WSMsgType.PING
992994
await asyncio.sleep(0)
993-
await ws.pong("data")
995+
await ws.pong(b"data")
994996

995997
msg = await ws.receive()
996998
assert msg.type == WSMsgType.CLOSE
@@ -1006,7 +1008,7 @@ async def handler(request):
10061008
ws = await client.ws_connect("/", autoping=False)
10071009

10081010
await asyncio.sleep(0)
1009-
await ws.ping("data")
1011+
await ws.ping(b"data")
10101012

10111013
msg = await ws.receive()
10121014
assert msg.type == WSMsgType.PONG
@@ -1036,7 +1038,7 @@ async def handler(request):
10361038
msg = await ws.receive()
10371039
assert msg.type == WSMsgType.PING
10381040
await asyncio.sleep(0)
1039-
await ws.pong("data")
1041+
await ws.pong(b"data")
10401042

10411043
msg = await ws.receive()
10421044
assert msg.type == WSMsgType.CLOSE
@@ -1052,7 +1054,7 @@ async def handler(request):
10521054
ws = await client.ws_connect("/", autoping=False)
10531055

10541056
await timed_out
1055-
await ws.ping("data")
1057+
await ws.ping(b"data")
10561058

10571059
msg = await ws.receive()
10581060
assert msg.type == WSMsgType.PONG

tests/test_websocket_writer.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ def writer(protocol, transport):
2929
return WebSocketWriter(protocol, transport, use_mask=False)
3030

3131

32-
async def test_pong(writer) -> None:
33-
await writer.pong()
34-
writer.transport.write.assert_called_with(b"\x8a\x00")
32+
async def test_pong(writer: WebSocketWriter) -> None:
33+
await writer.send_frame(b"", WSMsgType.PONG)
34+
writer.transport.write.assert_called_with(b"\x8a\x00") # type: ignore[attr-defined]
3535

3636

37-
async def test_ping(writer) -> None:
38-
await writer.ping()
39-
writer.transport.write.assert_called_with(b"\x89\x00")
37+
async def test_ping(writer: WebSocketWriter) -> None:
38+
await writer.send_frame(b"", WSMsgType.PING)
39+
writer.transport.write.assert_called_with(b"\x89\x00") # type: ignore[attr-defined]
4040

4141

4242
async def test_send_text(writer: WebSocketWriter) -> None:

0 commit comments

Comments
 (0)