Skip to content

Commit 6476aad

Browse files
fantix1st1
authored andcommitted
fix missing data on EOF in flushing
* when EOF is received and data is still pending in incoming buffer, the data will be lost before this fix * also removed sleep from a recent-written test
1 parent 695a520 commit 6476aad

File tree

3 files changed

+47
-71
lines changed

3 files changed

+47
-71
lines changed

Diff for: tests/test_tcp.py

+29-41
Original file line numberDiff line numberDiff line change
@@ -2606,35 +2606,6 @@ def server(sock):
26062606
self.assertEqual(len(data), CHUNK * SIZE)
26072607
sock.close()
26082608

2609-
def openssl_server(sock):
2610-
conn = openssl_ssl.Connection(sslctx_openssl, sock)
2611-
conn.set_accept_state()
2612-
2613-
while True:
2614-
try:
2615-
data = conn.recv(16384)
2616-
self.assertEqual(data, b'ping')
2617-
break
2618-
except openssl_ssl.WantReadError:
2619-
pass
2620-
2621-
# use renegotiation to queue data in peer _write_backlog
2622-
conn.renegotiate()
2623-
conn.send(b'pong')
2624-
2625-
data_size = 0
2626-
while True:
2627-
try:
2628-
chunk = conn.recv(16384)
2629-
if not chunk:
2630-
break
2631-
data_size += len(chunk)
2632-
except openssl_ssl.WantReadError:
2633-
pass
2634-
except openssl_ssl.ZeroReturnError:
2635-
break
2636-
self.assertEqual(data_size, CHUNK * SIZE)
2637-
26382609
def run(meth):
26392610
def wrapper(sock):
26402611
try:
@@ -2652,12 +2623,18 @@ async def client(addr):
26522623
*addr,
26532624
ssl=client_sslctx,
26542625
server_hostname='')
2626+
sslprotocol = writer.get_extra_info('uvloop.sslproto')
26552627
writer.write(b'ping')
26562628
data = await reader.readexactly(4)
26572629
self.assertEqual(data, b'pong')
2630+
2631+
sslprotocol.pause_writing()
26582632
for _ in range(SIZE):
26592633
writer.write(b'x' * CHUNK)
2634+
26602635
writer.close()
2636+
sslprotocol.resume_writing()
2637+
26612638
await self.wait_closed(writer)
26622639
try:
26632640
data = await reader.read()
@@ -2669,9 +2646,6 @@ async def client(addr):
26692646
with self.tcp_server(run(server)) as srv:
26702647
self.loop.run_until_complete(client(srv.addr))
26712648

2672-
with self.tcp_server(run(openssl_server)) as srv:
2673-
self.loop.run_until_complete(client(srv.addr))
2674-
26752649
def test_remote_shutdown_receives_trailing_data(self):
26762650
if self.implementation == 'asyncio':
26772651
raise unittest.SkipTest()
@@ -2892,20 +2866,26 @@ async def client(addr, ctx):
28922866
self.assertIsNone(ctx())
28932867

28942868
def test_shutdown_timeout_handler_not_set(self):
2869+
if self.implementation == 'asyncio':
2870+
# asyncio cannot receive EOF after resume_reading()
2871+
raise unittest.SkipTest()
2872+
28952873
loop = self.loop
2874+
eof = asyncio.Event()
2875+
extra = None
28962876

28972877
def server(sock):
28982878
sslctx = self._create_server_ssl_context(self.ONLYCERT,
28992879
self.ONLYKEY)
29002880
sock = sslctx.wrap_socket(sock, server_side=True)
29012881
sock.send(b'hello')
29022882
assert sock.recv(1024) == b'world'
2903-
time.sleep(0.1)
2904-
sock.send(b'extra bytes' * 1)
2883+
sock.send(b'extra bytes')
29052884
# sending EOF here
29062885
sock.shutdown(socket.SHUT_WR)
2886+
loop.call_soon_threadsafe(eof.set)
29072887
# make sure we have enough time to reproduce the issue
2908-
time.sleep(0.1)
2888+
assert sock.recv(1024) == b''
29092889
sock.close()
29102890

29112891
class Protocol(asyncio.Protocol):
@@ -2917,20 +2897,28 @@ def connection_made(self, transport):
29172897
self.transport = transport
29182898

