@@ -1256,6 +1256,144 @@ async def start_server():
1256
1256
for client in clients :
1257
1257
client .stop ()
1258
1258
1259
+ def test_create_server_ssl_over_ssl (self ):
1260
+ if self .implementation == 'asyncio' :
1261
+ raise unittest .SkipTest ('asyncio does not support SSL over SSL' )
1262
+
1263
+ CNT = 0 # number of clients that were successful
1264
+ TOTAL_CNT = 25 # total number of clients that test will create
1265
+ TIMEOUT = 10.0 # timeout for this test
1266
+
1267
+ A_DATA = b'A' * 1024 * 1024
1268
+ B_DATA = b'B' * 1024 * 1024
1269
+
1270
+ sslctx_1 = self ._create_server_ssl_context (self .ONLYCERT , self .ONLYKEY )
1271
+ client_sslctx_1 = self ._create_client_ssl_context ()
1272
+ sslctx_2 = self ._create_server_ssl_context (self .ONLYCERT , self .ONLYKEY )
1273
+ client_sslctx_2 = self ._create_client_ssl_context ()
1274
+
1275
+ clients = []
1276
+
1277
+ async def handle_client (reader , writer ):
1278
+ nonlocal CNT
1279
+
1280
+ # hack reader and writer to call start_tls()
1281
+ transport = writer ._transport
1282
+ writer ._transport = None
1283
+ reader ._transport = None
1284
+
1285
+ transport = await self .loop .start_tls (
1286
+ transport , writer ._protocol , sslctx_2 , server_side = True )
1287
+
1288
+ # restore with new transport
1289
+ writer ._transport = transport
1290
+ reader ._transport = transport
1291
+
1292
+ data = await reader .readexactly (len (A_DATA ))
1293
+ self .assertEqual (data , A_DATA )
1294
+ writer .write (b'OK' )
1295
+
1296
+ data = await reader .readexactly (len (B_DATA ))
1297
+ self .assertEqual (data , B_DATA )
1298
+ writer .writelines ([b'SP' , bytearray (b'A' ), memoryview (b'M' )])
1299
+
1300
+ await writer .drain ()
1301
+ writer .close ()
1302
+
1303
+ CNT += 1
1304
+
1305
+ async def test_client (addr ):
1306
+ fut = asyncio .Future (loop = self .loop )
1307
+
1308
+ def prog (sock ):
1309
+ try :
1310
+ sock .connect (addr )
1311
+ sock .starttls (client_sslctx_1 )
1312
+
1313
+ # because wrap_socket() doesn't work correctly on
1314
+ # SSLSocket, we have to do the 2nd level SSL manually
1315
+ incoming = ssl .MemoryBIO ()
1316
+ outgoing = ssl .MemoryBIO ()
1317
+ sslobj = client_sslctx_2 .wrap_bio (incoming , outgoing )
1318
+
1319
+ def do (func ):
1320
+ while True :
1321
+ try :
1322
+ rv = func ()
1323
+ break
1324
+ except ssl .SSLWantReadError :
1325
+ if outgoing .pending :
1326
+ sock .send (outgoing .read ())
1327
+ incoming .write (sock .recv (65536 ))
1328
+ if outgoing .pending :
1329
+ sock .send (outgoing .read ())
1330
+ return rv
1331
+
1332
+ do (sslobj .do_handshake )
1333
+
1334
+ do (lambda : sslobj .write (A_DATA ))
1335
+ data = do (lambda : sslobj .read (2 ))
1336
+ self .assertEqual (data , b'OK' )
1337
+
1338
+ do (lambda : sslobj .write (B_DATA ))
1339
+ data = b''
1340
+ while data != b'SPAM' :
1341
+ data += do (lambda : sslobj .read (4 ))
1342
+ self .assertEqual (data , b'SPAM' )
1343
+
1344
+ do (sslobj .unwrap )
1345
+ sock .close ()
1346
+
1347
+ except Exception as ex :
1348
+ self .loop .call_soon_threadsafe (fut .set_exception , ex )
1349
+ else :
1350
+ self .loop .call_soon_threadsafe (fut .set_result , None )
1351
+
1352
+ client = self .tcp_client (prog )
1353
+ client .start ()
1354
+ clients .append (client )
1355
+
1356
+ await fut
1357
+
1358
+ async def start_server ():
1359
+ extras = {}
1360
+ if self .implementation != 'asyncio' or self .PY37 :
1361
+ extras = dict (ssl_handshake_timeout = 10.0 )
1362
+
1363
+ srv = await asyncio .start_server (
1364
+ handle_client ,
1365
+ '127.0.0.1' , 0 ,
1366
+ family = socket .AF_INET ,
1367
+ ssl = sslctx_1 ,
1368
+ loop = self .loop ,
1369
+ ** extras )
1370
+
1371
+ try :
1372
+ srv_socks = srv .sockets
1373
+ self .assertTrue (srv_socks )
1374
+
1375
+ addr = srv_socks [0 ].getsockname ()
1376
+
1377
+ tasks = []
1378
+ for _ in range (TOTAL_CNT ):
1379
+ tasks .append (test_client (addr ))
1380
+
1381
+ await asyncio .wait_for (
1382
+ asyncio .gather (* tasks , loop = self .loop ),
1383
+ TIMEOUT , loop = self .loop )
1384
+
1385
+ finally :
1386
+ self .loop .call_soon (srv .close )
1387
+ await srv .wait_closed ()
1388
+
1389
+ with self ._silence_eof_received_warning ():
1390
+ self .loop .run_until_complete (start_server ())
1391
+
1392
+ self .assertEqual (CNT , TOTAL_CNT )
1393
+
1394
+ for client in clients :
1395
+ client .stop ()
1396
+
1259
1397
def test_create_connection_ssl_1 (self ):
1260
1398
if self .implementation == 'asyncio' :
1261
1399
# Don't crash on asyncio errors
0 commit comments