Skip to content

Commit c58938b

Browse files
committed
Fix (and test) threaded payment retries
The new in-`ChannelManager` retries logic does retries as two separate steps, under two separate locks - first it calculates the amount that needs to be retried, then it actually sends it. Because the first step doesn't udpate the amount, a second thread may come along and calculate the same amount and end up retrying duplicatively. Because we generally shouldn't ever be processing retries at the same time, the fix is trivial - simply take a lock at the top of the retry loop and hold it until we're done.
1 parent 69b7225 commit c58938b

File tree

4 files changed

+182
-3
lines changed

4 files changed

+182
-3
lines changed

lightning/src/ln/channelmanager.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7587,7 +7587,7 @@ where
75877587

75887588
inbound_payment_key: expanded_inbound_key,
75897589
pending_inbound_payments: Mutex::new(pending_inbound_payments),
7590-
pending_outbound_payments: OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()) },
7590+
pending_outbound_payments: OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), retry_lock: Mutex::new(()), },
75917591
pending_intercepted_htlcs: Mutex::new(pending_intercepted_htlcs.unwrap()),
75927592

75937593
forward_htlcs: Mutex::new(forward_htlcs),

lightning/src/ln/functional_test_utils.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,19 @@ impl<'a, 'b, 'c> Node<'a, 'b, 'c> {
350350
}
351351
}
352352

