Skip to content

Commit 17970f1

Browse files
authored
autonatv2: allow multiple concurrent requests per peer (#3187)
1 parent 898824c commit 17970f1

File tree

3 files changed

+128
-29
lines changed

3 files changed

+128
-29
lines changed

p2p/protocol/autonatv2/options.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ type autoNATSettings struct {
88
serverRPM int
99
serverPerPeerRPM int
1010
serverDialDataRPM int
11+
maxConcurrentRequestsPerPeer int
1112
dataRequestPolicy dataRequestPolicyFunc
1213
now func() time.Time
1314
amplificatonAttackPreventionDialWait time.Duration
@@ -20,6 +21,7 @@ func defaultSettings() *autoNATSettings {
2021
serverRPM: 60, // 1 every second
2122
serverPerPeerRPM: 12, // 1 every 5 seconds
2223
serverDialDataRPM: 12, // 1 every 5 seconds
24+
maxConcurrentRequestsPerPeer: 2,
2325
dataRequestPolicy: amplificationAttackPrevention,
2426
amplificatonAttackPreventionDialWait: 3 * time.Second,
2527
now: time.Now,
@@ -28,11 +30,12 @@ func defaultSettings() *autoNATSettings {
2830

2931
type AutoNATOption func(s *autoNATSettings) error
3032

31-
func WithServerRateLimit(rpm, perPeerRPM, dialDataRPM int) AutoNATOption {
33+
func WithServerRateLimit(rpm, perPeerRPM, dialDataRPM int, maxConcurrentRequestsPerPeer int) AutoNATOption {
3234
return func(s *autoNATSettings) error {
3335
s.serverRPM = rpm
3436
s.serverPerPeerRPM = perPeerRPM
3537
s.serverDialDataRPM = dialDataRPM
38+
s.maxConcurrentRequestsPerPeer = maxConcurrentRequestsPerPeer
3639
return nil
3740
}
3841
}

p2p/protocol/autonatv2/server.go

+29-21
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ func newServer(host, dialer host.Host, s *autoNATSettings) *server {
6767
amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait,
6868
allowPrivateAddrs: s.allowPrivateAddrs,
6969
limiter: &rateLimiter{
70-
RPM: s.serverRPM,
71-
PerPeerRPM: s.serverPerPeerRPM,
72-
DialDataRPM: s.serverDialDataRPM,
73-
now: s.now,
70+
RPM: s.serverRPM,
71+
PerPeerRPM: s.serverPerPeerRPM,
72+
DialDataRPM: s.serverDialDataRPM,
73+
MaxConcurrentRequestsPerPeer: s.maxConcurrentRequestsPerPeer,
74+
now: s.now,
7475
},
7576
now: s.now,
7677
metricsTracer: s.metricsTracer,
@@ -391,16 +392,17 @@ type rateLimiter struct {
391392
RPM int
392393
// DialDataRPM is the rate limit for requests that require dial data
393394
DialDataRPM int
395+
// MaxConcurrentRequestsPerPeer is the maximum number of concurrent requests per peer
396+
MaxConcurrentRequestsPerPeer int
394397

395398
mu sync.Mutex
396399
closed bool
397400
reqs []entry
398401
peerReqs map[peer.ID][]time.Time
399402
dialDataReqs []time.Time
400-
// ongoingReqs tracks in progress requests. This is used to disallow multiple concurrent requests by the
401-
// same peer
402-
// TODO: Should we allow a few concurrent requests per peer?
403-
ongoingReqs map[peer.ID]struct{}
403+
// inProgressReqs tracks in progress requests. This is used to limit multiple
404+
// concurrent requests by the same peer.
405+
inProgressReqs map[peer.ID]int
404406

405407
now func() time.Time // for tests
406408
}
@@ -410,28 +412,31 @@ type entry struct {
410412
Time time.Time
411413
}
412414

415+
func (r *rateLimiter) init() {
416+
if r.peerReqs == nil {
417+
r.peerReqs = make(map[peer.ID][]time.Time)
418+
r.inProgressReqs = make(map[peer.ID]int)
419+
}
420+
}
421+
413422
func (r *rateLimiter) Accept(p peer.ID) bool {
414423
r.mu.Lock()
415424
defer r.mu.Unlock()
416425
if r.closed {
417426
return false
418427
}
419-
if r.peerReqs == nil {
420-
r.peerReqs = make(map[peer.ID][]time.Time)
421-
r.ongoingReqs = make(map[peer.ID]struct{})
422-
}
423-
428+
r.init()
424429
nw := r.now()
425430
r.cleanup(nw)
426431

427-
if _, ok := r.ongoingReqs[p]; ok {
432+
if r.inProgressReqs[p] >= r.MaxConcurrentRequestsPerPeer {
428433
return false
429434
}
430435
if len(r.reqs) >= r.RPM || len(r.peerReqs[p]) >= r.PerPeerRPM {
431436
return false
432437
}
433438

434-
r.ongoingReqs[p] = struct{}{}
439+
r.inProgressReqs[p]++
435440
r.reqs = append(r.reqs, entry{PeerID: p, Time: nw})
436441
r.peerReqs[p] = append(r.peerReqs[p], nw)
437442
return true
@@ -443,10 +448,7 @@ func (r *rateLimiter) AcceptDialDataRequest(p peer.ID) bool {
443448
if r.closed {
444449
return false
445450
}
446-
if r.peerReqs == nil {
447-
r.peerReqs = make(map[peer.ID][]time.Time)
448-
r.ongoingReqs = make(map[peer.ID]struct{})
449-
}
451+
r.init()
450452
nw := r.now()
451453
r.cleanup(nw)
452454
if len(r.dialDataReqs) >= r.DialDataRPM {
@@ -495,15 +497,21 @@ func (r *rateLimiter) cleanup(now time.Time) {
495497
func (r *rateLimiter) CompleteRequest(p peer.ID) {
496498
r.mu.Lock()
497499
defer r.mu.Unlock()
498-
delete(r.ongoingReqs, p)
500+
r.inProgressReqs[p]--
501+
if r.inProgressReqs[p] <= 0 {
502+
delete(r.inProgressReqs, p)
503+
if r.inProgressReqs[p] < 0 {
504+
log.Errorf("BUG: negative in progress requests for peer %s", p)
505+
}
506+
}
499507
}
500508

501509
func (r *rateLimiter) Close() {
502510
r.mu.Lock()
503511
defer r.mu.Unlock()
504512
r.closed = true
505513
r.peerReqs = nil
506-
r.ongoingReqs = nil
514+
r.inProgressReqs = nil
507515
r.dialDataReqs = nil
508516
}
509517

p2p/protocol/autonatv2/server_test.go

+95-7
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func TestServerDataRequest(t *testing.T) {
149149
}
150150
return false
151151
}),
152-
WithServerRateLimit(10, 10, 10),
152+
WithServerRateLimit(10, 10, 10, 2),
153153
withAmplificationAttackPreventionDialWait(0),
154154
)
155155
defer an.Close()
@@ -187,6 +187,69 @@ func TestServerDataRequest(t *testing.T) {
187187
_, err = c.GetReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}})
188188
require.Error(t, err)
189189
}
190+
191+
func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
192+
const concurrentRequests = 5
193+
194+
// server will skip all tcp addresses
195+
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
196+
197+
doneChan := make(chan struct{})
198+
an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy(
199+
// stall all allowed requests
200+
func(s network.Stream, dialAddr ma.Multiaddr) bool {
201+
<-doneChan
202+
return true
203+
}),
204+
WithServerRateLimit(10, 10, 10, concurrentRequests),
205+
withAmplificationAttackPreventionDialWait(0),
206+
)
207+
defer an.Close()
208+
defer an.host.Close()
209+
210+
c := newAutoNAT(t, nil, allowPrivateAddrs)
211+
defer c.Close()
212+
defer c.host.Close()
213+
214+
idAndWait(t, c, an)
215+
216+
errChan := make(chan error)
217+
const N = 10
218+
// num concurrentRequests will stall and N will fail
219+
for i := 0; i < concurrentRequests+N; i++ {
220+
go func() {
221+
_, err := c.GetReachability(context.Background(), []Request{{Addr: c.host.Addrs()[0], SendDialData: false}})
222+
errChan <- err
223+
}()
224+
}
225+
226+
// check N failures
227+
for i := 0; i < N; i++ {
228+
select {
229+
case err := <-errChan:
230+
require.Error(t, err)
231+
case <-time.After(10 * time.Second):
232+
t.Fatalf("expected %d errors: got: %d", N, i)
233+
}
234+
}
235+
236+
// check concurrentRequests failures, as we won't send dial data
237+
close(doneChan)
238+
for i := 0; i < concurrentRequests; i++ {
239+
select {
240+
case err := <-errChan:
241+
require.Error(t, err)
242+
case <-time.After(5 * time.Second):
243+
t.Fatalf("expected %d errors: got: %d", concurrentRequests, i)
244+
}
245+
}
246+
select {
247+
case err := <-errChan:
248+
t.Fatalf("expected no more errors: got: %v", err)
249+
default:
250+
}
251+
}
252+
190253
func TestServerDataRequestJitter(t *testing.T) {
191254
// server will skip all tcp addresses
192255
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
@@ -198,7 +261,7 @@ func TestServerDataRequestJitter(t *testing.T) {
198261
}
199262
return false
200263
}),
201-
WithServerRateLimit(10, 10, 10),
264+
WithServerRateLimit(10, 10, 10, 2),
202265
withAmplificationAttackPreventionDialWait(5*time.Second),
203266
)
204267
defer an.Close()
@@ -238,7 +301,7 @@ func TestServerDataRequestJitter(t *testing.T) {
238301
}
239302

240303
func TestServerDial(t *testing.T) {
241-
an := newAutoNAT(t, nil, WithServerRateLimit(10, 10, 10), allowPrivateAddrs)
304+
an := newAutoNAT(t, nil, WithServerRateLimit(10, 10, 10, 2), allowPrivateAddrs)
242305
defer an.Close()
243306
defer an.host.Close()
244307

@@ -295,7 +358,7 @@ func TestServerDial(t *testing.T) {
295358

296359
func TestRateLimiter(t *testing.T) {
297360
cl := test.NewMockClock()
298-
r := rateLimiter{RPM: 3, PerPeerRPM: 2, DialDataRPM: 1, now: cl.Now}
361+
r := rateLimiter{RPM: 3, PerPeerRPM: 2, DialDataRPM: 1, now: cl.Now, MaxConcurrentRequestsPerPeer: 1}
299362

300363
require.True(t, r.Accept("peer1"))
301364

@@ -333,12 +396,37 @@ func TestRateLimiter(t *testing.T) {
333396

334397
cl.AdvanceBy(10 * time.Second)
335398
require.True(t, r.Accept("peer3"))
399+
400+
}
401+
402+
func TestRateLimiterConcurrentRequests(t *testing.T) {
403+
const N = 5
404+
const Peers = 5
405+
for concurrentRequests := 1; concurrentRequests <= N; concurrentRequests++ {
406+
cl := test.NewMockClock()
407+
r := rateLimiter{RPM: 10 * Peers * N, PerPeerRPM: 10 * Peers * N, DialDataRPM: 10 * Peers * N, now: cl.Now, MaxConcurrentRequestsPerPeer: concurrentRequests}
408+
for p := 0; p < Peers; p++ {
409+
for i := 0; i < concurrentRequests; i++ {
410+
require.True(t, r.Accept(peer.ID(fmt.Sprintf("peer-%d", p))))
411+
}
412+
require.False(t, r.Accept(peer.ID(fmt.Sprintf("peer-%d", p))))
413+
// Now complete the requests
414+
for i := 0; i < concurrentRequests; i++ {
415+
r.CompleteRequest(peer.ID(fmt.Sprintf("peer-%d", p)))
416+
}
417+
// Now we should be able to accept new requests
418+
for i := 0; i < concurrentRequests; i++ {
419+
require.True(t, r.Accept(peer.ID(fmt.Sprintf("peer-%d", p))))
420+
}
421+
require.False(t, r.Accept(peer.ID(fmt.Sprintf("peer-%d", p))))
422+
}
423+
}
336424
}
337425

338426
func TestRateLimiterStress(t *testing.T) {
339427
cl := test.NewMockClock()
340428
for i := 0; i < 10; i++ {
341-
r := rateLimiter{RPM: 20 + i, PerPeerRPM: 10 + i, DialDataRPM: i, now: cl.Now}
429+
r := rateLimiter{RPM: 20 + i, PerPeerRPM: 10 + i, DialDataRPM: i, MaxConcurrentRequestsPerPeer: 1, now: cl.Now}
342430

343431
peers := make([]peer.ID, 10+i)
344432
for i := 0; i < len(peers); i++ {
@@ -386,7 +474,7 @@ func TestRateLimiterStress(t *testing.T) {
386474
require.Equal(t, len(r.peerReqs), 1)
387475
require.Equal(t, len(r.peerReqs[peers[0]]), 1)
388476
require.Equal(t, len(r.dialDataReqs), 0)
389-
require.Equal(t, len(r.ongoingReqs), 1)
477+
require.Equal(t, len(r.inProgressReqs), 1)
390478
}
391479
}
392480

@@ -433,7 +521,7 @@ func TestReadDialData(t *testing.T) {
433521
}
434522

435523
func FuzzServerDialRequest(f *testing.F) {
436-
a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32))
524+
a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2))
437525
c := newAutoNAT(f, nil)
438526
idAndWait(f, c, a)
439527
// reduce the streamTimeout before running this. TODO: fix this

0 commit comments

Comments
 (0)