@@ -3090,6 +3090,97 @@ def wrapper(sock):
3090
3090
with self .tcp_server (run (server )) as srv :
3091
3091
self .loop .run_until_complete (client (srv .addr ))
3092
3092
3093
+ def test_first_data_after_wakeup (self ):
3094
+ if self .implementation == 'asyncio' :
3095
+ raise unittest .SkipTest ()
3096
+
3097
+ server_context = self ._create_server_ssl_context (
3098
+ self .ONLYCERT , self .ONLYKEY )
3099
+ client_context = self ._create_client_ssl_context ()
3100
+ loop = self .loop
3101
+ this = self
3102
+ fut = self .loop .create_future ()
3103
+
3104
+ def client (sock , addr ):
3105
+ try :
3106
+ sock .connect (addr )
3107
+
3108
+ incoming = ssl .MemoryBIO ()
3109
+ outgoing = ssl .MemoryBIO ()
3110
+ sslobj = client_context .wrap_bio (incoming , outgoing )
3111
+
3112
+ # Do handshake manually so that we could collect the last piece
3113
+ while True :
3114
+ try :
3115
+ sslobj .do_handshake ()
3116
+ break
3117
+ except ssl .SSLWantReadError :
3118
+ if outgoing .pending :
3119
+ sock .send (outgoing .read ())
3120
+ incoming .write (sock .recv (65536 ))
3121
+
3122
+ # Send the first data together with the last handshake payload
3123
+ sslobj .write (b'hello' )
3124
+ sock .send (outgoing .read ())
3125
+
3126
+ while True :
3127
+ try :
3128
+ incoming .write (sock .recv (65536 ))
3129
+ self .assertEqual (sslobj .read (1024 ), b'hello' )
3130
+ break
3131
+ except ssl .SSLWantReadError :
3132
+ pass
3133
+
3134
+ sock .close ()
3135
+
3136
+ except Exception as ex :
3137
+ loop .call_soon_threadsafe (fut .set_exception , ex )
3138
+ sock .close ()
3139
+ else :
3140
+ loop .call_soon_threadsafe (fut .set_result , None )
3141
+
3142
+ class EchoProto (asyncio .Protocol ):
3143
+ def connection_made (self , tr ):
3144
+ self .tr = tr
3145
+ # manually run the coroutine, in order to avoid accidental data
3146
+ coro = loop .start_tls (
3147
+ tr , self , server_context ,
3148
+ server_side = True ,
3149
+ ssl_handshake_timeout = this .TIMEOUT ,
3150
+ )
3151
+ waiter = coro .send (None )
3152
+
3153
+ def tls_started (_ ):
3154
+ try :
3155
+ coro .send (None )
3156
+ except StopIteration as e :
3157
+ # update self.tr to SSL transport as soon as we know it
3158
+ self .tr = e .value
3159
+
3160
+ waiter .add_done_callback (tls_started )
3161
+
3162
+ def data_received (self , data ):
3163
+ # This is a dumb protocol that writes back whatever it receives
3164
+ # regardless of whether self.tr is SSL or not
3165
+ self .tr .write (data )
3166
+
3167
+ async def run_main ():
3168
+ proto = EchoProto ()
3169
+
3170
+ server = await self .loop .create_server (
3171
+ lambda : proto , '127.0.0.1' , 0 )
3172
+ addr = server .sockets [0 ].getsockname ()
3173
+
3174
+ with self .tcp_client (lambda sock : client (sock , addr ),
3175
+ timeout = self .TIMEOUT ):
3176
+ await asyncio .wait_for (fut , timeout = self .TIMEOUT )
3177
+ proto .tr .close ()
3178
+
3179
+ server .close ()
3180
+ await server .wait_closed ()
3181
+
3182
+ self .loop .run_until_complete (run_main ())
3183
+
3093
3184
3094
3185
class Test_UV_TCPSSL (_TestSSL , tb .UVTestCase ):
3095
3186
pass
0 commit comments