Skip to content

Commit 01be886

Browse files
committed
Added test for SSL over SSL.
1 parent c178c6b commit 01be886

File tree

2 files changed

+140
-1
lines changed

2 files changed

+140
-1
lines changed

tests/test_tcp.py

+138
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,144 @@ async def start_server():
12561256
for client in clients:
12571257
client.stop()
12581258

1259+
def test_create_server_ssl_over_ssl(self):
1260+
if self.implementation == 'asyncio':
1261+
raise unittest.SkipTest('asyncio does not support SSL over SSL')
1262+
1263+
CNT = 0 # number of clients that were successful
1264+
TOTAL_CNT = 25 # total number of clients that test will create
1265+
TIMEOUT = 10.0 # timeout for this test
1266+
1267+
A_DATA = b'A' * 1024 * 1024
1268+
B_DATA = b'B' * 1024 * 1024
1269+
1270+
sslctx_1 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
1271+
client_sslctx_1 = self._create_client_ssl_context()
1272+
sslctx_2 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
1273+
client_sslctx_2 = self._create_client_ssl_context()
1274+
1275+
clients = []
1276+
1277+
async def handle_client(reader, writer):
1278+
nonlocal CNT
1279+
1280+
# hack reader and writer to call start_tls()
1281+
transport = writer._transport
1282+
writer._transport = None
1283+
reader._transport = None
1284+
1285+
transport = await self.loop.start_tls(
1286+
transport, writer._protocol, sslctx_2, server_side=True)
1287+
1288+
# restore with new transport
1289+
writer._transport = transport
1290+
reader._transport = transport
1291+
1292+
data = await reader.readexactly(len(A_DATA))
1293+
self.assertEqual(data, A_DATA)
1294+
writer.write(b'OK')
1295+
1296+
data = await reader.readexactly(len(B_DATA))
1297+
self.assertEqual(data, B_DATA)
1298+
writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
1299+
1300+
await writer.drain()
1301+
writer.close()
1302+
1303+
CNT += 1
1304+
1305+
async def test_client(addr):
1306+
fut = asyncio.Future(loop=self.loop)
1307+
1308+
def prog(sock):
1309+
try:
1310+
sock.connect(addr)
1311+
sock.starttls(client_sslctx_1)
1312+
1313+
# because wrap_socket() doesn't work correctly on
1314+
# SSLSocket, we have to do the 2nd level SSL manually
1315+
incoming = ssl.MemoryBIO()
1316+
outgoing = ssl.MemoryBIO()
1317+
sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
1318+
1319+
def do(func):
1320+
while True:
1321+
try:
1322+
rv = func()
1323+
break
1324+
except ssl.SSLWantReadError:
1325+
if outgoing.pending:
1326+
sock.send(outgoing.read())
1327+
incoming.write(sock.recv(65536))
1328+
if outgoing.pending:
1329+
sock.send(outgoing.read())
1330+
return rv
1331+
1332+
do(sslobj.do_handshake)
1333+
1334+
do(lambda: sslobj.write(A_DATA))
1335+
data = do(lambda: sslobj.read(2))
1336+
self.assertEqual(data, b'OK')
1337+
1338+
do(lambda: sslobj.write(B_DATA))
1339+
data = b''
1340+
while data != b'SPAM':
1341+
data += do(lambda: sslobj.read(4))
1342+
self.assertEqual(data, b'SPAM')
1343+
1344+
do(sslobj.unwrap)
1345+
sock.close()
1346+
1347+
except Exception as ex:
1348+
self.loop.call_soon_threadsafe(fut.set_exception, ex)
1349+
else:
1350+
self.loop.call_soon_threadsafe(fut.set_result, None)
1351+
1352+
client = self.tcp_client(prog)
1353+
client.start()
1354+
clients.append(client)
1355+
1356+
await fut
1357+
1358+
async def start_server():
1359+
extras = {}
1360+
if self.implementation != 'asyncio' or self.PY37:
1361+
extras = dict(ssl_handshake_timeout=10.0)
1362+
1363+
srv = await asyncio.start_server(
1364+
handle_client,
1365+
'127.0.0.1', 0,
1366+
family=socket.AF_INET,
1367+
ssl=sslctx_1,
1368+
loop=self.loop,
1369+
**extras)
1370+
1371+
try:
1372+
srv_socks = srv.sockets
1373+
self.assertTrue(srv_socks)
1374+
1375+
addr = srv_socks[0].getsockname()
1376+
1377+
tasks = []
1378+
for _ in range(TOTAL_CNT):
1379+
tasks.append(test_client(addr))
1380+
1381+
await asyncio.wait_for(
1382+
asyncio.gather(*tasks, loop=self.loop),
1383+
TIMEOUT, loop=self.loop)
1384+
1385+
finally:
1386+
self.loop.call_soon(srv.close)
1387+
await srv.wait_closed()
1388+
1389+
with self._silence_eof_received_warning():
1390+
self.loop.run_until_complete(start_server())
1391+
1392+
self.assertEqual(CNT, TOTAL_CNT)
1393+
1394+
for client in clients:
1395+
client.stop()
1396+
12591397
def test_create_connection_ssl_1(self):
12601398
if self.implementation == 'asyncio':
12611399
# Don't crash on asyncio errors

uvloop/loop.pyx

+2-1
Original file line numberDiff line numberDiff line change
@@ -1534,7 +1534,8 @@ cdef class Loop:
15341534
f'sslcontext is expected to be an instance of ssl.SSLContext, '
15351535
f'got {sslcontext!r}')
15361536

1537-
if not isinstance(transport, (TCPTransport, UnixTransport)):
1537+
if not isinstance(transport, (TCPTransport, UnixTransport,
1538+
_SSLProtocolTransport)):
15381539
raise TypeError(
15391540
f'transport {transport!r} is not supported by start_tls()')
15401541

0 commit comments

Comments
 (0)