@@ -1891,15 +1891,12 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
1891
1891
let flush_read_disabled = self . gossip_processing_backlog_lifted . swap ( false , Ordering :: Relaxed ) ;
1892
1892
1893
1893
let mut peers_to_disconnect = HashMap :: new ( ) ;
1894
- let mut events_generated = self . message_handler . chan_handler . get_and_clear_pending_msg_events ( ) ;
1895
- events_generated. append ( & mut self . message_handler . route_handler . get_and_clear_pending_msg_events ( ) ) ;
1896
-
1897
1894
{
1898
- // TODO: There are some DoS attacks here where you can flood someone's outbound send
1899
- // buffer by doing things like announcing channels on another node. We should be willing to
1900
- // drop optional-ish messages when send buffers get full!
1901
-
1902
1895
let peers_lock = self . peers . read ( ) . unwrap ( ) ;
1896
+
1897
+ let mut events_generated = self . message_handler . chan_handler . get_and_clear_pending_msg_events ( ) ;
1898
+ events_generated. append ( & mut self . message_handler . route_handler . get_and_clear_pending_msg_events ( ) ) ;
1899
+
1903
1900
let peers = & * peers_lock;
1904
1901
macro_rules! get_peer_for_forwarding {
1905
1902
( $node_id: expr) => {
@@ -2520,12 +2517,11 @@ mod tests {
2520
2517
2521
2518
use crate :: prelude:: * ;
2522
2519
use crate :: sync:: { Arc , Mutex } ;
2523
- use core:: convert:: Infallible ;
2524
- use core:: sync:: atomic:: { AtomicBool , Ordering } ;
2520
+ use core:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
2525
2521
2526
2522
#[ derive( Clone ) ]
2527
2523
struct FileDescriptor {
2528
- fd : u16 ,
2524
+ fd : u32 ,
2529
2525
outbound_data : Arc < Mutex < Vec < u8 > > > ,
2530
2526
disconnect : Arc < AtomicBool > ,
2531
2527
}
@@ -2560,24 +2556,44 @@ mod tests {
2560
2556
2561
2557
struct TestCustomMessageHandler {
2562
2558
features : InitFeatures ,
2559
+ peer_counter : AtomicUsize ,
2560
+ send_messages : Option < PublicKey > ,
2561
+ }
2562
+
2563
+ impl crate :: ln:: wire:: Type for u64 {
2564
+ fn type_id ( & self ) -> u16 { 4242 }
2563
2565
}
2564
2566
2565
2567
impl wire:: CustomMessageReader for TestCustomMessageHandler {
2566
- type CustomMessage = Infallible ;
2567
- fn read < R : io:: Read > ( & self , _: u16 , _: & mut R ) -> Result < Option < Self :: CustomMessage > , msgs:: DecodeError > {
2568
- Ok ( None )
2568
+ type CustomMessage = u64 ;
2569
+ fn read < R : io:: Read > ( & self , msg_type : u16 , reader : & mut R ) -> Result < Option < Self :: CustomMessage > , msgs:: DecodeError > {
2570
+ assert ! ( self . send_messages. is_some( ) ) ;
2571
+ assert_eq ! ( msg_type, 4242 ) ;
2572
+ let mut msg = [ 0u8 ; 8 ] ;
2573
+ reader. read_exact ( & mut msg) . unwrap ( ) ;
2574
+ Ok ( Some ( u64:: from_be_bytes ( msg) ) )
2569
2575
}
2570
2576
}
2571
2577
2572
2578
impl CustomMessageHandler for TestCustomMessageHandler {
2573
- fn handle_custom_message ( & self , _: Infallible , _: & PublicKey ) -> Result < ( ) , LightningError > {
2574
- unreachable ! ( ) ;
2579
+ fn handle_custom_message ( & self , msg : u64 , _: & PublicKey ) -> Result < ( ) , LightningError > {
2580
+ assert_eq ! ( self . peer_counter. load( Ordering :: Acquire ) as u64 , msg) ;
2581
+ Ok ( ( ) )
2575
2582
}
2576
2583
2577
- fn get_and_clear_pending_msg ( & self ) -> Vec < ( PublicKey , Self :: CustomMessage ) > { Vec :: new ( ) }
2584
+ fn get_and_clear_pending_msg ( & self ) -> Vec < ( PublicKey , Self :: CustomMessage ) > {
2585
+ if let Some ( peer_node_id) = & self . send_messages {
2586
+ vec ! [ ( * peer_node_id, self . peer_counter. load( Ordering :: Acquire ) as u64 ) ; 1000 ]
2587
+ } else { Vec :: new ( ) }
2588
+ }
2578
2589
2579
- fn peer_disconnected ( & self , _: & PublicKey ) { }
2580
- fn peer_connected ( & self , _: & PublicKey , _: & msgs:: Init , _: bool ) -> Result < ( ) , ( ) > { Ok ( ( ) ) }
2590
+ fn peer_disconnected ( & self , _: & PublicKey ) {
2591
+ self . peer_counter . fetch_sub ( 1 , Ordering :: AcqRel ) ;
2592
+ }
2593
+ fn peer_connected ( & self , _: & PublicKey , _: & msgs:: Init , _: bool ) -> Result < ( ) , ( ) > {
2594
+ self . peer_counter . fetch_add ( 2 , Ordering :: AcqRel ) ;
2595
+ Ok ( ( ) )
2596
+ }
2581
2597
2582
2598
fn provided_node_features ( & self ) -> NodeFeatures { NodeFeatures :: empty ( ) }
2583
2599
@@ -2600,7 +2616,9 @@ mod tests {
2600
2616
chan_handler : test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) ,
2601
2617
logger : test_utils:: TestLogger :: new ( ) ,
2602
2618
routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2603
- custom_handler : TestCustomMessageHandler { features } ,
2619
+ custom_handler : TestCustomMessageHandler {
2620
+ features, peer_counter : AtomicUsize :: new ( 0 ) , send_messages : None ,
2621
+ } ,
2604
2622
node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
2605
2623
}
2606
2624
) ;
@@ -2623,7 +2641,9 @@ mod tests {
2623
2641
chan_handler : test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) ,
2624
2642
logger : test_utils:: TestLogger :: new ( ) ,
2625
2643
routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2626
- custom_handler : TestCustomMessageHandler { features } ,
2644
+ custom_handler : TestCustomMessageHandler {
2645
+ features, peer_counter : AtomicUsize :: new ( 0 ) , send_messages : None ,
2646
+ } ,
2627
2647
node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
2628
2648
}
2629
2649
) ;
@@ -2643,7 +2663,9 @@ mod tests {
2643
2663
chan_handler : test_utils:: TestChannelMessageHandler :: new ( network) ,
2644
2664
logger : test_utils:: TestLogger :: new ( ) ,
2645
2665
routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2646
- custom_handler : TestCustomMessageHandler { features } ,
2666
+ custom_handler : TestCustomMessageHandler {
2667
+ features, peer_counter : AtomicUsize :: new ( 0 ) , send_messages : None ,
2668
+ } ,
2647
2669
node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
2648
2670
}
2649
2671
) ;
@@ -3191,4 +3213,100 @@ mod tests {
3191
3213
thread_c. join ( ) . unwrap ( ) ;
3192
3214
assert ! ( cfg[ 0 ] . chan_handler. message_fetch_counter. load( Ordering :: Acquire ) >= 1 ) ;
3193
3215
}
3216
+
3217
+ #[ test]
3218
+ #[ cfg( feature = "std" ) ]
3219
+ fn test_rapid_connect_events_order_multithreaded ( ) {
3220
+ // Previously, outbound messages held in `process_events` could race with peer
3221
+ // disconnection, allowing a message intended for a peer before disconnection to be sent
3222
+ // to the same peer after disconnection. Here we stress the handling of such messages by
3223
+ // connecting two peers repeatedly in a loop with a `CustomMessageHandler` set to stream
3224
+ // custom messages with a "connection id" to each other. That "connection id" (just the
3225
+ // number of reconnections seen) should always line up across both peers, which we assert
3226
+ // in the message handler.
3227
+ let mut cfg = create_peermgr_cfgs ( 2 ) ;
3228
+ cfg[ 0 ] . custom_handler . send_messages =
3229
+ Some ( cfg[ 1 ] . node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ) ;
3230
+ cfg[ 1 ] . custom_handler . send_messages =
3231
+ Some ( cfg[ 1 ] . node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ) ;
3232
+ let cfg = Arc :: new ( cfg) ;
3233
+ // Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }.
3234
+ let mut peers = create_network ( 2 , unsafe { & * ( & * cfg as * const _ ) as & ' static _ } ) ;
3235
+ let peer_a = Arc :: new ( peers. pop ( ) . unwrap ( ) ) ;
3236
+ let peer_b = Arc :: new ( peers. pop ( ) . unwrap ( ) ) ;
3237
+
3238
+ let exit_flag = Arc :: new ( AtomicBool :: new ( false ) ) ;
3239
+ macro_rules! spawn_thread { ( $id: expr) => { {
3240
+ let thread_peer_a = Arc :: clone( & peer_a) ;
3241
+ let thread_peer_b = Arc :: clone( & peer_b) ;
3242
+ let thread_exit = Arc :: clone( & exit_flag) ;
3243
+ std:: thread:: spawn( move || {
3244
+ let id_a = thread_peer_a. node_signer. get_node_id( Recipient :: Node ) . unwrap( ) ;
3245
+ let mut fd_a = FileDescriptor {
3246
+ fd: $id, outbound_data: Arc :: new( Mutex :: new( Vec :: new( ) ) ) ,
3247
+ disconnect: Arc :: new( AtomicBool :: new( false ) ) ,
3248
+ } ;
3249
+ let addr_a = SocketAddress :: TcpIpV4 { addr: [ 127 , 0 , 0 , 1 ] , port: 1000 } ;
3250
+ let mut fd_b = FileDescriptor {
3251
+ fd: $id, outbound_data: Arc :: new( Mutex :: new( Vec :: new( ) ) ) ,
3252
+ disconnect: Arc :: new( AtomicBool :: new( false ) ) ,
3253
+ } ;
3254
+ let addr_b = SocketAddress :: TcpIpV4 { addr: [ 127 , 0 , 0 , 1 ] , port: 1001 } ;
3255
+ let initial_data = thread_peer_b. new_outbound_connection( id_a, fd_b. clone( ) , Some ( addr_a. clone( ) ) ) . unwrap( ) ;
3256
+ thread_peer_a. new_inbound_connection( fd_a. clone( ) , Some ( addr_b. clone( ) ) ) . unwrap( ) ;
3257
+ if thread_peer_a. read_event( & mut fd_a, & initial_data) . is_err( ) {
3258
+ thread_peer_b. socket_disconnected( & fd_b) ;
3259
+ return ;
3260
+ }
3261
+
3262
+ loop {
3263
+ if thread_exit. load( Ordering :: Relaxed ) {
3264
+ thread_peer_a. socket_disconnected( & fd_a) ;
3265
+ thread_peer_b. socket_disconnected( & fd_b) ;
3266
+ return ;
3267
+ }
3268
+ if fd_a. disconnect. load( Ordering :: Relaxed ) { return ; }
3269
+ if fd_b. disconnect. load( Ordering :: Relaxed ) { return ; }
3270
+
3271
+ let data_a = fd_a. outbound_data. lock( ) . unwrap( ) . split_off( 0 ) ;
3272
+ if !data_a. is_empty( ) {
3273
+ if thread_peer_b. read_event( & mut fd_b, & data_a) . is_err( ) {
3274
+ thread_peer_a. socket_disconnected( & fd_a) ;
3275
+ return ;
3276
+ }
3277
+ }
3278
+
3279
+ let data_b = fd_b. outbound_data. lock( ) . unwrap( ) . split_off( 0 ) ;
3280
+ if !data_b. is_empty( ) {
3281
+ if thread_peer_a. read_event( & mut fd_a, & data_b) . is_err( ) {
3282
+ thread_peer_b. socket_disconnected( & fd_b) ;
3283
+ return ;
3284
+ }
3285
+ }
3286
+ }
3287
+ } )
3288
+ } } }
3289
+
3290
+ let mut threads = Vec :: new ( ) ;
3291
+ {
3292
+ let thread_peer_a = Arc :: clone ( & peer_a) ;
3293
+ let thread_peer_b = Arc :: clone ( & peer_b) ;
3294
+ let thread_exit = Arc :: clone ( & exit_flag) ;
3295
+ threads. push ( std:: thread:: spawn ( move || {
3296
+ while !thread_exit. load ( Ordering :: Relaxed ) {
3297
+ thread_peer_a. process_events ( ) ;
3298
+ thread_peer_b. process_events ( ) ;
3299
+ }
3300
+ } ) ) ;
3301
+ }
3302
+ for i in 0 ..1000 {
3303
+ threads. push ( spawn_thread ! ( i) ) ;
3304
+ }
3305
+ exit_flag. store ( true , Ordering :: Relaxed ) ;
3306
+ for thread in threads {
3307
+ thread. join ( ) . unwrap ( ) ;
3308
+ }
3309
+ assert_eq ! ( peer_a. peers. read( ) . unwrap( ) . len( ) , 0 ) ;
3310
+ assert_eq ! ( peer_b. peers. read( ) . unwrap( ) . len( ) , 0 ) ;
3311
+ }
3194
3312
}
0 commit comments