Skip to content

Commit 97fd72a

Browse files
feat(kad): add limit option for getting providers
1 parent d0da3a0 commit 97fd72a

File tree

3 files changed

+158
-28
lines changed

3 files changed

+158
-28
lines changed

protocols/kad/src/behaviour.rs

+71-25
Original file line numberDiff line numberDiff line change
@@ -919,17 +919,23 @@ where
919919
///
920920
/// The result of this operation is delivered in a
921921
/// reported via [`KademliaEvent::OutboundQueryCompleted{QueryResult::GetProviders}`].
922-
pub fn get_providers(&mut self, key: record::Key) -> QueryId {
922+
pub fn get_providers(&mut self, key: record::Key, limit: ProviderLimit) -> QueryId {
923923
let providers = self
924924
.store
925925
.providers(&key)
926926
.into_iter()
927927
.filter(|p| !p.is_expired(Instant::now()))
928-
.map(|p| p.provider)
929-
.collect();
928+
.map(|p| p.provider);
929+
930+
let providers = match limit {
931+
ProviderLimit::None => providers.collect(),
932+
ProviderLimit::N(limit) => providers.take(limit.into()).collect(),
933+
};
934+
930935
let info = QueryInfo::GetProviders {
931936
key: key.clone(),
932937
providers,
938+
limit,
933939
};
934940
let target = kbucket::Key::new(key);
935941
let peers = self.kbuckets.closest_keys(&target);
@@ -1258,17 +1264,19 @@ where
12581264
})),
12591265
}),
12601266

1261-
QueryInfo::GetProviders { key, providers } => {
1262-
Some(KademliaEvent::OutboundQueryCompleted {
1263-
id: query_id,
1264-
stats: result.stats,
1265-
result: QueryResult::GetProviders(Ok(GetProvidersOk {
1266-
key,
1267-
providers,
1268-
closest_peers: result.peers.collect(),
1269-
})),
1270-
})
1271-
}
1267+
QueryInfo::GetProviders {
1268+
key,
1269+
providers,
1270+
limit: _,
1271+
} => Some(KademliaEvent::OutboundQueryCompleted {
1272+
id: query_id,
1273+
stats: result.stats,
1274+
result: QueryResult::GetProviders(Ok(GetProvidersOk {
1275+
key,
1276+
providers,
1277+
closest_peers: result.peers.collect(),
1278+
})),
1279+
}),
12721280

12731281
QueryInfo::AddProvider {
12741282
context,
@@ -1553,17 +1561,19 @@ where
15531561
})),
15541562
}),
15551563

1556-
QueryInfo::GetProviders { key, providers } => {
1557-
Some(KademliaEvent::OutboundQueryCompleted {
1558-
id: query_id,
1559-
stats: result.stats,
1560-
result: QueryResult::GetProviders(Err(GetProvidersError::Timeout {
1561-
key,
1562-
providers,
1563-
closest_peers: result.peers.collect(),
1564-
})),
1565-
})
1566-
}
1564+
QueryInfo::GetProviders {
1565+
key,
1566+
providers,
1567+
limit: _,
1568+
} => Some(KademliaEvent::OutboundQueryCompleted {
1569+
id: query_id,
1570+
stats: result.stats,
1571+
result: QueryResult::GetProviders(Err(GetProvidersError::Timeout {
1572+
key,
1573+
providers,
1574+
closest_peers: result.peers.collect(),
1575+
})),
1576+
}),
15671577
}
15681578
}
15691579

