Skip to content

Commit 77a0f77

Browse files
committed
Fix (and test) threaded payment retries
The new in-`ChannelManager` retries logic does retries as two separate steps, under two separate locks - first it calculates the amount that needs to be retried, then it actually sends it. Because the first step doesn't udpate the amount, a second thread may come along and calculate the same amount and end up retrying duplicatively. Because we generally shouldn't ever be processing retries at the same time, the fix is trivial - simply take a lock at the top of the retry loop and hold it until we're done.
1 parent 0f03d48 commit 77a0f77

File tree

4 files changed

+181
-3
lines changed

4 files changed

+181
-3
lines changed

lightning/src/ln/channelmanager.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7715,7 +7715,7 @@ where
77157715

77167716
inbound_payment_key: expanded_inbound_key,
77177717
pending_inbound_payments: Mutex::new(pending_inbound_payments),
7718-
pending_outbound_payments: OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()) },
7718+
pending_outbound_payments: OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), retry_lock: Mutex::new(()), },
77197719
pending_intercepted_htlcs: Mutex::new(pending_intercepted_htlcs.unwrap()),
77207720

77217721
forward_htlcs: Mutex::new(forward_htlcs),

lightning/src/ln/functional_test_utils.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,19 @@ impl<'a, 'b, 'c> Node<'a, 'b, 'c> {
350350
}
351351
}
352352

