Skip to content

Commit b181025

Browse files
f fix mpp with multiple matching intro nodes
1 parent 60c31cd commit b181025

File tree

1 file changed

+67
-6
lines changed

1 file changed

+67
-6
lines changed

lightning/src/routing/router.rs

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ impl Payee {
770770
_ => None,
771771
}
772772
}
773+
// This method must always return the hints in the same order.
773774
fn blinded_route_hints(&self) -> &[(BlindedPayInfo, BlindedPath)] {
774775
match self {
775776
Self::Blinded { route_hints, .. } => &route_hints[..],
@@ -2307,12 +2308,13 @@ where L::Target: Logger {
23072308
for results_vec in selected_paths {
23082309
let mut hops = Vec::with_capacity(results_vec.len());
23092310
for res in results_vec { hops.push(res?); }
2310-
let blinded_path = payment_params.payee.blinded_route_hints().iter()
2311-
.find(|(_, p)| {
2312-
let intro_node_idx = if p.blinded_hops.len() == 1 { hops.len() - 1 }
2313-
else { hops.len().saturating_sub(2) };
2314-
p.introduction_node_id == hops[intro_node_idx].pubkey
2315-
}).map(|(_, p)| p.clone());
2311+
2312+
let blinded_path = payment_params.payee.blinded_route_hints().iter().enumerate()
2313+
.find(|(idx, (_, p))| {
2314+
DUMMY_BLINDED_SCID + *idx as u64 == hops[hops.len() - 1].short_channel_id ||
2315+
(p.blinded_hops.len() == 1 && p.introduction_node_id == hops[hops.len() - 1].pubkey)
2316+
})
2317+
.map(|(_, (_, p))| p.clone());
23162318
let blinded_tail = if let Some(BlindedPath { blinded_hops, blinding_point, .. }) = blinded_path {
23172319
let num_blinded_hops = blinded_hops.len();
23182320
Some(BlindedTail {
@@ -6355,6 +6357,65 @@ mod tests {
63556357
_ => panic!("Expected error")
63566358
}
63576359
}
6360+
6361+
#[test]
6362+
fn matching_intro_node_paths_provided() {
6363+
// Check that if multiple blinded paths with the same intro node are provided in payment
6364+
// parameters, we'll return the correct paths in the resulting MPP route.
6365+
let (secp_ctx, network, _, _, logger) = build_graph();
6366+
let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
6367+
let network_graph = network.read_only();
6368+
6369+
let scorer = ln_test_utils::TestScorer::new();
6370+
let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
6371+
let random_seed_bytes = keys_manager.get_secure_random_bytes();
6372+
let config = UserConfig::default();
6373+
6374+
let bolt12_features: Bolt12InvoiceFeatures = channelmanager::provided_invoice_features(&config).to_context();
6375+
let blinded_path_1 = BlindedPath {
6376+
introduction_node_id: nodes[2],
6377+
blinding_point: ln_test_utils::pubkey(42),
6378+
blinded_hops: vec![
6379+
BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
6380+
BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() }
6381+
],
6382+
};
6383+
let blinded_payinfo_1 = BlindedPayInfo {
6384+
fee_base_msat: 0,
6385+
fee_proportional_millionths: 0,
6386+
htlc_minimum_msat: 0,
6387+
htlc_maximum_msat: 30_000,
6388+
cltv_expiry_delta: 0,
6389+
features: BlindedHopFeatures::empty(),
6390+
};
6391+
6392+
let mut blinded_path_2 = blinded_path_1.clone();
6393+
blinded_path_2.blinding_point = ln_test_utils::pubkey(43);
6394+
let mut blinded_payinfo_2 = blinded_payinfo_1.clone();
6395+
blinded_payinfo_2.htlc_maximum_msat = 70_000;
6396+
6397+
let blinded_hints = vec![
6398+
(blinded_payinfo_1.clone(), blinded_path_1.clone()),
6399+
(blinded_payinfo_2.clone(), blinded_path_2.clone()),
6400+
];
6401+
let payment_params = PaymentParameters::blinded(blinded_hints.clone())
6402+
.with_bolt12_features(bolt12_features.clone()).unwrap();
6403+
6404+
let route = get_route(&our_id, &payment_params, &network_graph, None,
6405+
100_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap();
6406+
assert_eq!(route.paths.len(), 2);
6407+
let mut total_amount_paid_msat = 0;
6408+
for path in route.paths.into_iter() {
6409+
assert_eq!(path.hops.last().unwrap().pubkey, nodes[2]);
6410+
if let Some(bt) = &path.blinded_tail {
6411+
assert_eq!(bt.blinding_point,
6412+
blinded_hints.iter().find(|(p, _)| p.htlc_maximum_msat == path.final_value_msat())
6413+
.map(|(_, bp)| bp.blinding_point).unwrap());
6414+
} else { panic!(); }
6415+
total_amount_paid_msat += path.final_value_msat();
6416+
}
6417+
assert_eq!(total_amount_paid_msat, 100_000);
6418+
}
63586419
}
63596420

63606421
#[cfg(all(test, not(feature = "no-std")))]

0 commit comments

Comments
 (0)