Skip to content

Commit 6760665

Browse files
authored
Option to Reject requests by default (#58)
* refactor(hooks): refactor default validation as hook, add unregister option * feat(graphsync): add disable default validation option * fix(responsemanager): fix mutex unlocking cover case where unlocking was not happening
1 parent b3cc648 commit 6760665

File tree

8 files changed

+246
-145
lines changed

8 files changed

+246
-145
lines changed

graphsync.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ type OnRequestReceivedHook func(p peer.ID, request RequestData, hookActions Requ
154154
// If it returns an error processing is halted and the original request is cancelled.
155155
type OnResponseReceivedHook func(p peer.ID, responseData ResponseData) error
156156

157+
// UnregisterHookFunc is a function call to unregister a hook that was previously registered
158+
type UnregisterHookFunc func()
159+
157160
// GraphExchange is a protocol that can exchange IPLD graphs based on a selector
158161
type GraphExchange interface {
159162
// Request initiates a new GraphSync request to the given peer using the given selector spec.
@@ -163,8 +166,8 @@ type GraphExchange interface {
163166
// If overrideDefaultValidation is set to true, then if the hook does not error,
164167
// it is considered to have "validated" the request -- and that validation supersedes
165168
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
166-
RegisterRequestReceivedHook(hook OnRequestReceivedHook) error
169+
RegisterRequestReceivedHook(hook OnRequestReceivedHook) UnregisterHookFunc
167170

168171
// RegisterResponseReceivedHook adds a hook that runs when a response is received
169-
RegisterResponseReceivedHook(OnResponseReceivedHook) error
172+
RegisterResponseReceivedHook(OnResponseReceivedHook) UnregisterHookFunc
170173
}

impl/graphsync.go

+50-31
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ import (
44
"context"
55

66
"github.com/ipfs/go-graphsync"
7-
"github.com/ipfs/go-graphsync/requestmanager/asyncloader"
8-
97
gsmsg "github.com/ipfs/go-graphsync/message"
108
"github.com/ipfs/go-graphsync/messagequeue"
119
gsnet "github.com/ipfs/go-graphsync/network"
1210
"github.com/ipfs/go-graphsync/peermanager"
1311
"github.com/ipfs/go-graphsync/requestmanager"
12+
"github.com/ipfs/go-graphsync/requestmanager/asyncloader"
1413
"github.com/ipfs/go-graphsync/responsemanager"
1514
"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
15+
"github.com/ipfs/go-graphsync/selectorvalidator"
1616
logging "github.com/ipfs/go-log"
1717
"github.com/ipfs/go-peertaskqueue"
1818
ipld "github.com/ipld/go-ipld-prime"
@@ -21,26 +21,41 @@ import (
2121

2222
var log = logging.Logger("graphsync")
2323

24+
const maxRecursionDepth = 100
25+
2426
// GraphSync is an instance of a GraphSync exchange that implements
2527
// the graphsync protocol.
2628
type GraphSync struct {
27-
network gsnet.GraphSyncNetwork
28-
loader ipld.Loader
29-
storer ipld.Storer
30-
requestManager *requestmanager.RequestManager
31-
responseManager *responsemanager.ResponseManager
32-
asyncLoader *asyncloader.AsyncLoader
33-
peerResponseManager *peerresponsemanager.PeerResponseManager
34-
peerTaskQueue *peertaskqueue.PeerTaskQueue
35-
peerManager *peermanager.PeerMessageManager
36-
ctx context.Context
37-
cancel context.CancelFunc
29+
network gsnet.GraphSyncNetwork
30+
loader ipld.Loader
31+
storer ipld.Storer
32+
requestManager *requestmanager.RequestManager
33+
responseManager *responsemanager.ResponseManager
34+
asyncLoader *asyncloader.AsyncLoader
35+
peerResponseManager *peerresponsemanager.PeerResponseManager
36+
peerTaskQueue *peertaskqueue.PeerTaskQueue
37+
peerManager *peermanager.PeerMessageManager
38+
ctx context.Context
39+
cancel context.CancelFunc
40+
unregisterDefaultValidator graphsync.UnregisterHookFunc
41+
}
42+
43+
// Option defines the functional option type that can be used to configure
44+
// graphsync instances
45+
type Option func(*GraphSync)
46+
47+
// RejectAllRequestsByDefault means that without hooks registered
48+
// that perform their own request validation, all requests are rejected
49+
func RejectAllRequestsByDefault() Option {
50+
return func(gs *GraphSync) {
51+
gs.unregisterDefaultValidator()
52+
}
3853
}
3954

4055
// New creates a new GraphSync Exchange on the given network,
4156
// and the given link loader+storer.
4257
func New(parent context.Context, network gsnet.GraphSyncNetwork,
43-
loader ipld.Loader, storer ipld.Storer) graphsync.GraphExchange {
58+
loader ipld.Loader, storer ipld.Storer, options ...Option) graphsync.GraphExchange {
4459
ctx, cancel := context.WithCancel(parent)
4560

4661
createMessageQueue := func(ctx context.Context, p peer.ID) peermanager.PeerQueue {
@@ -55,18 +70,24 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork,
5570
}
5671
peerResponseManager := peerresponsemanager.New(ctx, createdResponseQueue)
5772
responseManager := responsemanager.New(ctx, loader, peerResponseManager, peerTaskQueue)
73+
unregisterDefaultValidator := responseManager.RegisterHook(selectorvalidator.SelectorValidator(maxRecursionDepth))
5874
graphSync := &GraphSync{
59-
network: network,
60-
loader: loader,
61-
storer: storer,
62-
asyncLoader: asyncLoader,
63-
requestManager: requestManager,
64-
peerManager: peerManager,
65-
peerTaskQueue: peerTaskQueue,
66-
peerResponseManager: peerResponseManager,
67-
responseManager: responseManager,
68-
ctx: ctx,
69-
cancel: cancel,
75+
network: network,
76+
loader: loader,
77+
storer: storer,
78+
asyncLoader: asyncLoader,
79+
requestManager: requestManager,
80+
peerManager: peerManager,
81+
peerTaskQueue: peerTaskQueue,
82+
peerResponseManager: peerResponseManager,
83+
responseManager: responseManager,
84+
ctx: ctx,
85+
cancel: cancel,
86+
unregisterDefaultValidator: unregisterDefaultValidator,
87+
}
88+
89+
for _, option := range options {
90+
option(graphSync)
7091
}
7192

7293
asyncLoader.Startup()
@@ -86,15 +107,13 @@ func (gs *GraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, sel
86107
// If overrideDefaultValidation is set to true, then if the hook does not error,
87108
// it is considered to have "validated" the request -- and that validation supersedes
88109
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
89-
func (gs *GraphSync) RegisterRequestReceivedHook(hook graphsync.OnRequestReceivedHook) error {
90-
gs.responseManager.RegisterHook(hook)
91-
return nil
110+
func (gs *GraphSync) RegisterRequestReceivedHook(hook graphsync.OnRequestReceivedHook) graphsync.UnregisterHookFunc {
111+
return gs.responseManager.RegisterHook(hook)
92112
}
93113

94114
// RegisterResponseReceivedHook adds a hook that runs when a response is received
95-
func (gs *GraphSync) RegisterResponseReceivedHook(hook graphsync.OnResponseReceivedHook) error {
96-
gs.requestManager.RegisterHook(hook)
97-
return nil
115+
func (gs *GraphSync) RegisterResponseReceivedHook(hook graphsync.OnResponseReceivedHook) graphsync.UnregisterHookFunc {
116+
return gs.requestManager.RegisterHook(hook)
98117
}
99118

100119
type graphSyncReceiver GraphSync

impl/graphsync_test.go

+31-14
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,14 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
9999
var receivedRequestData []byte
100100
// initialize graphsync on second node to response to requests
101101
gsnet := td.GraphSyncHost2()
102-
err := gsnet.RegisterRequestReceivedHook(
102+
gsnet.RegisterRequestReceivedHook(
103103
func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
104104
var has bool
105105
receivedRequestData, has = requestData.Extension(td.extensionName)
106106
require.True(t, has, "did not have expected extension")
107107
hookActions.SendExtensionData(td.extensionResponse)
108108
},
109109
)
110-
require.NoError(t, err, "error registering extension")
111110

112111
blockChainLength := 100
113112
blockChain := testutil.SetupBlockChain(ctx, t, td.loader2, td.storer2, 100, blockChainLength)
@@ -117,7 +116,7 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
117116
message := gsmsg.New()
118117
message.AddRequest(gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32), td.extension))
119118
// send request across network
120-
err = td.gsnet1.SendMessage(ctx, td.host2.ID(), message)
119+
err := td.gsnet1.SendMessage(ctx, td.host2.ID(), message)
121120
require.NoError(t, err)
122121
// read the values sent back to requestor
123122
var received gsmsg.GraphSyncMessage
@@ -150,6 +149,27 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
150149
require.Equal(t, td.extensionResponseData, receivedExtensions[0], "did not return correct extension data")
151150
}
152151

152+
func TestRejectRequestsByDefault(t *testing.T) {
153+
// create network
154+
ctx := context.Background()
155+
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
156+
defer cancel()
157+
td := newGsTestData(ctx, t)
158+
159+
requestor := td.GraphSyncHost1()
160+
// setup responder to disable default validation, meaning all requests are rejected
161+
_ = td.GraphSyncHost2(RejectAllRequestsByDefault())
162+
163+
blockChainLength := 5
164+
blockChain := testutil.SetupBlockChain(ctx, t, td.loader2, td.storer2, 5, blockChainLength)
165+
166+
// send request across network
167+
progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension)
168+
169+
testutil.VerifyEmptyResponse(ctx, t, progressChan)
170+
testutil.VerifySingleTerminalError(ctx, t, errChan)
171+
}
172+
153173
func TestGraphsyncRoundTrip(t *testing.T) {
154174
// create network
155175
ctx := context.Background()
@@ -170,17 +190,16 @@ func TestGraphsyncRoundTrip(t *testing.T) {
170190
var receivedResponseData []byte
171191
var receivedRequestData []byte
172192

173-
err := requestor.RegisterResponseReceivedHook(
193+
requestor.RegisterResponseReceivedHook(
174194
func(p peer.ID, responseData graphsync.ResponseData) error {
175195
data, has := responseData.Extension(td.extensionName)
176196
if has {
177197
receivedResponseData = data
178198
}
179199
return nil
180200
})
181-
require.NoError(t, err, "Error setting up extension")
182201

183-
err = responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
202+
responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
184203
var has bool
185204
receivedRequestData, has = requestData.Extension(td.extensionName)
186205
if !has {
@@ -189,7 +208,6 @@ func TestGraphsyncRoundTrip(t *testing.T) {
189208
hookActions.SendExtensionData(td.extensionResponse)
190209
}
191210
})
192-
require.NoError(t, err, "Error setting up extension")
193211

194212
progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension)
195213

@@ -342,15 +360,14 @@ func TestUnixFSFetch(t *testing.T) {
342360
requestor := New(ctx, td.gsnet1, loader1, storer1)
343361
responder := New(ctx, td.gsnet2, loader2, storer2)
344362
extensionName := graphsync.ExtensionName("Free for all")
345-
err = responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
363+
responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
346364
hookActions.ValidateRequest()
347365
hookActions.SendExtensionData(graphsync.ExtensionData{
348366
Name: extensionName,
349367
Data: nil,
350368
})
351369
})
352-
require.NoError(t, err)
353-
370+
354371
// make a go-ipld-prime link for the root UnixFS node
355372
clink := cidlink.Link{Cid: nd.Cid()}
356373

@@ -443,13 +460,13 @@ func newGsTestData(ctx context.Context, t *testing.T) *gsTestData {
443460
return td
444461
}
445462

446-
func (td *gsTestData) GraphSyncHost1() graphsync.GraphExchange {
447-
return New(td.ctx, td.gsnet1, td.loader1, td.storer1)
463+
func (td *gsTestData) GraphSyncHost1(options ...Option) graphsync.GraphExchange {
464+
return New(td.ctx, td.gsnet1, td.loader1, td.storer1, options...)
448465
}
449466

450-
func (td *gsTestData) GraphSyncHost2() graphsync.GraphExchange {
467+
func (td *gsTestData) GraphSyncHost2(options ...Option) graphsync.GraphExchange {
451468

452-
return New(td.ctx, td.gsnet2, td.loader2, td.storer2)
469+
return New(td.ctx, td.gsnet2, td.loader2, td.storer2, options...)
453470
}
454471

455472
type receivedMessage struct {

requestmanager/requestmanager.go

+32-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type inProgressRequestStatus struct {
3434
}
3535

3636
type responseHook struct {
37+
key uint64
3738
hook graphsync.OnResponseReceivedHook
3839
}
3940

@@ -65,6 +66,7 @@ type RequestManager struct {
6566
// dont touch out side of run loop
6667
nextRequestID graphsync.RequestID
6768
inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus
69+
responseHookNextKey uint64
6870
responseHooks []responseHook
6971
}
7072

@@ -201,12 +203,25 @@ func (rm *RequestManager) ProcessResponses(p peer.ID, responses []gsmsg.GraphSyn
201203
}
202204
}
203205

206+
type registerHookMessage struct {
207+
hook graphsync.OnResponseReceivedHook
208+
unregisterHookChan chan graphsync.UnregisterHookFunc
209+
}
210+
204211
// RegisterHook registers an extension to processincoming responses
205212
func (rm *RequestManager) RegisterHook(
206-
hook graphsync.OnResponseReceivedHook) {
213+
hook graphsync.OnResponseReceivedHook) graphsync.UnregisterHookFunc {
214+
response := make(chan graphsync.UnregisterHookFunc)
215+
select {
216+
case rm.messages <- &registerHookMessage{hook, response}:
217+
case <-rm.ctx.Done():
218+
return nil
219+
}
207220
select {
208-
case rm.messages <- &responseHook{hook}:
221+
case unregister := <-response:
222+
return unregister
209223
case <-rm.ctx.Done():
224+
return nil
210225
}
211226
}
212227

@@ -285,8 +300,21 @@ func (prm *processResponseMessage) handle(rm *RequestManager) {
285300
rm.processTerminations(filteredResponses)
286301
}
287302

288-
func (rh *responseHook) handle(rm *RequestManager) {
289-
rm.responseHooks = append(rm.responseHooks, *rh)
303+
func (rhm *registerHookMessage) handle(rm *RequestManager) {
304+
rh := responseHook{rm.responseHookNextKey, rhm.hook}
305+
rm.responseHookNextKey++
306+
rm.responseHooks = append(rm.responseHooks, rh)
307+
select {
308+
case rhm.unregisterHookChan <- func() {
309+
for i, matchHook := range rm.responseHooks {
310+
if rh.key == matchHook.key {
311+
rm.responseHooks = append(rm.responseHooks[:i], rm.responseHooks[i+1:]...)
312+
return
313+
}
314+
}
315+
}:
316+
case <-rm.ctx.Done():
317+
}
290318
}
291319

292320
func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse {

0 commit comments

Comments
 (0)