@@ -2331,6 +2341,31 @@ where
23312341
{
23322342
query.on_success(&peer_id, vec![])
23332343
}
2344+
2345+
if let QueryInfo::GetProviders {
2346+
key: _,
2347+
providers,
2348+
limit,
2349+
} = &query.inner.info
2350+
{
2351+
match limit {
2352+
ProviderLimit::None => {
2353+
// No limit, so wait for enough peers to respond.
2354+
}
2355+
ProviderLimit::N(n) => {
2356+
// Check if we have enough providers.
2357+
if usize::from(*n) <= providers.len() {
2358+
debug!(
2359+
"found enough providers {}/{}, finishing",
2360+
providers.len(),
2361+
n
2362+
);
2363+
query.finish();
2364+
}
2365+
}
2366+
}
2367+
}
2368+
23342369
if self.connected_peers.contains(&peer_id) {
23352370
self.queued_events
23362371
.push_back(NetworkBehaviourAction::NotifyHandler {
@@ -2363,6 +2398,15 @@ where
23632398
}
23642399
}
23652400

2401+
/// Specifies the number of provider records fetched.
2402+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
2403+
pub enum ProviderLimit {
2404+
/// No limit on the number of records.
2405+
None,
2406+
/// Finishes the query as soon as this many records have been found.
2407+
N(NonZeroUsize),
2408+
}
2409+
23662410
/// A quorum w.r.t. the configured replication factor specifies the minimum
23672411
/// number of distinct nodes that must be successfully contacted in order
23682412
/// for a query to succeed.
@@ -2862,6 +2906,8 @@ pub enum QueryInfo {
28622906
key: record::Key,
28632907
/// The found providers.
28642908
providers: HashSet<PeerId>,
2909+
/// The limit of how many providers to find,
2910+
limit: ProviderLimit,
28652911
},
28662912

28672913
/// A (repeated) query initiated by [`Kademlia::start_providing`].

protocols/kad/src/behaviour/test.rs

+86-2
Original file line numberDiff line numberDiff line change
@@ -1333,7 +1333,7 @@ fn network_behaviour_inject_address_change() {
13331333
}
13341334

13351335
#[test]
1336-
fn get_providers() {
1336+
fn get_providers_single() {
13371337
fn prop(key: record::Key) {
13381338
let (_, mut single_swarm) = build_node();
13391339
single_swarm
@@ -1352,7 +1352,9 @@ fn get_providers() {
13521352
}
13531353
});
13541354

1355-
let query_id = single_swarm.behaviour_mut().get_providers(key.clone());
1355+
let query_id = single_swarm
1356+
.behaviour_mut()
1357+
.get_providers(key.clone(), ProviderLimit::None);
13561358

13571359
block_on(async {
13581360
match single_swarm.next().await.unwrap() {
@@ -1379,3 +1381,85 @@ fn get_providers() {
13791381
}
13801382
QuickCheck::new().tests(10).quickcheck(prop as fn(_))
13811383
}
1384+
1385+
fn get_providers_limit<const N: usize>() {
1386+
fn prop<const N: usize>(key: record::Key) {
1387+
let mut swarms = build_nodes(3);
1388+
1389+
// Let first peer know of second peer and second peer know of third peer.
1390+
for i in 0..2 {
1391+
let (peer_id, address) = (
1392+
Swarm::local_peer_id(&swarms[i + 1].1).clone(),
1393+
swarms[i + 1].0.clone(),
1394+
);
1395+
swarms[i].1.behaviour_mut().add_address(&peer_id, address);
1396+
}
1397+
1398+
// Drop the swarm addresses.
1399+
let mut swarms = swarms
1400+
.into_iter()
1401+
.map(|(_addr, swarm)| swarm)
1402+
.collect::<Vec<_>>();
1403+
1404+
// Provide the content on peer 2 and 3.
1405+
for i in 1..3 {
1406+
swarms[i]
1407+
.behaviour_mut()
1408+
.start_providing(key.clone())
1409+
.expect("could not provide");
1410+
}
1411+
1412+
// Query with expecting a single provider.
1413+
let query_id = swarms[0]
1414+
.behaviour_mut()
1415+
.get_providers(key.clone(), ProviderLimit::N(N.try_into().unwrap()));
1416+
1417+
block_on(poll_fn(move |ctx| {
1418+
for (i, swarm) in swarms.iter_mut().enumerate() {
1419+
loop {
1420+
match swarm.poll_next_unpin(ctx) {
1421+
Poll::Ready(Some(SwarmEvent::Behaviour(
1422+
KademliaEvent::OutboundQueryCompleted {
1423+
id,
1424+
result:
1425+
QueryResult::GetProviders(Ok(GetProvidersOk {
1426+
key: found_key,
1427+
providers,
1428+
..
1429+
})),
1430+
..
1431+
},
1432+
))) if i == 0 && id == query_id => {
1433+
// There are a total of 2 providers.
1434+
assert_eq!(providers.len(), std::cmp::min(N, 2));
1435+
assert_eq!(key, found_key);
1436+
// Providers should be either 2 or 3
1437+
assert_ne!(swarm.local_peer_id(), providers.iter().next().unwrap());
1438+
return Poll::Ready(());
1439+
}
1440+
Poll::Ready(..) => {}
1441+
Poll::Pending => break,
1442+
}
1443+
}
1444+
}
1445+
Poll::Pending
1446+
}));
1447+
}
1448+
1449+
QuickCheck::new().tests(10).quickcheck(prop::<N> as fn(_))
1450+
}
1451+
1452+
#[test]
1453+
fn get_providers_limit_n_1() {
1454+
get_providers_limit::<1>();
1455+
}
1456+
1457+
#[test]
1458+
fn get_providers_limit_n_2() {
1459+
get_providers_limit::<1>();
1460+
}
1461+
1462+
#[test]
1463+
fn get_providers_limit_n_5() {
1464+
get_providers_limit::<5>();
1465+
}

protocols/kad/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ pub use behaviour::{
6767
};
6868
pub use behaviour::{
6969
Kademlia, KademliaBucketInserts, KademliaCaching, KademliaConfig, KademliaEvent,
70-
KademliaStoreInserts, Quorum,
70+
KademliaStoreInserts, ProviderLimit, Quorum,
7171
};
7272
pub use protocol::KadConnectionType;
7373
pub use query::QueryId;

0 commit comments

Comments
 (0)