353+
/// If we need an unsafe pointer to a `Node` (ie to reference it in a thread
354+
/// pre-std::thread::scope), this provides that with `Sync`. Note that accessing some of the fields
355+
/// in the `Node` are not safe to use (i.e. the ones behind an `Rc`), but that's left to the caller
356+
/// to figure out.
357+
pub struct NodePtr(pub *const Node<'static, 'static, 'static>);
358+
impl NodePtr {
359+
pub fn from_node<'a, 'b: 'a, 'c: 'b>(node: &Node<'a, 'b, 'c>) -> Self {
360+
Self((node as *const Node<'a, 'b, 'c>).cast())
361+
}
362+
}
363+
unsafe impl Send for NodePtr {}
364+
unsafe impl Sync for NodePtr {}
365+
353366
impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
354367
fn drop(&mut self) {
355368
if !panicking() {

lightning/src/ln/outbound_payment.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,12 +393,14 @@ pub enum PaymentSendFailure {
393393

394394
pub(super) struct OutboundPayments {
395395
pub(super) pending_outbound_payments: Mutex<HashMap<PaymentId, PendingOutboundPayment>>,
396+
pub(super) retry_lock: Mutex<()>,
396397
}
397398

398399
impl OutboundPayments {
399400
pub(super) fn new() -> Self {
400401
Self {
401-
pending_outbound_payments: Mutex::new(HashMap::new())
402+
pending_outbound_payments: Mutex::new(HashMap::new()),
403+
retry_lock: Mutex::new(()),
402404
}
403405
}
404406

@@ -501,6 +503,7 @@ impl OutboundPayments {
501503
FH: Fn() -> Vec<ChannelDetails>,
502504
L::Target: Logger,
503505
{
506+
let _single_thread = self.retry_lock.lock().unwrap();
504507
loop {
505508
let mut outbounds = self.pending_outbound_payments.lock().unwrap();
506509
let mut retry_id_route_params = None;

lightning/src/ln/payment_tests.rs

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use crate::routing::gossip::NodeId;
3939
#[cfg(feature = "std")]
4040
use {
4141
crate::util::time::tests::SinceEpoch,
42-
std::time::{SystemTime, Duration}
42+
std::time::{SystemTime, Instant, Duration}
4343
};
4444

4545
#[test]
@@ -2616,3 +2616,165 @@ fn test_simple_partial_retry() {
26162616
expect_pending_htlcs_forwardable!(nodes[2]);
26172617
expect_payment_claimable!(nodes[2], payment_hash, payment_secret, amt_msat);
26182618
}
2619+
2620+
#[test]
2621+
#[cfg(feature = "std")]
2622+
fn test_threaded_payment_retries() {
2623+
// In the first version of the in-`ChannelManager` payment retries, retries weren't limited to
2624+
// a single thread and would happily let multiple threads run retries at the same time. Because
2625+
// retries are done by first calculating the amount we need to retry, then dropping the
2626+
// relevant lock, then actually sending, we would happily let multiple threads retry the same
2627+
// amount at the same time, overpaying our original HTLC!
2628+
let chanmon_cfgs = create_chanmon_cfgs(4);
2629+
let node_cfgs = create_node_cfgs(4, &chanmon_cfgs);
2630+
let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]);
2631+
let nodes = create_network(4, &node_cfgs, &node_chanmgrs);
2632+
2633+
// There is one mitigating guardrail when retrying payments - we can never over-pay by more
2634+
// than 10% of the original value. Thus, we want all our retries to be below that. In order to
2635+
// keep things simple, we route one HTLC for 0.1% of the payment over channel 1 and the rest
2636+
// out over channel 3+4. This will let us ignore 99% of the payment value and deal with only
2637+
// our channel.
2638+
let chan_1_scid = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 10_000_000, 0).0.contents.short_channel_id;
2639+
create_announced_chan_between_nodes_with_value(&nodes, 1, 3, 10_000_000, 0);
2640+
let chan_3_scid = create_announced_chan_between_nodes_with_value(&nodes, 0, 2, 10_000_000, 0).0.contents.short_channel_id;
2641+
let chan_4_scid = create_announced_chan_between_nodes_with_value(&nodes, 2, 3, 10_000_000, 0).0.contents.short_channel_id;
2642+
2643+
let amt_msat = 100_000_000;
2644+
let (_, payment_hash, _, payment_secret) = get_route_and_payment_hash!(&nodes[0], nodes[2], amt_msat);
2645+
#[cfg(feature = "std")]
2646+
let payment_expiry_secs = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs() + 60 * 60;
2647+
#[cfg(not(feature = "std"))]
2648+
let payment_expiry_secs = 60 * 60;
2649+
let mut invoice_features = InvoiceFeatures::empty();
2650+
invoice_features.set_variable_length_onion_required();
2651+
invoice_features.set_payment_secret_required();
2652+
invoice_features.set_basic_mpp_optional();
2653+
let payment_params = PaymentParameters::from_node_id(nodes[1].node.get_our_node_id(), TEST_FINAL_CLTV)
2654+
.with_expiry_time(payment_expiry_secs as u64)
2655+
.with_features(invoice_features);
2656+
let mut route_params = RouteParameters {
2657+
payment_params,
2658+
final_value_msat: amt_msat,
2659+
final_cltv_expiry_delta: TEST_FINAL_CLTV,
2660+
};
2661+
2662+
let mut route = Route {
2663+
paths: vec![
2664+
vec![RouteHop {
2665+
pubkey: nodes[1].node.get_our_node_id(),
2666+
node_features: nodes[1].node.node_features(),
2667+
short_channel_id: chan_1_scid,
2668+
channel_features: nodes[1].node.channel_features(),
2669+
fee_msat: 0,
2670+
cltv_expiry_delta: 100,
2671+
}, RouteHop {
2672+
pubkey: nodes[3].node.get_our_node_id(),
2673+
node_features: nodes[2].node.node_features(),
2674+
short_channel_id: 42, // Set a random SCID which nodes[1] will fail as unknown
2675+
channel_features: nodes[2].node.channel_features(),
2676+
fee_msat: amt_msat / 1000,
2677+
cltv_expiry_delta: 100,
2678+
}],
2679+
vec![RouteHop {
2680+
pubkey: nodes[2].node.get_our_node_id(),
2681+
node_features: nodes[2].node.node_features(),
2682+
short_channel_id: chan_3_scid,
2683+
channel_features: nodes[2].node.channel_features(),
2684+
fee_msat: 100_000,
2685+
cltv_expiry_delta: 100,
2686+
}, RouteHop {
2687+
pubkey: nodes[3].node.get_our_node_id(),
2688+
node_features: nodes[3].node.node_features(),
2689+
short_channel_id: chan_4_scid,
2690+
channel_features: nodes[3].node.channel_features(),
2691+
fee_msat: amt_msat - amt_msat / 1000,
2692+
cltv_expiry_delta: 100,
2693+
}]
2694+
],
2695+
payment_params: Some(PaymentParameters::from_node_id(nodes[2].node.get_our_node_id(), TEST_FINAL_CLTV)),
2696+
};
2697+
nodes[0].router.expect_find_route(route_params.clone(), Ok(route.clone()));
2698+
2699+
nodes[0].node.send_payment_with_retry(payment_hash, &Some(payment_secret), PaymentId(payment_hash.0), route_params.clone(), Retry::Attempts(0xdeadbeef)).unwrap();
2700+
check_added_monitors!(nodes[0], 2);
2701+
let mut send_msg_events = nodes[0].node.get_and_clear_pending_msg_events();
2702+
assert_eq!(send_msg_events.len(), 2);
2703+
send_msg_events.retain(|msg|
2704+
if let MessageSendEvent::UpdateHTLCs { node_id, .. } = msg {
2705+
// Drop the commitment update for nodes[2], we can just let that one sit pending
2706+
// forever.
2707+
*node_id == nodes[1].node.get_our_node_id()
2708+
} else { panic!(); }
2709+
);
2710+
2711+
// from here on out, the retry `RouteParameters` amount will be amt/1000
2712+
route_params.final_value_msat /= 1000;
2713+
route.paths.pop();
2714+
2715+
let end_time = Instant::now() + Duration::from_secs(1);
2716+
macro_rules! thread_body { () => { {
2717+
// We really want std::thread::scope, but its not stable until 1.63. Until then, we get unsafe.
2718+
let node_ref = NodePtr::from_node(&nodes[0]);
2719+
move || {
2720+
let node_a = unsafe { &*node_ref.0 };
2721+
while Instant::now() < end_time {
2722+
node_a.node.get_and_clear_pending_events(); // wipe the PendingHTLCsForwardable
2723+
// Ignore if we have any pending events, just always pretend we just got a
2724+
// PendingHTLCsForwardable
2725+
node_a.node.process_pending_htlc_forwards();
2726+
}
2727+
}
2728+
} } }
2729+
let mut threads = Vec::new();
2730+
for _ in 0..16 { threads.push(std::thread::spawn(thread_body!())); }
2731+
2732+
// Back in the main thread, poll pending messages and make sure that we never have more than
2733+
// one HTLC pending at a time. Note that the commitment_signed_dance will fail horribly if
2734+
// there are HTLC messages shoved in while its running. This allows us to test that we never
2735+
// generate an additional update_add_htlc until we've fully failed the first.
2736+
let mut previously_failed_channels = Vec::new();
2737+
loop {
2738+
assert_eq!(send_msg_events.len(), 1);
2739+
let send_event = SendEvent::from_event(send_msg_events.pop().unwrap());
2740+
assert_eq!(send_event.msgs.len(), 1);
2741+
2742+
nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]);
2743+
commitment_signed_dance!(nodes[1], nodes[0], send_event.commitment_msg, false, true);
2744+
2745+
// Note that we only push one route into `expect_find_route` at a time, because that's all
2746+
// the retries (should) need. If the bug is reintroduced "real" routes may be selected, but
2747+
// we should still ultimately fail for the same reason - because we're trying to send too
2748+
// many HTLCs at once.
2749+
let mut new_route_params = route_params.clone();
2750+
previously_failed_channels.push(route.paths[0][1].short_channel_id);
2751+
new_route_params.payment_params.previously_failed_channels = previously_failed_channels.clone();
2752+
route.paths[0][1].short_channel_id += 1;
2753+
nodes[0].router.expect_find_route(new_route_params, Ok(route.clone()));
2754+
2755+
let bs_fail_updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
2756+
nodes[0].node.handle_update_fail_htlc(&nodes[1].node.get_our_node_id(), &bs_fail_updates.update_fail_htlcs[0]);
2757+
// The "normal" commitment_signed_dance delivers the final RAA and then calls
2758+
// `check_added_monitors` to ensure only the one RAA-generated monitor update was created.
2759+
// This races with our other threads which may generate an add-HTLCs commitment update via
2760+
// `process_pending_htlc_forwards`. Instead, we defer the monitor update check until after
2761+
// *we've* called `process_pending_htlc_forwards` when its guaranteed to have two updates.
2762+
let last_raa = commitment_signed_dance!(nodes[0], nodes[1], bs_fail_updates.commitment_signed, false, true, false, true);
2763+
nodes[0].node.handle_revoke_and_ack(&nodes[1].node.get_our_node_id(), &last_raa);
2764+
2765+
let cur_time = Instant::now();
2766+
if cur_time > end_time {
2767+
for thread in threads.drain(..) { thread.join().unwrap(); }
2768+
}
2769+
2770+
// Make sure we have some events to handle when we go around...
2771+
nodes[0].node.get_and_clear_pending_events(); // wipe the PendingHTLCsForwardable
2772+
nodes[0].node.process_pending_htlc_forwards();
2773+
send_msg_events = nodes[0].node.get_and_clear_pending_msg_events();
2774+
check_added_monitors!(nodes[0], 2);
2775+
2776+
if cur_time > end_time {
2777+
break;
2778+
}
2779+
}
2780+
}

0 commit comments

Comments
 (0)