Skip to content

Commit ce3951d

Browse files
authored
Protect Libp2p Connections (#229)
* feat(requestmanager): add connection protection * refactor(testutil): extract TestConnManager * feat(responsemanager): add connection holding also uncovered a bug in early cancellations, resolved by using state pattern from requestmanager * refactor(graphsync): change string to unique tag make tag for request IDs unique to graphsync
1 parent 5c5f1e8 commit ce3951d

14 files changed

+246
-77
lines changed

benchmarks/testnet/virtual.go

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
delay "github.com/ipfs/go-ipfs-delay"
1111
mockrouting "github.com/ipfs/go-ipfs-routing/mock"
12+
"github.com/libp2p/go-libp2p-core/connmgr"
1213
"github.com/libp2p/go-libp2p-core/peer"
1314
tnet "github.com/libp2p/go-libp2p-testing/net"
1415
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
@@ -255,6 +256,10 @@ func (nc *networkClient) DisconnectFrom(_ context.Context, p peer.ID) error {
255256
return nil
256257
}
257258

259+
func (nc *networkClient) ConnectionManager() gsnet.ConnManager {
260+
return &connmgr.NullConnMgr{}
261+
}
262+
258263
func (rq *receiverQueue) enqueue(m *message) {
259264
rq.lk.Lock()
260265
defer rq.lk.Unlock()

graphsync.go

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ import (
1414
// RequestID is a unique identifier for a GraphSync request.
1515
type RequestID int32
1616

17+
// Tag returns an easy way to identify this request id as a graphsync request (for libp2p connections)
18+
func (r RequestID) Tag() string {
19+
return fmt.Sprintf("graphsync-request-%d", r)
20+
}
21+
1722
// Priority a priority for a GraphSync request.
1823
type Priority int32
1924

impl/graphsync.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,11 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork,
179179

180180
asyncLoader := asyncloader.New(ctx, linkSystem, requestAllocator)
181181
requestQueue := taskqueue.NewTaskQueue(ctx)
182-
requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, networkErrorListeners, requestQueue)
182+
requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, networkErrorListeners, requestQueue, network.ConnectionManager())
183183
requestExecutor := executor.NewExecutor(requestManager, incomingBlockHooks, asyncLoader.AsyncLoad)
184184
responseAssembler := responseassembler.New(ctx, peerManager)
185185
peerTaskQueue := peertaskqueue.New()
186-
responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressIncomingRequests)
186+
responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressIncomingRequests, network.ConnectionManager())
187187
graphSync := &GraphSync{
188188
network: network,
189189
linkSystem: linkSystem,

network/interface.go

+8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ type GraphSyncNetwork interface {
3131
ConnectTo(context.Context, peer.ID) error
3232

3333
NewMessageSender(context.Context, peer.ID) (MessageSender, error)
34+
35+
ConnectionManager() ConnManager
36+
}
37+
38+
// ConnManager provides the methods needed to protect and unprotect connections
39+
type ConnManager interface {
40+
Protect(peer.ID, string)
41+
Unprotect(peer.ID, string) bool
3442
}
3543

3644
// MessageSender is an interface to send messages to a peer

network/libp2p_impl.go

+4
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ func (gsnet *libp2pGraphSyncNetwork) handleNewStream(s network.Stream) {
151151
}
152152
}
153153

154+
func (gsnet *libp2pGraphSyncNetwork) ConnectionManager() ConnManager {
155+
return gsnet.host.ConnManager()
156+
}
157+
154158
type libp2pGraphSyncNotifee libp2pGraphSyncNetwork
155159

