@@ -1655,6 +1655,72 @@ async def client(addr):
1655
1655
self .loop .run_until_complete (
1656
1656
asyncio .wait_for (client (srv .addr ), loop = self .loop , timeout = 10 ))
1657
1657
1658
+ def test_create_connection_memory_leak (self ):
1659
+ if self .implementation == 'asyncio' :
1660
+ raise unittest .SkipTest ()
1661
+
1662
+ HELLO_MSG = b'1' * self .PAYLOAD_SIZE
1663
+
1664
+ server_context = self ._create_server_ssl_context (
1665
+ self .ONLYCERT , self .ONLYKEY )
1666
+ client_context = self ._create_client_ssl_context ()
1667
+
1668
+ def serve (sock ):
1669
+ sock .settimeout (self .TIMEOUT )
1670
+
1671
+ sock .starttls (server_context , server_side = True )
1672
+
1673
+ sock .sendall (b'O' )
1674
+ data = sock .recv_all (len (HELLO_MSG ))
1675
+ self .assertEqual (len (data ), len (HELLO_MSG ))
1676
+
1677
+ sock .unwrap ()
1678
+ sock .close ()
1679
+
1680
+ class ClientProto (asyncio .Protocol ):
1681
+ def __init__ (self , on_data , on_eof ):
1682
+ self .on_data = on_data
1683
+ self .on_eof = on_eof
1684
+ self .con_made_cnt = 0
1685
+
1686
+ def connection_made (proto , tr ):
1687
+ # XXX: We assume user stores the transport in protocol
1688
+ proto .tr = tr
1689
+ proto .con_made_cnt += 1
1690
+ # Ensure connection_made gets called only once.
1691
+ self .assertEqual (proto .con_made_cnt , 1 )
1692
+
1693
+ def data_received (self , data ):
1694
+ self .on_data .set_result (data )
1695
+
1696
+ def eof_received (self ):
1697
+ self .on_eof .set_result (True )
1698
+
1699
+ async def client (addr ):
1700
+ await asyncio .sleep (0.5 , loop = self .loop )
1701
+
1702
+ on_data = self .loop .create_future ()
1703
+ on_eof = self .loop .create_future ()
1704
+
1705
+ tr , proto = await self .loop .create_connection (
1706
+ lambda : ClientProto (on_data , on_eof ), * addr ,
1707
+ ssl = client_context )
1708
+
1709
+ self .assertEqual (await on_data , b'O' )
1710
+ tr .write (HELLO_MSG )
1711
+ await on_eof
1712
+
1713
+ tr .close ()
1714
+
1715
+ with self .tcp_server (serve , timeout = self .TIMEOUT ) as srv :
1716
+ self .loop .run_until_complete (
1717
+ asyncio .wait_for (client (srv .addr ), loop = self .loop , timeout = 10 ))
1718
+
1719
+ # No garbage is left for SSL client from loop.create_connection, even
1720
+ # if user stores the SSLTransport in corresponding protocol instance
1721
+ client_context = weakref .ref (client_context )
1722
+ self .assertIsNone (client_context ())
1723
+
1658
1724
def test_start_tls_client_buf_proto_1 (self ):
1659
1725
if self .implementation == 'asyncio' :
1660
1726
raise unittest .SkipTest ()
0 commit comments