353+
/// If we need an unsafe pointer to a `Node` (ie to reference it in a thread
354+
/// pre-std::thread::scope), this provides that with `Sync`. Note that accessing some of the fields
355+
/// in the `Node` are not safe to use (i.e. the ones behind an `Rc`), but that's left to the caller
356+
/// to figure out.
357+
pub struct NodePtr(pub *const Node<'static, 'static, 'static>);
358+
impl NodePtr {
359+
pub fn from_node<'a, 'b: 'a, 'c: 'b>(node: &Node<'a, 'b, 'c>) -> Self {
360+
Self((node as *const Node<'a, 'b, 'c>).cast())
361+
}
362+
}
363+
unsafe impl Send for NodePtr {}
364+
unsafe impl Sync for NodePtr {}
365+
353366
impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
354367
fn drop(&mut self) {
355368
if !panicking() {

lightning/src/ln/outbound_payment.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,12 +392,14 @@ pub enum PaymentSendFailure {
392392

393393
pub(super) struct OutboundPayments {
394394
pub(super) pending_outbound_payments: Mutex<HashMap<PaymentId, PendingOutboundPayment>>,
395+
pub(super) retry_lock: Mutex<()>,
395396
}
396397

397398
impl OutboundPayments {
398399
pub(super) fn new() -> Self {
399400
Self {
400-
pending_outbound_payments: Mutex::new(HashMap::new())
401+
pending_outbound_payments: Mutex::new(HashMap::new()),
402+
retry_lock: Mutex::new(()),
401403
}
402404
}
403405

@@ -478,6 +480,7 @@ impl OutboundPayments {
478480
FH: Fn() -> Vec<ChannelDetails>,
479481
L::Target: Logger,
480482
{
483+
let _single_thread = self.retry_lock.lock().unwrap();
481484
loop {
482485
let mut outbounds = self.pending_outbound_payments.lock().unwrap();
483486
let mut retry_id_route_params = None;

lightning/src/ln/payment_tests.rs

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use crate::routing::gossip::NodeId;
3939
#[cfg(feature = "std")]
4040
use {
4141
crate::util::time::tests::SinceEpoch,
42-
std::time::{SystemTime, Duration}
42+
std::time::{SystemTime, Instant, Duration}
4343
};
4444

4545
#[test]
@@ -2575,3 +2575,166 @@ fn test_simple_partial_retry() {
25752575
expect_pending_htlcs_forwardable!(nodes[2]);
25762576
expect_payment_claimable!(nodes[2], payment_hash, payment_secret, amt_msat);
25772577
}
2578+
2579+
2580+
#[test]
2581+
#[cfg(feature = "std")]
2582+
fn test_threaded_payment_retries() {
2583+
// In the first version of the in-`ChannelManager` payment retries, retries weren't limited to
2584+
// a single thread and would happily let multile threads run retries at the same time. Because
2585+
// retries are done by first calculating the amount we need to retry, then dropping the
2586+
// relevant lock, then actually sending, we would happily let multiple threads retry the same
2587+
// amount at the same time, overpaying our original HTLC!
2588+
let chanmon_cfgs = create_chanmon_cfgs(4);
2589+
let node_cfgs = create_node_cfgs(4, &chanmon_cfgs);
2590+
let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]);
2591+
let nodes = create_network(4, &node_cfgs, &node_chanmgrs);
2592+
2593+
// There is one mitigating guardrail when retrying payments - we can never over-pay by more
2594+
// than 10% of the original value. Thus, we want all our retries to be below that. In order to
2595+
// keep things simple, we route one HTLC for 0.1% of the payment over channel 1 and the rest
2596+
// out over channel 3+4. This will let us ignore 99% of the payment value and deal with only
2597+
// out channel.
2598+
let chan_1_scid = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 10_000_000, 0).0.contents.short_channel_id;
2599+
create_announced_chan_between_nodes_with_value(&nodes, 1, 3, 10_000_000, 0);
2600+
let chan_3_scid = create_announced_chan_between_nodes_with_value(&nodes, 0, 2, 10_000_000, 0).0.contents.short_channel_id;
2601+
let chan_4_scid = create_announced_chan_between_nodes_with_value(&nodes, 2, 3, 10_000_000, 0).0.contents.short_channel_id;
2602+
2603+
let amt_msat = 100_000_000;
2604+
let (_, payment_hash, _, payment_secret) = get_route_and_payment_hash!(&nodes[0], nodes[2], amt_msat);
2605+
#[cfg(feature = "std")]
2606+
let payment_expiry_secs = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs() + 60 * 60;
2607+
#[cfg(not(feature = "std"))]
2608+
let payment_expiry_secs = 60 * 60;
2609+
let mut invoice_features = InvoiceFeatures::empty();
2610+
invoice_features.set_variable_length_onion_required();
2611+
invoice_features.set_payment_secret_required();
2612+
invoice_features.set_basic_mpp_optional();
2613+
let payment_params = PaymentParameters::from_node_id(nodes[1].node.get_our_node_id(), TEST_FINAL_CLTV)
2614+
.with_expiry_time(payment_expiry_secs as u64)
2615+
.with_features(invoice_features);
2616+
let mut route_params = RouteParameters {
2617+
payment_params,
2618+
final_value_msat: amt_msat,
2619+
final_cltv_expiry_delta: TEST_FINAL_CLTV,
2620+
};
2621+
2622+
let mut route = Route {
2623+
paths: vec![
2624+
vec![RouteHop {
2625+
pubkey: nodes[1].node.get_our_node_id(),
2626+
node_features: nodes[1].node.node_features(),
2627+
short_channel_id: chan_1_scid,
2628+
channel_features: nodes[1].node.channel_features(),
2629+
fee_msat: 0,
2630+
cltv_expiry_delta: 100,
2631+
}, RouteHop {
2632+
pubkey: nodes[3].node.get_our_node_id(),
2633+
node_features: nodes[2].node.node_features(),
2634+
short_channel_id: 42, // Set a random SCID which nodes[1] will fail as unknown
2635+
channel_features: nodes[2].node.channel_features(),
2636+
fee_msat: amt_msat / 1000,
2637+
cltv_expiry_delta: 100,
2638+
}],
2639+
vec![RouteHop {
2640+
pubkey: nodes[2].node.get_our_node_id(),
2641+
node_features: nodes[2].node.node_features(),
2642+
short_channel_id: chan_3_scid,
2643+
channel_features: nodes[2].node.channel_features(),
2644+
fee_msat: 100_000,
2645+
cltv_expiry_delta: 100,
2646+
}, RouteHop {
2647+
pubkey: nodes[3].node.get_our_node_id(),
2648+
node_features: nodes[3].node.node_features(),
2649+
short_channel_id: chan_4_scid,
2650+
channel_features: nodes[3].node.channel_features(),
2651+
fee_msat: amt_msat - amt_msat / 1000,
2652+
cltv_expiry_delta: 100,
2653+
}]
2654+
],
2655+
payment_params: Some(PaymentParameters::from_node_id(nodes[2].node.get_our_node_id(), TEST_FINAL_CLTV)),
2656+
};
2657+
nodes[0].router.expect_find_route(route_params.clone(), Ok(route.clone()));
2658+
2659+
nodes[0].node.send_payment_with_retry(payment_hash, &Some(payment_secret), PaymentId(payment_hash.0), route_params.clone(), Retry::Attempts(0xdeadbeef)).unwrap();
2660+
check_added_monitors!(nodes[0], 2);
2661+
let mut send_msg_events = nodes[0].node.get_and_clear_pending_msg_events();
2662+
assert_eq!(send_msg_events.len(), 2);
2663+
send_msg_events.retain(|msg|
2664+
if let MessageSendEvent::UpdateHTLCs { node_id, .. } = msg {
2665+
// Drop the commitment update for noddes[2], we can just let that one sit pending
2666+
// forever.
2667+
*node_id == nodes[1].node.get_our_node_id()
2668+
} else { panic!(); }
2669+
);
2670+
2671+
// from here on out, the retry `RouteParameters` amount will be amt/1000
2672+
route_params.final_value_msat /= 1000;
2673+
route.paths.pop();
2674+
2675+
let end_time = Instant::now() + Duration::from_secs(1);
2676+
macro_rules! thread_body { () => { {
2677+
// We really want std::thread::scope, but its not stable until 1.63. Until then, we get unsafe.
2678+
let node_ref = NodePtr::from_node(&nodes[0]);
2679+
move || {
2680+
let node_a = unsafe { &*node_ref.0 };
2681+
while Instant::now() < end_time {
2682+
node_a.node.get_and_clear_pending_events(); // wipe the PendingHTLCsForwardable
2683+
// Ignore if we have any pending events, just always pretend we just got a
2684+
// PendingHTLCsForwardable
2685+
node_a.node.process_pending_htlc_forwards();
2686+
}
2687+
}
2688+
} } }
2689+
let mut threads = Vec::new();
2690+
for _ in 0..16 { threads.push(std::thread::spawn(thread_body!())); }
2691+
2692+
// Back in the main thread, poll pending messages and make sure that we never have more than
2693+
// one HTLC pending at a time. Note that the commitment_signed_dance will fail horribly if
2694+
// there are HTLC messages shoved in while its running. This allows us to test that we never
2695+
// generate an additional update_add_htlc until we've fully failed the first.
2696+
let mut previously_failed_channels = Vec::new();
2697+
loop {
2698+
assert_eq!(send_msg_events.len(), 1);
2699+
let send_event = SendEvent::from_event(send_msg_events.pop().unwrap());
2700+
assert_eq!(send_event.msgs.len(), 1);
2701+
2702+
nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]);
2703+
commitment_signed_dance!(nodes[1], nodes[0], send_event.commitment_msg, false, true);
2704+
2705+
// Note that we only push one route into `expect_find_route` at a time, because that's all
2706+
// the retries (should) need. If the bug is reintroduced "real" routes may be selected, but
2707+
// we should still ultimately fail for the same reason - because we're trying to send too
2708+
// many HTLCs at once.
2709+
let mut new_route_params = route_params.clone();
2710+
previously_failed_channels.push(route.paths[0][1].short_channel_id);
2711+
new_route_params.payment_params.previously_failed_channels = previously_failed_channels.clone();
2712+
route.paths[0][1].short_channel_id += 1;
2713+
nodes[0].router.expect_find_route(new_route_params, Ok(route.clone()));
2714+
2715+
let bs_fail_updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
2716+
nodes[0].node.handle_update_fail_htlc(&nodes[1].node.get_our_node_id(), &bs_fail_updates.update_fail_htlcs[0]);
2717+
// The "normal" commitment_signed_dance delivers the final RAA and then calls
2718+
// `check_added_monitors` to ensure only the one RAA-generated monitor update was created.
2719+
// This races with our other threads which may generate an add-HTLCs commitment update via
2720+
// `process_pending_htlc_forwards`. Instead, we defer the monitor update check until after
2721+
// *we've* called `process_pending_htlc_forwards` when its guaranteed to have two updates.
2722+
let last_raa = commitment_signed_dance!(nodes[0], nodes[1], bs_fail_updates.commitment_signed, false, true, false, true);
2723+
nodes[0].node.handle_revoke_and_ack(&nodes[1].node.get_our_node_id(), &last_raa);
2724+
2725+
let cur_time = Instant::now();
2726+
if cur_time > end_time {
2727+
for thread in threads.drain(..) { thread.join().unwrap(); }
2728+
}
2729+
2730+
// Make sure we have some events to handle when we go around...
2731+
nodes[0].node.get_and_clear_pending_events(); // wipe the PendingHTLCsForwardable
2732+
nodes[0].node.process_pending_htlc_forwards();
2733+
send_msg_events = nodes[0].node.get_and_clear_pending_msg_events();
2734+
check_added_monitors!(nodes[0], 2);
2735+
2736+
if cur_time > end_time {
2737+
break;
2738+
}
2739+
}
2740+
}

0 commit comments

Comments
 (0)