156160
func (nn *libp2pGraphSyncNotifee) libp2pGraphSyncNetwork() *libp2pGraphSyncNetwork {

requestmanager/client.go

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
gsmsg "github.com/ipfs/go-graphsync/message"
2424
"github.com/ipfs/go-graphsync/messagequeue"
2525
"github.com/ipfs/go-graphsync/metadata"
26+
"github.com/ipfs/go-graphsync/network"
2627
"github.com/ipfs/go-graphsync/notifications"
2728
"github.com/ipfs/go-graphsync/requestmanager/executor"
2829
"github.com/ipfs/go-graphsync/requestmanager/hooks"
@@ -94,6 +95,7 @@ type RequestManager struct {
9495
asyncLoader AsyncLoader
9596
disconnectNotif *pubsub.PubSub
9697
linkSystem ipld.LinkSystem
98+
connManager network.ConnManager
9799

98100
// dont touch out side of run loop
99101
nextRequestID graphsync.RequestID
@@ -126,6 +128,7 @@ func New(ctx context.Context,
126128
responseHooks ResponseHooks,
127129
networkErrorListeners *listeners.NetworkErrorListeners,
128130
requestQueue taskqueue.TaskQueue,
131+
connManager network.ConnManager,
129132
) *RequestManager {
130133
ctx, cancel := context.WithCancel(ctx)
131134
return &RequestManager{
@@ -141,6 +144,7 @@ func New(ctx context.Context,
141144
responseHooks: responseHooks,
142145
networkErrorListeners: networkErrorListeners,
143146
requestQueue: requestQueue,
147+
connManager: connManager,
144148
}
145149
}
146150

requestmanager/requestmanager_test.go

+87-63
Original file line numberDiff line numberDiff line change
@@ -29,68 +29,6 @@ import (
2929
"github.com/ipfs/go-graphsync/testutil"
3030
)
3131

32-
type requestRecord struct {
33-
gsr gsmsg.GraphSyncRequest
34-
p peer.ID
35-
}
36-
37-
type fakePeerHandler struct {
38-
requestRecordChan chan requestRecord
39-
}
40-
41-
func (fph *fakePeerHandler) AllocateAndBuildMessage(p peer.ID, blkSize uint64,
42-
requestBuilder func(b *gsmsg.Builder), notifees []notifications.Notifee) {
43-
builder := gsmsg.NewBuilder(gsmsg.Topic(0))
44-
requestBuilder(builder)
45-
message, err := builder.Build()
46-
if err != nil {
47-
panic(err)
48-
}
49-
fph.requestRecordChan <- requestRecord{
50-
gsr: message.Requests()[0],
51-
p: p,
52-
}
53-
}
54-
55-
func readNNetworkRequests(ctx context.Context,
56-
t *testing.T,
57-
requestRecordChan <-chan requestRecord,
58-
count int) []requestRecord {
59-
requestRecords := make([]requestRecord, 0, count)
60-
for i := 0; i < count; i++ {
61-
var rr requestRecord
62-
testutil.AssertReceive(ctx, t, requestRecordChan, &rr, fmt.Sprintf("did not receive request %d", i))
63-
requestRecords = append(requestRecords, rr)
64-
}
65-
// because of the simultaneous request queues it's possible for the requests to go to the network layer out of order
66-
// if the requests are queued at a near identical time
67-
sort.Slice(requestRecords, func(i, j int) bool {
68-
return requestRecords[i].gsr.ID() < requestRecords[j].gsr.ID()
69-
})
70-
return requestRecords
71-
}
72-
73-
func metadataForBlocks(blks []blocks.Block, present bool) metadata.Metadata {
74-
md := make(metadata.Metadata, 0, len(blks))
75-
for _, block := range blks {
76-
md = append(md, metadata.Item{
77-
Link: block.Cid(),
78-
BlockPresent: present,
79-
})
80-
}
81-
return md
82-
}
83-
84-
func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) graphsync.ExtensionData {
85-
md := metadataForBlocks(blks, present)
86-
metadataEncoded, err := metadata.EncodeMetadata(md)
87-
require.NoError(t, err, "did not encode metadata")
88-
return graphsync.ExtensionData{
89-
Name: graphsync.ExtensionMetadata,
90-
Data: metadataEncoded,
91-
}
92-
}
93-
9432
func TestNormalSimultaneousFetch(t *testing.T) {
9533
ctx := context.Background()
9634
td := newTestData(ctx, t)
@@ -106,6 +44,8 @@ func TestNormalSimultaneousFetch(t *testing.T) {
10644

10745
requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2)
10846

47+
td.tcm.AssertProtected(t, peers[0])
48+
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag(), requestRecords[1].gsr.ID().Tag())
10949
require.Equal(t, peers[0], requestRecords[0].p)
11050
require.Equal(t, peers[0], requestRecords[1].p)
11151
require.False(t, requestRecords[0].gsr.IsCancel())
@@ -148,6 +88,10 @@ func TestNormalSimultaneousFetch(t *testing.T) {
14888
td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan1)
14989
blockChain2.VerifyResponseRange(requestCtx, returnedResponseChan2, 0, 3)
15090

91+
td.tcm.AssertProtected(t, peers[0])
92+
td.tcm.RefuteProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag())
93+
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[1].gsr.ID().Tag())
94+
15195
moreBlocks := blockChain2.RemainderBlocks(3)
15296
moreMetadata := metadataForBlocks(moreBlocks, true)
15397
moreMetadataEncoded, err := metadata.EncodeMetadata(moreMetadata)
@@ -170,6 +114,8 @@ func TestNormalSimultaneousFetch(t *testing.T) {
170114
blockChain2.VerifyRemainder(requestCtx, returnedResponseChan2, 3)
171115
testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan1)
172116
testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan2)
117+
118+
td.tcm.RefuteProtected(t, peers[0])
173119
}
174120

