Skip to content

Commit 82104fb

Browse files
authored
add app state check (#263)
* add app state check * fixes #246 * add test to cover the case * CRF: rename states and fix comment
1 parent 1eb6727 commit 82104fb

File tree

3 files changed

+78
-8
lines changed

3 files changed

+78
-8
lines changed

tests/test_tcp.py

+49
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,55 @@ async def client(addr):
16011601
# exception or log an error, even if the handshake failed
16021602
self.assertEqual(messages, [])
16031603

1604+
def test_ssl_handshake_connection_lost(self):
1605+
# #246: make sure that no connection_lost() is called before
1606+
# connection_made() is called first
1607+
1608+
client_sslctx = self._create_client_ssl_context()
1609+
1610+
# silence error logger
1611+
self.loop.set_exception_handler(lambda loop, ctx: None)
1612+
1613+
connection_made_called = False
1614+
connection_lost_called = False
1615+
1616+
def server(sock):
1617+
sock.recv(1024)
1618+
# break the connection during handshake
1619+
sock.close()
1620+
1621+
class ClientProto(asyncio.Protocol):
1622+
def connection_made(self, transport):
1623+
nonlocal connection_made_called
1624+
connection_made_called = True
1625+
1626+
def connection_lost(self, exc):
1627+
nonlocal connection_lost_called
1628+
connection_lost_called = True
1629+
1630+
async def client(addr):
1631+
await self.loop.create_connection(
1632+
ClientProto,
1633+
*addr,
1634+
ssl=client_sslctx,
1635+
server_hostname=''),
1636+
1637+
with self.tcp_server(server,
1638+
max_clients=1,
1639+
backlog=1) as srv:
1640+
1641+
with self.assertRaises(ConnectionResetError):
1642+
self.loop.run_until_complete(client(srv.addr))
1643+
1644+
if connection_lost_called:
1645+
if connection_made_called:
1646+
self.fail("unexpected call to connection_lost()")
1647+
else:
1648+
self.fail("unexpected call to connection_lost() without"
1649+
"calling connection_made()")
1650+
elif connection_made_called:
1651+
self.fail("unexpected call to connection_made()")
1652+
16041653
def test_ssl_connect_accepted_socket(self):
16051654
server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
16061655
server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY)

uvloop/sslproto.pxd

+17-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,22 @@ cdef enum SSLProtocolState:
66
SHUTDOWN = 4
77

88

9+
cdef enum AppProtocolState:
10+
# This tracks the state of app protocol (https://git.io/fj59P):
11+
#
12+
# INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
13+
#
14+
# * cm: connection_made()
15+
# * dr: data_received()
16+
# * er: eof_received()
17+
# * cl: connection_lost()
18+
19+
STATE_INIT = 0
20+
STATE_CON_MADE = 1
21+
STATE_EOF = 2
22+
STATE_CON_LOST = 3
23+
24+
925
cdef class _SSLProtocolTransport:
1026
cdef:
1127
object _loop
@@ -30,7 +46,6 @@ cdef class SSLProtocol:
3046
bint _app_transport_created
3147

3248
object _transport
33-
bint _call_connection_made
3449
object _ssl_handshake_timeout
3550
object _ssl_shutdown_timeout
3651

@@ -46,7 +61,7 @@ cdef class SSLProtocol:
4661
object _ssl_buffer_view
4762
SSLProtocolState _state
4863
size_t _conn_lost
49-
bint _eof_received
64+
AppProtocolState _app_state
5065

5166
bint _ssl_writing_paused
5267
bint _app_reading_paused

uvloop/sslproto.pyx

+12-6
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ cdef class SSLProtocol:
253253
self._app_transport_created = False
254254
# transport, ex: SelectorSocketTransport
255255
self._transport = None
256-
self._call_connection_made = call_connection_made
257256
self._ssl_handshake_timeout = ssl_handshake_timeout
258257
self._ssl_shutdown_timeout = ssl_shutdown_timeout
259258
# SSL and state machine
@@ -264,7 +263,10 @@ cdef class SSLProtocol:
264263
self._outgoing_read = self._outgoing.read
265264
self._state = UNWRAPPED
266265
self._conn_lost = 0 # Set when connection_lost called
267-
self._eof_received = False
266+
if call_connection_made:
267+
self._app_state = STATE_INIT
268+
else:
269+
self._app_state = STATE_CON_MADE
268270

269271
# Flow Control
270272

@@ -335,7 +337,10 @@ cdef class SSLProtocol:
335337
self._app_transport._closed = True
336338

337339
if self._state != DO_HANDSHAKE:
338-
self._loop.call_soon(self._app_protocol.connection_lost, exc)
340+
if self._app_state == STATE_CON_MADE or \
341+
self._app_state == STATE_EOF:
342+
self._app_state = STATE_CON_LOST
343+
self._loop.call_soon(self._app_protocol.connection_lost, exc)
339344
self._set_state(UNWRAPPED)
340345
self._transport = None
341346
self._app_transport = None
@@ -518,7 +523,8 @@ cdef class SSLProtocol:
518523
cipher=sslobj.cipher(),
519524
compression=sslobj.compression(),
520525
ssl_object=sslobj)
521-
if self._call_connection_made:
526+
if self._app_state == STATE_INIT:
527+
self._app_state = STATE_CON_MADE
522528
self._app_protocol.connection_made(self._get_app_transport())
523529
self._wakeup_waiter()
524530
self._do_read()
@@ -735,8 +741,8 @@ cdef class SSLProtocol:
735741

736742
cdef _call_eof_received(self):
737743
try:
738-
if not self._eof_received:
739-
self._eof_received = True
744+
if self._app_state == STATE_CON_MADE:
745+
self._app_state = STATE_EOF
740746
keep_open = self._app_protocol.eof_received()
741747
if keep_open:
742748
aio_logger.warning('returning true from eof_received() '

0 commit comments

Comments
 (0)