Skip to content

Commit cb11516

Browse files
committed
Stop requiring loop in WebSocketCommonProtocol.
Change tests to avoid passing a loop argument. Fix #988.
1 parent cc1254b commit cb11516

File tree

5 files changed

+59
-39
lines changed

5 files changed

+59
-39
lines changed

src/websockets/legacy/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,10 +514,11 @@ def __init__(
514514
legacy_recv: bool = kwargs.pop("legacy_recv", False)
515515

516516
# Backwards compatibility: the loop parameter used to be supported.
517-
loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None)
518-
if loop is None:
517+
_loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None)
518+
if _loop is None:
519519
loop = asyncio.get_event_loop()
520520
else:
521+
loop = _loop
521522
warnings.warn("remove loop argument", DeprecationWarning)
522523

523524
wsuri = parse_uri(uri)
@@ -543,7 +544,7 @@ def __init__(
543544
max_queue=max_queue,
544545
read_limit=read_limit,
545546
write_limit=write_limit,
546-
loop=loop,
547+
loop=_loop,
547548
host=wsuri.host,
548549
port=wsuri.port,
549550
secure=wsuri.secure,

src/websockets/legacy/protocol.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def __init__(
124124
if close_timeout is None:
125125
close_timeout = timeout
126126

127+
# Backwards compatibility: the loop parameter used to be supported.
128+
if loop is None:
129+
loop = asyncio.get_event_loop()
130+
else:
131+
warnings.warn("remove loop argument", DeprecationWarning)
132+
127133
self.ping_interval = ping_interval
128134
self.ping_timeout = ping_timeout
129135
self.close_timeout = close_timeout
@@ -145,8 +151,6 @@ def __init__(
145151
# Track if DEBUG is enabled. Shortcut logging calls if it isn't.
146152
self.debug = logger.isEnabledFor(logging.DEBUG)
147153

148-
assert loop is not None
149-
# Remove when dropping Python < 3.10 - use get_running_loop instead.
150154
self.loop = loop
151155

152156
self._host = host

src/websockets/legacy/server.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,10 +1024,11 @@ def __init__(
10241024
legacy_recv: bool = kwargs.pop("legacy_recv", False)
10251025

10261026
# Backwards compatibility: the loop parameter used to be supported.
1027-
loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None)
1028-
if loop is None:
1027+
_loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None)
1028+
if _loop is None:
10291029
loop = asyncio.get_event_loop()
10301030
else:
1031+
loop = _loop
10311032
warnings.warn("remove loop argument", DeprecationWarning)
10321033

10331034
ws_server = WebSocketServer(logger=logger, loop=loop)
@@ -1053,7 +1054,7 @@ def __init__(
10531054
max_queue=max_queue,
10541055
read_limit=read_limit,
10551056
write_limit=write_limit,
1056-
loop=loop,
1057+
loop=_loop,
10571058
legacy_recv=legacy_recv,
10581059
origins=origins,
10591060
extensions=extensions,

tests/legacy/test_client_server.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,17 @@ def start_server(self, deprecation_warnings=None, **kwargs):
213213
kwargs.setdefault("compression", None)
214214
# Disable pings by default in tests.
215215
kwargs.setdefault("ping_interval", None)
216-
# Python 3.10 dislikes not having a running event loop
217-
if sys.version_info[:2] >= (3, 10): # pragma: no cover
218-
kwargs.setdefault("loop", self.loop)
219216

220217
with warnings.catch_warnings(record=True) as recorded_warnings:
221218
start_server = serve(handler, "localhost", 0, **kwargs)
222219
self.server = self.loop.run_until_complete(start_server)
223220

224221
expected_warnings = [] if deprecation_warnings is None else deprecation_warnings
225-
if sys.version_info[:2] >= (3, 10): # pragma: no cover
226-
expected_warnings += ["remove loop argument"]
222+
if (
223+
sys.version_info[:2] >= (3, 10)
224+
and "remove loop argument" not in expected_warnings
225+
): # pragma: no cover
226+
expected_warnings += ["There is no current event loop"]
227227
self.assertDeprecationWarnings(recorded_warnings, expected_warnings)
228228

229229
def start_redirecting_server(
@@ -234,10 +234,6 @@ def start_redirecting_server(
234234
deprecation_warnings=None,
235235
**kwargs,
236236
):
237-
# Python 3.10 dislikes not having a running event loop
238-
if sys.version_info[:2] >= (3, 10): # pragma: no cover
239-
kwargs.setdefault("loop", self.loop)
240-
241237
async def process_request(path, headers):
242238
server_uri = get_server_uri(self.server, self.secure, path)
243239
if force_insecure:
@@ -259,8 +255,11 @@ async def process_request(path, headers):
259255
self.redirecting_server = self.loop.run_until_complete(start_server)
260256

261257
expected_warnings = [] if deprecation_warnings is None else deprecation_warnings
262-
if sys.version_info[:2] >= (3, 10): # pragma: no cover
263-
expected_warnings += ["remove loop argument"]
258+
if (
259+
sys.version_info[:2] >= (3, 10)
260+
and "remove loop argument" not in expected_warnings
261+
): # pragma: no cover
262+
expected_warnings += ["There is no current event loop"]
264263
self.assertDeprecationWarnings(recorded_warnings, expected_warnings)
265264

266265
def start_client(
@@ -270,9 +269,6 @@ def start_client(
270269
kwargs.setdefault("compression", None)
271270
# Disable pings by default in tests.
272271
kwargs.setdefault("ping_interval", None)
273-
# Python 3.10 dislikes not having a running event loop
274-
if sys.version_info[:2] >= (3, 10): # pragma: no cover
275-
kwargs.setdefault("loop", self.loop)
276272

277273
secure = kwargs.get("ssl") is not None
278274
try:
@@ -286,8 +282,11 @@ def start_client(
286282
self.client = self.loop.run_until_complete(start_client)
287283

288284
expected_warnings = [] if deprecation_warnings is None else deprecation_warnings
289-
if sys.version_info[:2] >= (3, 10): # pragma: no cover
290-
expected_warnings += ["remove loop argument"]
285+
if (
286+
sys.version_info[:2] >= (3, 10)
287+
and "remove loop argument" not in expected_warnings
288+
): # pragma: no cover
289+
expected_warnings += ["There is no current event loop"]
291290
self.assertDeprecationWarnings(recorded_warnings, expected_warnings)
292291

293292
def stop_client(self):
@@ -409,8 +408,6 @@ def test_infinite_redirect(self):
409408
with temp_test_redirecting_server(
410409
self,
411410
http.HTTPStatus.FOUND,
412-
loop=self.loop,
413-
deprecation_warnings=["remove loop argument"],
414411
):
415412
self.server = self.redirecting_server
416413
with self.assertRaises(InvalidHandshake):
@@ -430,7 +427,7 @@ def test_redirect_missing_location(self):
430427
with temp_test_client(self):
431428
self.fail("Did not raise") # pragma: no cover
432429

433-
def test_explicit_event_loop(self):
430+
def test_loop_backwards_compatibility(self):
434431
with self.temp_server(
435432
loop=self.loop, deprecation_warnings=["remove loop argument"]
436433
):

tests/legacy/test_protocol.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import contextlib
3+
import sys
34
import unittest
45
import unittest.mock
56
import warnings
@@ -86,8 +87,9 @@ class CommonTests:
8687

8788
def setUp(self):
8889
super().setUp()
89-
# Disable pings to make it easier to test what frames are sent exactly.
90-
self.protocol = WebSocketCommonProtocol(ping_interval=None, loop=self.loop)
90+
with warnings.catch_warnings(record=True):
91+
# Disable pings to make it easier to test what frames are sent exactly.
92+
self.protocol = WebSocketCommonProtocol(ping_interval=None)
9193
self.transport = TransportMock()
9294
self.transport.setup_mock(self.loop, self.protocol)
9395

@@ -309,14 +311,29 @@ def assertCompletesWithin(self, min_time, max_time):
309311
# Test constructor.
310312

311313
def test_timeout_backwards_compatibility(self):
312-
with warnings.catch_warnings(record=True) as recorded_warnings:
313-
protocol = WebSocketCommonProtocol(timeout=5, loop=self.loop)
314+
with warnings.catch_warnings(record=True) as recorded:
315+
protocol = WebSocketCommonProtocol(timeout=5)
314316

315317
self.assertEqual(protocol.close_timeout, 5)
316318

317-
self.assertDeprecationWarnings(
318-
recorded_warnings, ["rename timeout to close_timeout"]
319-
)
319+
expected = ["rename timeout to close_timeout"]
320+
if sys.version_info[:2] >= (3, 10): # pragma: no cover
321+
expected += ["There is no current event loop"]
322+
323+
self.assertDeprecationWarnings(recorded, expected)
324+
325+
def test_loop_backwards_compatibility(self):
326+
loop = asyncio.new_event_loop()
327+
self.addCleanup(loop.close)
328+
329+
with warnings.catch_warnings(record=True) as recorded:
330+
protocol = WebSocketCommonProtocol(loop=loop)
331+
332+
self.assertEqual(protocol.loop, loop)
333+
334+
expected = ["remove loop argument"]
335+
336+
self.assertDeprecationWarnings(recorded, expected)
320337

321338
# Test public attributes.
322339

@@ -1116,11 +1133,11 @@ def restart_protocol_with_keepalive_ping(
11161133
self.transport.close()
11171134
self.loop.run_until_complete(self.protocol.close())
11181135
# copied from setUp, but enables keepalive pings
1119-
self.protocol = WebSocketCommonProtocol(
1120-
ping_interval=ping_interval,
1121-
ping_timeout=ping_timeout,
1122-
loop=self.loop,
1123-
)
1136+
with warnings.catch_warnings(record=True):
1137+
self.protocol = WebSocketCommonProtocol(
1138+
ping_interval=ping_interval,
1139+
ping_timeout=ping_timeout,
1140+
)
11241141
self.transport = TransportMock()
11251142
self.transport.setup_mock(self.loop, self.protocol)
11261143
self.protocol.is_client = initial_protocol.is_client

0 commit comments

Comments
 (0)