175121
func TestCancelRequestInProgress(t *testing.T) {
@@ -187,6 +133,9 @@ func TestCancelRequestInProgress(t *testing.T) {
187133

188134
requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2)
189135

136+
td.tcm.AssertProtected(t, peers[0])
137+
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag(), requestRecords[1].gsr.ID().Tag())
138+
190139
firstBlocks := td.blockChain.Blocks(0, 3)
191140
firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true)
192141
firstResponses := []gsmsg.GraphSyncResponse{
@@ -224,6 +173,8 @@ func TestCancelRequestInProgress(t *testing.T) {
224173
require.Len(t, errors, 1)
225174
_, ok := errors[0].(graphsync.RequestClientCancelledErr)
226175
require.True(t, ok)
176+
177+
td.tcm.RefuteProtected(t, peers[0])
227178
}
228179
func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {
229180
ctx := context.Background()
@@ -246,6 +197,9 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {
246197

247198
requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)
248199

200+
td.tcm.AssertProtected(t, peers[0])
201+
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag())
202+
249203
go func() {
250204
firstBlocks := td.blockChain.Blocks(0, 3)
251205
firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true)
@@ -267,6 +221,8 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {
267221
require.True(t, rr.gsr.IsCancel())
268222
require.Equal(t, requestRecords[0].gsr.ID(), rr.gsr.ID())
269223

224+
td.tcm.RefuteProtected(t, peers[0])
225+
270226
errors := testutil.CollectErrors(requestCtx, t, returnedErrorChan1)
271227
require.Len(t, errors, 1)
272228
_, ok := errors[0].(graphsync.RequestClientCancelledErr)
@@ -321,13 +277,17 @@ func TestFailedRequest(t *testing.T) {
321277
returnedResponseChan, returnedErrorChan := td.requestManager.NewRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector())
322278

323279
rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0]
280+
td.tcm.AssertProtected(t, peers[0])
281+
td.tcm.AssertProtectedWithTags(t, peers[0], rr.gsr.ID().Tag())
282+
324283
failedResponses := []gsmsg.GraphSyncResponse{
325284
gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestFailedContentNotFound),
326285
}
327286
td.requestManager.ProcessResponses(peers[0], failedResponses, nil)
328287

329288
testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan)
330289
testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan)
290+
td.tcm.RefuteProtected(t, peers[0])
331291
}
332292

333293
func TestLocallyFulfilledFirstRequestFailsLater(t *testing.T) {
@@ -962,10 +922,73 @@ func TestPauseResumeExternal(t *testing.T) {
962922
testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan)
963923
}
964924

