Skip to content

Commit 135d060

Browse files
committed
create_server() now makes a strong ref to the Server object.
Fixes #81. Also makes Server objects weak-refable.
1 parent 45e89cf commit 135d060

File tree

5 files changed

+78
-10
lines changed

5 files changed

+78
-10
lines changed

tests/test_tcp.py

+45
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ssl
77
import sys
88
import threading
9+
import weakref
910

1011
from uvloop import _testbase as tb
1112

@@ -262,6 +263,50 @@ async def runner():
262263

263264
self.loop.run_until_complete(runner())
264265

266+
def test_create_server_7(self):
267+
# Test that create_server() stores a hard ref to the server object
268+
# somewhere in the loop. In asyncio it so happens that
269+
# loop.sock_accept() has a reference to the server object so it
270+
# never gets GCed.
271+
272+
class Proto(asyncio.Protocol):
273+
def connection_made(self, tr):
274+
self.tr = tr
275+
self.tr.write(b'hello')
276+
277+
async def test():
278+
port = tb.find_free_port()
279+
srv = await self.loop.create_server(Proto, '127.0.0.1', port)
280+
wsrv = weakref.ref(srv)
281+
del srv
282+
283+
gc.collect()
284+
gc.collect()
285+
gc.collect()
286+
287+
s = socket.socket(socket.AF_INET)
288+
with s:
289+
s.setblocking(False)
290+
await self.loop.sock_connect(s, ('127.0.0.1', port))
291+
d = await self.loop.sock_recv(s, 100)
292+
self.assertEqual(d, b'hello')
293+
294+
srv = wsrv()
295+
srv.close()
296+
await srv.wait_closed()
297+
del srv
298+
299+
# Let all transports shutdown.
300+
await asyncio.sleep(0.1, loop=self.loop)
301+
302+
gc.collect()
303+
gc.collect()
304+
gc.collect()
305+
306+
self.assertIsNone(wsrv())
307+
308+
self.loop.run_until_complete(test())
309+
265310
def test_create_connection_1(self):
266311
CNT = 0
267312
TOTAL_CNT = 100

uvloop/loop.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ cdef class Loop:
5252
set _queued_streams
5353
Py_ssize_t _ready_len
5454

55+
set _servers
56+
5557
object _transports
5658
dict _fd_to_reader_fileobj
5759
dict _fd_to_writer_fileobj

uvloop/loop.pyx

+3
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ cdef class Loop:
148148
# Set to True when `loop.shutdown_asyncgens` is called.
149149
self._asyncgens_shutdown_called = False
150150

151+
self._servers = set()
152+
151153
def __init__(self):
152154
self.set_debug((not sys_ignore_environment
153155
and bool(os_environ.get('PYTHONASYNCIODEBUG'))))
@@ -1509,6 +1511,7 @@ cdef class Loop:
15091511

15101512
server._add_server(tcp)
15111513

1514+
server._ref()
15121515
return server
15131516

15141517
async def create_connection(self, protocol_factory, host=None, port=None, *,

uvloop/server.pxd

+4
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@ cdef class Server:
44
list _waiters
55
int _active_count
66
Loop _loop
7+
object __weakref__
78

89
cdef _add_server(self, UVStreamServer srv)
910
cdef _wakeup(self)
1011

1112
cdef _attach(self)
1213
cdef _detach(self)
14+
15+
cdef _ref(self)
16+
cdef _unref(self)

uvloop/server.pyx

+24-10
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ cdef class Server:
2727
if self._active_count == 0 and self._servers is None:
2828
self._wakeup()
2929

30+
cdef _ref(self):
31+
# Keep the server object alive while it's not explicitly closed.
32+
self._loop._servers.add(self)
33+
34+
cdef _unref(self):
35+
self._loop._servers.discard(self)
36+
3037
# Public API
3138

3239
def __repr__(self):
@@ -40,25 +47,32 @@ cdef class Server:
4047
await waiter
4148

4249
def close(self):
50+
cdef list servers
51+
4352
if self._servers is None:
4453
return
4554

46-
cdef list servers = self._servers
47-
self._servers = None
55+
try:
56+
servers = self._servers
57+
self._servers = None
4858

49-
for server in servers:
50-
(<UVStreamServer>server)._close()
59+
for server in servers:
60+
(<UVStreamServer>server)._close()
5161

52-
if self._active_count == 0:
53-
self._wakeup()
62+
if self._active_count == 0:
63+
self._wakeup()
64+
finally:
65+
self._unref()
5466

5567
property sockets:
5668
def __get__(self):
5769
cdef list sockets = []
5870

59-
for server in self._servers:
60-
sockets.append(
61-
(<UVStreamServer>server)._get_socket()
62-
)
71+
# Guard against `self._servers is None`
72+
if self._servers:
73+
for server in self._servers:
74+
sockets.append(
75+
(<UVStreamServer>server)._get_socket()
76+
)
6377

6478
return sockets

0 commit comments

Comments
 (0)