29192899
def data_received(self, data):
2920-
self.transport.write(b'world')
2921-
# pause reading would make incoming data stay in the sslobj
2922-
self.transport.pause_reading()
2923-
# resume for AIO to pass
2924-
loop.call_later(0.2, self.transport.resume_reading)
2900+
if data == b'hello':
2901+
self.transport.write(b'world')
2902+
# pause reading would make incoming data stay in the sslobj
2903+
self.transport.pause_reading()
2904+
else:
2905+
nonlocal extra
2906+
extra = data
29252907

29262908
def connection_lost(self, exc):
2927-
self.fut.set_result(None)
2909+
if exc is None:
2910+
self.fut.set_result(None)
2911+
else:
2912+
self.fut.set_exception(exc)
29282913

29292914
async def client(addr):
29302915
ctx = self._create_client_ssl_context()
29312916
tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
2917+
await eof.wait()
2918+
tr.resume_reading()
29322919
await pr.fut
29332920
tr.close()
2921+
assert extra == b'extra bytes'
29342922

29352923
with self.tcp_server(server) as srv:
29362924
loop.run_until_complete(client(srv.addr))

Diff for: uvloop/sslproto.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ cdef class SSLProtocol:
6565

6666
bint _ssl_writing_paused
6767
bint _app_reading_paused
68+
bint _eof_received
6869

6970
size_t _incoming_high_water
7071
size_t _incoming_low_water

Diff for: uvloop/sslproto.pyx

+17-30
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ cdef class SSLProtocol:
278278
self._incoming_high_water = 0
279279
self._incoming_low_water = 0
280280
self._set_read_buffer_limits()
281+
self._eof_received = False
281282

282283
self._app_writing_paused = False
283284
self._outgoing_high_water = 0
@@ -391,6 +392,7 @@ cdef class SSLProtocol:
391392
will close itself. If it returns a true value, closing the
392393
transport is up to the protocol.
393394
"""
395+
self._eof_received = True
394396
try:
395397
if self._loop.get_debug():
396398
aio_logger.debug("%r received EOF", self)
@@ -400,9 +402,10 @@ cdef class SSLProtocol:
400402

401403
elif self._state == WRAPPED:
402404
self._set_state(FLUSHING)
403-
self._do_write()
404-
self._set_state(SHUTDOWN)
405-
self._do_shutdown()
405+
if self._app_reading_paused:
406+
return True
407+
else:
408+
self._do_flush()
406409

407410
elif self._state == FLUSHING:
408411
self._do_write()
@@ -412,11 +415,14 @@ cdef class SSLProtocol:
412415
elif self._state == SHUTDOWN:
413416
self._do_shutdown()
414417

415-
finally:
418+
except Exception:
416419
self._transport.close()
420+
raise
417421

418422
cdef _get_extra_info(self, name, default=None):
419-
if name in self._extra:
423+
if name == 'uvloop.sslproto':
424+
return self
425+
elif name in self._extra:
420426
return self._extra[name]
421427
elif self._transport is not None:
422428
return self._transport.get_extra_info(name, default)
@@ -555,33 +561,14 @@ cdef class SSLProtocol:
555561
aio_TimeoutError('SSL shutdown timed out'))
556562

557563
cdef _do_flush(self):
558-
if self._write_backlog:
559-
try:
560-
while True:
561-
# data is discarded when FLUSHING
562-
chunk_size = len(self._sslobj_read(SSL_READ_MAX_SIZE))
563-
if not chunk_size:
564-
# close_notify
565-
break
566-
except ssl_SSLAgainErrors as exc:
567-
pass
568-
except ssl_SSLError as exc:
569-
self._on_shutdown_complete(exc)
570-
return
571-
572-
try:
573-
self._do_write()
574-
except Exception as exc:
575-
self._on_shutdown_complete(exc)
576-
return
577-
578-
if not self._write_backlog:
579-
self._set_state(SHUTDOWN)
580-
self._do_shutdown()
564+
self._do_read()
565+
self._set_state(SHUTDOWN)
566+
self._do_shutdown()
581567

582568
cdef _do_shutdown(self):
583569
try:
584-
self._sslobj.unwrap()
570+
if not self._eof_received:
571+
self._sslobj.unwrap()
585572
except ssl_SSLAgainErrors as exc:
586573
self._process_outgoing()
587574
except ssl_SSLError as exc:
@@ -655,7 +642,7 @@ cdef class SSLProtocol:
655642
# Incoming flow
656643

657644
cdef _do_read(self):
658-
if self._state != WRAPPED:
645+
if self._state != WRAPPED and self._state != FLUSHING:
659646
return
660647
try:
661648
if not self._app_reading_paused:

0 commit comments

Comments
 (0)