925+
type requestRecord struct {
926+
gsr gsmsg.GraphSyncRequest
927+
p peer.ID
928+
}
929+
930+
type fakePeerHandler struct {
931+
requestRecordChan chan requestRecord
932+
}
933+
934+
func (fph *fakePeerHandler) AllocateAndBuildMessage(p peer.ID, blkSize uint64,
935+
requestBuilder func(b *gsmsg.Builder), notifees []notifications.Notifee) {
936+
builder := gsmsg.NewBuilder(gsmsg.Topic(0))
937+
requestBuilder(builder)
938+
message, err := builder.Build()
939+
if err != nil {
940+
panic(err)
941+
}
942+
fph.requestRecordChan <- requestRecord{
943+
gsr: message.Requests()[0],
944+
p: p,
945+
}
946+
}
947+
948+
func readNNetworkRequests(ctx context.Context,
949+
t *testing.T,
950+
requestRecordChan <-chan requestRecord,
951+
count int) []requestRecord {
952+
requestRecords := make([]requestRecord, 0, count)
953+
for i := 0; i < count; i++ {
954+
var rr requestRecord
955+
testutil.AssertReceive(ctx, t, requestRecordChan, &rr, fmt.Sprintf("did not receive request %d", i))
956+
requestRecords = append(requestRecords, rr)
957+
}
958+
// because of the simultaneous request queues it's possible for the requests to go to the network layer out of order
959+
// if the requests are queued at a near identical time
960+
sort.Slice(requestRecords, func(i, j int) bool {
961+
return requestRecords[i].gsr.ID() < requestRecords[j].gsr.ID()
962+
})
963+
return requestRecords
964+
}
965+
966+
func metadataForBlocks(blks []blocks.Block, present bool) metadata.Metadata {
967+
md := make(metadata.Metadata, 0, len(blks))
968+
for _, block := range blks {
969+
md = append(md, metadata.Item{
970+
Link: block.Cid(),
971+
BlockPresent: present,
972+
})
973+
}
974+
return md
975+
}
976+
977+
func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) graphsync.ExtensionData {
978+
md := metadataForBlocks(blks, present)
979+
metadataEncoded, err := metadata.EncodeMetadata(md)
980+
require.NoError(t, err, "did not encode metadata")
981+
return graphsync.ExtensionData{
982+
Name: graphsync.ExtensionMetadata,
983+
Data: metadataEncoded,
984+
}
985+
}
986+
965987
type testData struct {
966988
requestRecordChan chan requestRecord
967989
fph *fakePeerHandler
968990
fal *testloader.FakeAsyncLoader
991+
tcm *testutil.TestConnManager
969992
requestHooks *hooks.OutgoingRequestHooks
970993
responseHooks *hooks.IncomingResponseHooks
971994
blockHooks *hooks.IncomingBlockHooks
@@ -989,13 +1012,14 @@ func newTestData(ctx context.Context, t *testing.T) *testData {
9891012
td.requestRecordChan = make(chan requestRecord, 3)
9901013
td.fph = &fakePeerHandler{td.requestRecordChan}
9911014
td.fal = testloader.NewFakeAsyncLoader()
1015+
td.tcm = testutil.NewTestConnManager()
9921016
td.requestHooks = hooks.NewRequestHooks()
9931017
td.responseHooks = hooks.NewResponseHooks()
9941018
td.blockHooks = hooks.NewBlockHooks()
9951019
td.networkErrorListeners = listeners.NewNetworkErrorListeners()
9961020
td.taskqueue = taskqueue.NewTaskQueue(ctx)
9971021
lsys := cidlink.DefaultLinkSystem()
998-
td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue)
1022+
td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.tcm)
9991023
td.executor = executor.NewExecutor(td.requestManager, td.blockHooks, td.fal.AsyncLoad)
10001024
td.requestManager.SetDelegate(td.fph)
10011025
td.requestManager.Startup()

requestmanager/server.go

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ func (rm *RequestManager) newRequest(p peer.ID, root ipld.Link, selector ipld.No
8787
requestStatus.lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged))
8888
rm.inProgressRequestStatuses[request.ID()] = requestStatus
8989

90+
rm.connManager.Protect(p, requestID.Tag())
9091
rm.requestQueue.PushTask(p, peertask.Task{Topic: requestID, Priority: math.MaxInt32, Work: 1})
9192
return request, requestStatus.inProgressChan, requestStatus.inProgressErr
9293
}
@@ -151,6 +152,7 @@ func (rm *RequestManager) terminateRequest(requestID graphsync.RequestID, ipr *i
151152
case <-rm.ctx.Done():
152153
}
153154
}
155+
rm.connManager.Unprotect(ipr.p, requestID.Tag())
154156
delete(rm.inProgressRequestStatuses, requestID)
155157
ipr.cancelFn()
156158
rm.asyncLoader.CleanupRequest(requestID)

0 commit comments

Comments
 (0)