diff --git a/graphsync.go b/graphsync.go index bae45ed8..80102c10 100644 --- a/graphsync.go +++ b/graphsync.go @@ -486,25 +486,15 @@ type GraphExchange interface { // RegisterReceiverNetworkErrorListener adds a listener for when errors occur receiving data over the wire RegisterReceiverNetworkErrorListener(listener OnReceiverNetworkErrorListener) UnregisterHookFunc - // UnpauseRequest unpauses a request that was paused in a block hook based request ID - // Can also send extensions with unpause - UnpauseRequest(RequestID, ...ExtensionData) error - - // PauseRequest pauses an in progress request (may take 1 or more blocks to process) - PauseRequest(RequestID) error + // Pause pauses an in progress request or response (may take 1 or more blocks to process) + Pause(context.Context, RequestID) error - // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID + // Unpause unpauses a request or response that was paused // Can also send extensions with unpause - UnpauseResponse(peer.ID, RequestID, ...ExtensionData) error - - // PauseResponse pauses an in progress response (may take 1 or more blocks to process) - PauseResponse(peer.ID, RequestID) error - - // CancelResponse cancels an in progress response - CancelResponse(peer.ID, RequestID) error + Unpause(context.Context, RequestID, ...ExtensionData) error - // CancelRequest cancels an in progress request - CancelRequest(context.Context, RequestID) error + // Cancel cancels an in progress request or response + Cancel(context.Context, RequestID) error // Stats produces insight on the current state of a graphsync exchange Stats() Stats diff --git a/impl/graphsync.go b/impl/graphsync.go index 0f546b24..be70f32c 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -2,6 +2,7 @@ package graphsync import ( "context" + "errors" "time" logging "github.com/ipfs/go-log/v2" @@ -296,6 +297,7 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, responseManager.Startup() responseQueue.Startup(gsConfig.maxInProgressIncomingRequests, queryExecutor) network.SetDelegate((*graphSyncReceiver)(graphSync)) + return graphSync } @@ -402,35 +404,32 @@ func (gs *GraphSync) RegisterReceiverNetworkErrorListener(listener graphsync.OnR return gs.receiverErrorListeners.Register(listener) } -// UnpauseRequest unpauses a request that was paused in a block hook based request ID -// Can also send extensions with unpause -func (gs *GraphSync) UnpauseRequest(requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { - return gs.requestManager.UnpauseRequest(requestID, extensions...) -} - -// PauseRequest pauses an in progress request (may take 1 or more blocks to process) -func (gs *GraphSync) PauseRequest(requestID graphsync.RequestID) error { - return gs.requestManager.PauseRequest(requestID) -} - -// UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID -func (gs *GraphSync) UnpauseResponse(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { - return gs.responseManager.UnpauseResponse(p, requestID, extensions...) -} - -// PauseResponse pauses an in progress response (may take 1 or more blocks to process) -func (gs *GraphSync) PauseResponse(p peer.ID, requestID graphsync.RequestID) error { - return gs.responseManager.PauseResponse(p, requestID) +// Pause pauses an in progress request or response +func (gs *GraphSync) Pause(ctx context.Context, requestID graphsync.RequestID) error { + var reqNotFound graphsync.RequestNotFoundErr + if err := gs.requestManager.PauseRequest(ctx, requestID); !errors.As(err, &reqNotFound) { + return err + } + return gs.responseManager.PauseResponse(ctx, requestID) } -// CancelResponse cancels an in progress response -func (gs *GraphSync) CancelResponse(p peer.ID, requestID graphsync.RequestID) error { - return gs.responseManager.CancelResponse(p, requestID) +// Unpause unpauses a request or response that was paused +// Can also send extensions with unpause +func (gs *GraphSync) Unpause(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { + var reqNotFound graphsync.RequestNotFoundErr + if err := gs.requestManager.UnpauseRequest(ctx, requestID, extensions...); !errors.As(err, &reqNotFound) { + return err + } + return gs.responseManager.UnpauseResponse(ctx, requestID, extensions...) } -// CancelRequest cancels an in progress request -func (gs *GraphSync) CancelRequest(ctx context.Context, requestID graphsync.RequestID) error { - return gs.requestManager.CancelRequest(ctx, requestID) +// Cancel cancels an in progress request or response +func (gs *GraphSync) Cancel(ctx context.Context, requestID graphsync.RequestID) error { + var reqNotFound graphsync.RequestNotFoundErr + if err := gs.requestManager.CancelRequest(ctx, requestID); !errors.As(err, &reqNotFound) { + return err + } + return gs.responseManager.CancelResponse(ctx, requestID) } // Stats produces insight on the current state of a graphsync exchange diff --git a/impl/graphsync_test.go b/impl/graphsync_test.go index 921f4215..524ac3b8 100644 --- a/impl/graphsync_test.go +++ b/impl/graphsync_test.go @@ -723,7 +723,7 @@ func TestPauseResume(t *testing.T) { require.Len(t, responderPeerState.IncomingState.Diagnostics(), 0) requestID := <-requestIDChan - err := responder.UnpauseResponse(td.host1.ID(), requestID) + err := responder.Unpause(ctx, requestID) require.NoError(t, err) blockChain.VerifyRemainder(ctx, progressChan, stopPoint) @@ -793,7 +793,7 @@ func TestPauseResumeRequest(t *testing.T) { testutil.AssertDoesReceiveFirst(t, timer.C, "should pause request", progressChan) requestID := <-requestIDChan - err := requestor.UnpauseRequest(requestID, td.extensionUpdate) + err := requestor.Unpause(ctx, requestID, td.extensionUpdate) require.NoError(t, err) blockChain.VerifyRemainder(ctx, progressChan, stopPoint) @@ -1092,7 +1092,7 @@ func TestNetworkDisconnect(t *testing.T) { require.NoError(t, td.mn.DisconnectPeers(td.host1.ID(), td.host2.ID())) require.NoError(t, td.mn.UnlinkPeers(td.host1.ID(), td.host2.ID())) requestID := <-requestIDChan - err := responder.UnpauseResponse(td.host1.ID(), requestID) + err := responder.Unpause(ctx, requestID) require.NoError(t, err) testutil.AssertReceive(ctx, t, networkError, &err, "should receive network error") diff --git a/requestmanager/client.go b/requestmanager/client.go index 5002f516..79453d0b 100644 --- a/requestmanager/client.go +++ b/requestmanager/client.go @@ -292,9 +292,9 @@ func (rm *RequestManager) ProcessResponses(p peer.ID, // UnpauseRequest unpauses a request that was paused in a block hook based request ID // Can also send extensions with unpause -func (rm *RequestManager) UnpauseRequest(requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { +func (rm *RequestManager) UnpauseRequest(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { response := make(chan error, 1) - rm.send(&unpauseRequestMessage{requestID, extensions, response}, nil) + rm.send(&unpauseRequestMessage{requestID, extensions, response}, ctx.Done()) select { case <-rm.ctx.Done(): return errors.New("context cancelled") @@ -304,9 +304,9 @@ func (rm *RequestManager) UnpauseRequest(requestID graphsync.RequestID, extensio } // PauseRequest pauses an in progress request (may take 1 or more blocks to process) -func (rm *RequestManager) PauseRequest(requestID graphsync.RequestID) error { +func (rm *RequestManager) PauseRequest(ctx context.Context, requestID graphsync.RequestID) error { response := make(chan error, 1) - rm.send(&pauseRequestMessage{requestID, response}, nil) + rm.send(&pauseRequestMessage{requestID, response}, ctx.Done()) select { case <-rm.ctx.Done(): return errors.New("context cancelled") diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index 6e5fce3c..5446f32d 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -816,7 +816,7 @@ func TestPauseResume(t *testing.T) { // attempt to unpause while request is not paused (note: hook on second block will keep it from // reaching pause point) - err := td.requestManager.UnpauseRequest(rr.gsr.ID()) + err := td.requestManager.UnpauseRequest(ctx, rr.gsr.ID()) require.EqualError(t, err, "request is not paused") close(holdForResumeAttempt) // verify responses sent read ONLY for blocks BEFORE the pause @@ -834,7 +834,7 @@ func TestPauseResume(t *testing.T) { td.fal.CleanupRequest(peers[0], rr.gsr.ID()) // unpause - err = td.requestManager.UnpauseRequest(rr.gsr.ID(), td.extension1, td.extension2) + err = td.requestManager.UnpauseRequest(ctx, rr.gsr.ID(), td.extension1, td.extension2) require.NoError(t, err) // verify the correct new request with Do-no-send-cids & other extensions @@ -875,7 +875,7 @@ func TestPauseResumeExternal(t *testing.T) { hook := func(p peer.ID, responseData graphsync.ResponseData, blockData graphsync.BlockData, hookActions graphsync.IncomingBlockHookActions) { blocksReceived++ if blocksReceived == pauseAt { - err := td.requestManager.PauseRequest(responseData.RequestID()) + err := td.requestManager.PauseRequest(ctx, responseData.RequestID()) require.NoError(t, err) close(holdForPause) } @@ -909,7 +909,7 @@ func TestPauseResumeExternal(t *testing.T) { td.fal.CleanupRequest(peers[0], rr.gsr.ID()) // unpause - err := td.requestManager.UnpauseRequest(rr.gsr.ID(), td.extension1, td.extension2) + err := td.requestManager.UnpauseRequest(ctx, rr.gsr.ID(), td.extension1, td.extension2) require.NoError(t, err) // verify the correct new request with Do-no-send-cids & other extensions diff --git a/requestmanager/server.go b/requestmanager/server.go index 796a3fbd..9660ab8e 100644 --- a/requestmanager/server.go +++ b/requestmanager/server.go @@ -233,7 +233,7 @@ func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID, onTermina if !ok { if onTerminated != nil { select { - case onTerminated <- graphsync.RequestNotFoundErr{}: + case onTerminated <- &graphsync.RequestNotFoundErr{}: case <-rm.ctx.Done(): } } diff --git a/responsemanager/client.go b/responsemanager/client.go index 08e4bfcd..38531a4c 100644 --- a/responsemanager/client.go +++ b/responsemanager/client.go @@ -33,6 +33,7 @@ type inProgressResponseStatus struct { ctx context.Context span trace.Span cancelFn func() + peer peer.ID request gsmsg.GraphSyncRequest loader ipld.BlockReadOpener traverser ipldutil.Traverser @@ -43,11 +44,6 @@ type inProgressResponseStatus struct { responseStream responseassembler.ResponseStream } -type responseKey struct { - p peer.ID - requestID graphsync.RequestID -} - // RequestHooks is an interface for processing request hooks type RequestHooks interface { ProcessRequestHooks(p peer.ID, request graphsync.RequestData) hooks.RequestResult @@ -107,7 +103,7 @@ type ResponseManager struct { blockSentListeners BlockSentListeners networkErrorListeners NetworkErrorListeners messages chan responseManagerMessage - inProgressResponses map[responseKey]*inProgressResponseStatus + inProgressResponses map[graphsync.RequestID]*inProgressResponseStatus connManager network.ConnManager // maximum number of links to traverse per request. A value of zero = infinity, or no limit maxLinksPerRequest uint64 @@ -144,7 +140,7 @@ func New(ctx context.Context, blockSentListeners: blockSentListeners, networkErrorListeners: networkErrorListeners, messages: messages, - inProgressResponses: make(map[responseKey]*inProgressResponseStatus), + inProgressResponses: make(map[graphsync.RequestID]*inProgressResponseStatus), connManager: connManager, maxLinksPerRequest: maxLinksPerRequest, responseQueue: responseQueue, @@ -158,9 +154,9 @@ func (rm *ResponseManager) ProcessRequests(ctx context.Context, p peer.ID, reque } // UnpauseResponse unpauses a response that was previously paused -func (rm *ResponseManager) UnpauseResponse(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { +func (rm *ResponseManager) UnpauseResponse(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { response := make(chan error, 1) - rm.send(&unpauseRequestMessage{p, requestID, response, extensions}, nil) + rm.send(&unpauseRequestMessage{requestID, response, extensions}, ctx.Done()) select { case <-rm.ctx.Done(): return errors.New("context cancelled") @@ -170,9 +166,9 @@ func (rm *ResponseManager) UnpauseResponse(p peer.ID, requestID graphsync.Reques } // PauseResponse pauses an in progress response (may take 1 or more blocks to process) -func (rm *ResponseManager) PauseResponse(p peer.ID, requestID graphsync.RequestID) error { +func (rm *ResponseManager) PauseResponse(ctx context.Context, requestID graphsync.RequestID) error { response := make(chan error, 1) - rm.send(&pauseRequestMessage{p, requestID, response}, nil) + rm.send(&pauseRequestMessage{requestID, response}, ctx.Done()) select { case <-rm.ctx.Done(): return errors.New("context cancelled") @@ -182,9 +178,9 @@ func (rm *ResponseManager) PauseResponse(p peer.ID, requestID graphsync.RequestI } // CancelResponse cancels an in progress response -func (rm *ResponseManager) CancelResponse(p peer.ID, requestID graphsync.RequestID) error { +func (rm *ResponseManager) CancelResponse(ctx context.Context, requestID graphsync.RequestID) error { response := make(chan error, 1) - rm.send(&errorRequestMessage{p, requestID, queryexecutor.ErrCancelledByCommand, response}, nil) + rm.send(&errorRequestMessage{requestID, queryexecutor.ErrCancelledByCommand, response}, ctx.Done()) select { case <-rm.ctx.Done(): return errors.New("context cancelled") @@ -204,19 +200,19 @@ func (rm *ResponseManager) synchronize() { } // StartTask starts the given task from the peer task queue -func (rm *ResponseManager) StartTask(task *peertask.Task, responseTaskChan chan<- queryexecutor.ResponseTask) { - rm.send(&startTaskRequest{task, responseTaskChan}, nil) +func (rm *ResponseManager) StartTask(task *peertask.Task, p peer.ID, responseTaskChan chan<- queryexecutor.ResponseTask) { + rm.send(&startTaskRequest{task, p, responseTaskChan}, nil) } // GetUpdates is called to read pending updates for a task and clear them -func (rm *ResponseManager) GetUpdates(p peer.ID, requestID graphsync.RequestID, updatesChan chan<- []gsmsg.GraphSyncRequest) { - rm.send(&responseUpdateRequest{responseKey{p, requestID}, updatesChan}, nil) +func (rm *ResponseManager) GetUpdates(requestID graphsync.RequestID, updatesChan chan<- []gsmsg.GraphSyncRequest) { + rm.send(&responseUpdateRequest{requestID, updatesChan}, nil) } // FinishTask marks a task from the task queue as done -func (rm *ResponseManager) FinishTask(task *peertask.Task, err error) { +func (rm *ResponseManager) FinishTask(task *peertask.Task, p peer.ID, err error) { done := make(chan struct{}, 1) - rm.send(&finishTaskRequest{task, err, done}, nil) + rm.send(&finishTaskRequest{task, p, err, done}, nil) select { case <-rm.ctx.Done(): case <-done: @@ -224,9 +220,9 @@ func (rm *ResponseManager) FinishTask(task *peertask.Task, err error) { } // CloseWithNetworkError closes a request due to a network error -func (rm *ResponseManager) CloseWithNetworkError(p peer.ID, requestID graphsync.RequestID) { +func (rm *ResponseManager) CloseWithNetworkError(requestID graphsync.RequestID) { done := make(chan error, 1) - rm.send(&errorRequestMessage{p, requestID, queryexecutor.ErrNetworkError, done}, nil) + rm.send(&errorRequestMessage{requestID, queryexecutor.ErrNetworkError, done}, nil) select { case <-rm.ctx.Done(): case <-done: @@ -234,9 +230,9 @@ func (rm *ResponseManager) CloseWithNetworkError(p peer.ID, requestID graphsync. } // TerminateRequest indicates a request has finished sending data and should no longer be tracked -func (rm *ResponseManager) TerminateRequest(p peer.ID, requestID graphsync.RequestID) { +func (rm *ResponseManager) TerminateRequest(requestID graphsync.RequestID) { done := make(chan struct{}, 1) - rm.send(&terminateRequestMessage{p, requestID, done}, nil) + rm.send(&terminateRequestMessage{requestID, done}, nil) select { case <-rm.ctx.Done(): case <-done: diff --git a/responsemanager/messages.go b/responsemanager/messages.go index 917d70c1..cb052652 100644 --- a/responsemanager/messages.go +++ b/responsemanager/messages.go @@ -20,13 +20,12 @@ func (prm *processRequestsMessage) handle(rm *ResponseManager) { } type pauseRequestMessage struct { - p peer.ID requestID graphsync.RequestID response chan error } func (prm *pauseRequestMessage) handle(rm *ResponseManager) { - err := rm.pauseRequest(prm.p, prm.requestID) + err := rm.pauseRequest(prm.requestID) select { case <-rm.ctx.Done(): case prm.response <- err: @@ -34,14 +33,13 @@ func (prm *pauseRequestMessage) handle(rm *ResponseManager) { } type errorRequestMessage struct { - p peer.ID requestID graphsync.RequestID err error response chan error } func (erm *errorRequestMessage) handle(rm *ResponseManager) { - err := rm.abortRequest(rm.ctx, erm.p, erm.requestID, erm.err) + err := rm.abortRequest(rm.ctx, erm.requestID, erm.err) select { case <-rm.ctx.Done(): case erm.response <- err: @@ -60,14 +58,13 @@ func (sm *synchronizeMessage) handle(rm *ResponseManager) { } type unpauseRequestMessage struct { - p peer.ID requestID graphsync.RequestID response chan error extensions []graphsync.ExtensionData } func (urm *unpauseRequestMessage) handle(rm *ResponseManager) { - err := rm.unpauseRequest(urm.p, urm.requestID, urm.extensions...) + err := rm.unpauseRequest(urm.requestID, urm.extensions...) select { case <-rm.ctx.Done(): case urm.response <- err: @@ -75,12 +72,12 @@ func (urm *unpauseRequestMessage) handle(rm *ResponseManager) { } type responseUpdateRequest struct { - key responseKey + requestID graphsync.RequestID updateChan chan<- []gsmsg.GraphSyncRequest } func (rur *responseUpdateRequest) handle(rm *ResponseManager) { - updates := rm.getUpdates(rur.key) + updates := rm.getUpdates(rur.requestID) select { case <-rm.ctx.Done(): case rur.updateChan <- updates: @@ -89,12 +86,13 @@ func (rur *responseUpdateRequest) handle(rm *ResponseManager) { type finishTaskRequest struct { task *peertask.Task + p peer.ID err error done chan struct{} } func (ftr *finishTaskRequest) handle(rm *ResponseManager) { - rm.finishTask(ftr.task, ftr.err) + rm.finishTask(ftr.task, ftr.p, ftr.err) select { case <-rm.ctx.Done(): case ftr.done <- struct{}{}: @@ -103,11 +101,12 @@ func (ftr *finishTaskRequest) handle(rm *ResponseManager) { type startTaskRequest struct { task *peertask.Task + p peer.ID taskDataChan chan<- queryexecutor.ResponseTask } func (str *startTaskRequest) handle(rm *ResponseManager) { - taskData := rm.startTask(str.task) + taskData := rm.startTask(str.task, str.p) select { case <-rm.ctx.Done(): @@ -129,13 +128,12 @@ func (psm *peerStateMessage) handle(rm *ResponseManager) { } type terminateRequestMessage struct { - p peer.ID requestID graphsync.RequestID done chan<- struct{} } func (trm *terminateRequestMessage) handle(rm *ResponseManager) { - rm.terminateRequest(responseKey{trm.p, trm.requestID}) + rm.terminateRequest(trm.requestID) select { case <-rm.ctx.Done(): case trm.done <- struct{}{}: diff --git a/responsemanager/queryexecutor/queryexecutor.go b/responsemanager/queryexecutor/queryexecutor.go index f43e19f9..563168ad 100644 --- a/responsemanager/queryexecutor/queryexecutor.go +++ b/responsemanager/queryexecutor/queryexecutor.go @@ -87,7 +87,7 @@ func (qe *QueryExecutor) ExecuteTask(_ context.Context, pid peer.ID, task *peert // StartTask lets us block until this task is at the top of the execution stack responseTaskChan := make(chan ResponseTask) var rt ResponseTask - qe.manager.StartTask(task, responseTaskChan) + qe.manager.StartTask(task, pid, responseTaskChan) select { case rt = <-responseTaskChan: case <-qe.ctx.Done(): @@ -109,7 +109,7 @@ func (qe *QueryExecutor) ExecuteTask(_ context.Context, pid peer.ID, task *peert span.SetStatus(codes.Error, err.Error()) } } - qe.manager.FinishTask(task, err) + qe.manager.FinishTask(task, pid, err) log.Debugw("finishing response execution", "id", rt.Request.ID(), "peer", pid.String(), "root_cid", rt.Request.Root().String()) return false } @@ -159,7 +159,7 @@ func (qe *QueryExecutor) checkForUpdates( return err case <-taskData.Signals.UpdateSignal: updateChan := make(chan []gsmsg.GraphSyncRequest) - qe.manager.GetUpdates(p, taskData.Request.ID(), updateChan) + qe.manager.GetUpdates(taskData.Request.ID(), updateChan) select { case updates := <-updateChan: for _, update := range updates { @@ -279,9 +279,9 @@ func (qe *QueryExecutor) sendResponse(ctx context.Context, p peer.ID, taskData R // Manager providers an interface to the response manager type Manager interface { - StartTask(task *peertask.Task, responseTaskChan chan<- ResponseTask) - GetUpdates(p peer.ID, requestID graphsync.RequestID, updatesChan chan<- []gsmsg.GraphSyncRequest) - FinishTask(task *peertask.Task, err error) + StartTask(task *peertask.Task, p peer.ID, responseTaskChan chan<- ResponseTask) + GetUpdates(requestID graphsync.RequestID, updatesChan chan<- []gsmsg.GraphSyncRequest) + FinishTask(task *peertask.Task, p peer.ID, err error) } // BlockHooks is an interface for processing block hooks diff --git a/responsemanager/queryexecutor/queryexecutor_test.go b/responsemanager/queryexecutor/queryexecutor_test.go index 15fa66bd..af17cace 100644 --- a/responsemanager/queryexecutor/queryexecutor_test.go +++ b/responsemanager/queryexecutor/queryexecutor_test.go @@ -268,10 +268,11 @@ func newTestData(t *testing.T, blockCount int, expectedTraverse int) (*testData, td := &testData{} td.t = t td.ctx, td.cancel = context.WithTimeout(ctx, 10*time.Second) + td.peer = testutil.GeneratePeers(1)[0] td.blockStore = make(map[ipld.Link][]byte) td.persistence = testutil.NewTestStore(td.blockStore) td.task = &peertask.Task{} - td.manager = &fauxManager{ctx: ctx, t: t, expectedStartTask: td.task} + td.manager = &fauxManager{ctx: ctx, t: t, expectedStartTask: td.task, expectedPeer: td.peer} td.blockHooks = hooks.NewBlockHooks() td.updateHooks = hooks.NewUpdateHooks() td.requestID = graphsync.NewRequestID() @@ -280,7 +281,6 @@ func newTestData(t *testing.T, blockCount int, expectedTraverse int) (*testData, td.extensionData = basicnode.NewBytes(testutil.RandomBytes(100)) td.extensionName = graphsync.ExtensionName("AppleSauce/McGee") td.responseCode = graphsync.ResponseStatusCode(101) - td.peer = testutil.GeneratePeers(1)[0] td.extension = graphsync.ExtensionData{ Name: td.extensionName, @@ -367,10 +367,12 @@ type fauxManager struct { t *testing.T responseTask ResponseTask expectedStartTask *peertask.Task + expectedPeer peer.ID } -func (fm *fauxManager) StartTask(task *peertask.Task, responseTaskChan chan<- ResponseTask) { +func (fm *fauxManager) StartTask(task *peertask.Task, p peer.ID, responseTaskChan chan<- ResponseTask) { require.Same(fm.t, fm.expectedStartTask, task) + require.Equal(fm.t, fm.expectedPeer, p) go func() { select { case <-fm.ctx.Done(): @@ -379,10 +381,11 @@ func (fm *fauxManager) StartTask(task *peertask.Task, responseTaskChan chan<- Re }() } -func (fm *fauxManager) GetUpdates(p peer.ID, requestID graphsync.RequestID, updatesChan chan<- []gsmsg.GraphSyncRequest) { +func (fm *fauxManager) GetUpdates(requestID graphsync.RequestID, updatesChan chan<- []gsmsg.GraphSyncRequest) { } -func (fm *fauxManager) FinishTask(task *peertask.Task, err error) { +func (fm *fauxManager) FinishTask(task *peertask.Task, p peer.ID, err error) { + require.Equal(fm.t, fm.expectedPeer, p) } type fauxResponseStream struct { diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index 09ea3731..6d321bd5 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -174,7 +174,7 @@ func TestCancellationViaCommand(t *testing.T) { td.assertSendBlock() // send a cancellation - err := responseManager.CancelResponse(td.p, td.requestID) + err := responseManager.CancelResponse(td.ctx, td.requestID) require.NoError(t, err) close(waitForCancel) @@ -218,22 +218,33 @@ func TestStats(t *testing.T) { responseManager := td.nullTaskQueueResponseManager() td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) responseManager.Startup() - responseManager.ProcessRequests(td.ctx, td.p, td.requests) + + p1 := td.p + reqid1 := td.requestID + req1 := td.requests + p2 := testutil.GeneratePeers(1)[0] - responseManager.ProcessRequests(td.ctx, p2, td.requests) - peerState := responseManager.PeerState(td.p) + reqid2 := graphsync.NewRequestID() + req2 := []gsmsg.GraphSyncRequest{ + gsmsg.NewRequest(reqid2, td.blockChain.TipLink.(cidlink.Link).Cid, td.blockChain.Selector(), graphsync.Priority(0), td.extension), + } + + responseManager.ProcessRequests(td.ctx, p1, req1) + responseManager.ProcessRequests(td.ctx, p2, req2) + + peerState := responseManager.PeerState(p1) require.Len(t, peerState.RequestStates, 1) - require.Equal(t, peerState.RequestStates[td.requestID], graphsync.Queued) + require.Equal(t, peerState.RequestStates[reqid1], graphsync.Queued) require.Len(t, peerState.Pending, 1) - require.Equal(t, peerState.Pending[0], td.requestID) + require.Equal(t, peerState.Pending[0], reqid1) require.Len(t, peerState.Active, 0) // no inconsistencies require.Len(t, peerState.Diagnostics(), 0) peerState = responseManager.PeerState(p2) require.Len(t, peerState.RequestStates, 1) - require.Equal(t, peerState.RequestStates[td.requestID], graphsync.Queued) + require.Equal(t, peerState.RequestStates[reqid2], graphsync.Queued) require.Len(t, peerState.Pending, 1) - require.Equal(t, peerState.Pending[0], td.requestID) + require.Equal(t, peerState.Pending[0], reqid2) require.Len(t, peerState.Active, 0) // no inconsistencies require.Len(t, peerState.Diagnostics(), 0) @@ -502,7 +513,7 @@ func TestValidationAndExtensions(t *testing.T) { td.assertPausedRequest() td.assertRequestDoesNotCompleteWhilePaused() testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") - err := responseManager.UnpauseResponse(td.p, td.requestID) + err := responseManager.UnpauseResponse(td.ctx, td.requestID) require.NoError(t, err) td.assertCompleteRequestWith(graphsync.RequestCompletedFull) }) @@ -560,7 +571,7 @@ func TestValidationAndExtensions(t *testing.T) { td.assertRequestDoesNotCompleteWhilePaused() td.verifyNResponses(blockCount) td.assertPausedRequest() - err := responseManager.UnpauseResponse(td.p, td.requestID, td.extensionResponse) + err := responseManager.UnpauseResponse(td.ctx, td.requestID, td.extensionResponse) require.NoError(t, err) td.assertReceiveExtensionResponse() td.assertCompleteRequestWith(graphsync.RequestCompletedFull) @@ -579,7 +590,7 @@ func TestValidationAndExtensions(t *testing.T) { td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { blkIndex++ if blkIndex == blockCount { - err := responseManager.PauseResponse(p, requestData.ID()) + err := responseManager.PauseResponse(td.ctx, requestData.ID()) require.NoError(t, err) } }) @@ -587,7 +598,7 @@ func TestValidationAndExtensions(t *testing.T) { td.assertRequestDoesNotCompleteWhilePaused() td.verifyNResponses(blockCount + 1) td.assertPausedRequest() - err := responseManager.UnpauseResponse(td.p, td.requestID) + err := responseManager.UnpauseResponse(td.ctx, td.requestID) require.NoError(t, err) td.verifyNResponses(td.blockChainLength - (blockCount + 1)) td.assertCompleteRequestWith(graphsync.RequestCompletedFull) @@ -606,7 +617,7 @@ func TestValidationAndExtensions(t *testing.T) { }) go func() { <-advance - err := responseManager.UnpauseResponse(td.p, td.requestID) + err := responseManager.UnpauseResponse(td.ctx, td.requestID) require.NoError(t, err) }() responseManager.ProcessRequests(td.ctx, td.p, td.requests) @@ -780,7 +791,7 @@ func TestValidationAndExtensions(t *testing.T) { td.assertCompleteRequestWith(graphsync.RequestFailedUnknown) // cannot unpause - err := responseManager.UnpauseResponse(td.p, td.requestID) + err := responseManager.UnpauseResponse(td.ctx, td.requestID) require.Error(t, err) }) }) @@ -856,7 +867,7 @@ func TestNetworkErrors(t *testing.T) { td.notifyBlockSendsNetworkError(err) td.assertNetworkErrors(err, 1) td.assertRequestCleared() - err = responseManager.UnpauseResponse(td.p, td.requestID, td.extensionResponse) + err = responseManager.UnpauseResponse(td.ctx, td.requestID, td.extensionResponse) require.Error(t, err) }) } diff --git a/responsemanager/server.go b/responsemanager/server.go index 80758d04..faf277f6 100644 --- a/responsemanager/server.go +++ b/responsemanager/server.go @@ -45,21 +45,21 @@ func (rm *ResponseManager) run() { } } -func (rm *ResponseManager) terminateRequest(key responseKey) { - ipr, ok := rm.inProgressResponses[key] +func (rm *ResponseManager) terminateRequest(requestID graphsync.RequestID) { + ipr, ok := rm.inProgressResponses[requestID] if !ok { return } - rm.connManager.Unprotect(key.p, key.requestID.Tag()) - delete(rm.inProgressResponses, key) + rm.connManager.Unprotect(ipr.peer, requestID.Tag()) + delete(rm.inProgressResponses, requestID) ipr.cancelFn() ipr.span.End() } -func (rm *ResponseManager) processUpdate(ctx context.Context, key responseKey, update gsmsg.GraphSyncRequest) { - response, ok := rm.inProgressResponses[key] +func (rm *ResponseManager) processUpdate(ctx context.Context, requestID graphsync.RequestID, update gsmsg.GraphSyncRequest) { + response, ok := rm.inProgressResponses[requestID] if !ok || response.state == graphsync.CompletingSend { - log.Warnf("received update for non existent request, peer %s, request ID %s", key.p.Pretty(), key.requestID.String()) + log.Warnf("received update for non existent request ID %s", requestID.String()) return } @@ -89,7 +89,7 @@ func (rm *ResponseManager) processUpdate(ctx context.Context, key responseKey, u } return } // else this is a paused response, so the update needs to be handled here and not in the executor - result := rm.updateHooks.ProcessUpdateHooks(key.p, response.request, update) + result := rm.updateHooks.ProcessUpdateHooks(response.peer, response.request, update) _ = response.responseStream.Transaction(func(rb responseassembler.ResponseBuilder) error { for _, extension := range result.Extensions { rb.SendExtensionData(extension) @@ -106,7 +106,7 @@ func (rm *ResponseManager) processUpdate(ctx context.Context, key responseKey, u return } if result.Unpause { - err := rm.unpauseRequest(key.p, key.requestID) + err := rm.unpauseRequest(requestID) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, result.Err.Error()) @@ -115,11 +115,10 @@ func (rm *ResponseManager) processUpdate(ctx context.Context, key responseKey, u } } -func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { - key := responseKey{p, requestID} - inProgressResponse, ok := rm.inProgressResponses[key] +func (rm *ResponseManager) unpauseRequest(requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { + inProgressResponse, ok := rm.inProgressResponses[requestID] if !ok { - return errors.New("could not find request") + return graphsync.RequestNotFoundErr{} } if inProgressResponse.state != graphsync.Paused { return errors.New("request is not paused") @@ -133,16 +132,17 @@ func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.Request return nil }) } - rm.responseQueue.PushTask(p, peertask.Task{Topic: key, Priority: math.MaxInt32, Work: 1}) + rm.responseQueue.PushTask(inProgressResponse.peer, peertask.Task{Topic: requestID, Priority: math.MaxInt32, Work: 1}) return nil } -func (rm *ResponseManager) abortRequest(ctx context.Context, p peer.ID, requestID graphsync.RequestID, err error) error { - key := responseKey{p, requestID} - rm.responseQueue.Remove(key, key.p) - response, ok := rm.inProgressResponses[key] +func (rm *ResponseManager) abortRequest(ctx context.Context, requestID graphsync.RequestID, err error) error { + response, ok := rm.inProgressResponses[requestID] + if ok { + rm.responseQueue.Remove(requestID, response.peer) + } if !ok || response.state == graphsync.CompletingSend { - return errors.New("could not find request") + return graphsync.RequestNotFoundErr{} } _, span := otel.Tracer("graphsync").Start(trace.ContextWithSpan(ctx, response.span), @@ -158,13 +158,13 @@ func (rm *ResponseManager) abortRequest(ctx context.Context, p peer.ID, requestI if response.state != graphsync.Running { if ipldutil.IsContextCancelErr(err) { response.responseStream.ClearRequest() - rm.terminateRequest(key) - rm.cancelledListeners.NotifyCancelledListeners(p, response.request) + rm.terminateRequest(requestID) + rm.cancelledListeners.NotifyCancelledListeners(response.peer, response.request) return nil } if err == queryexecutor.ErrNetworkError { response.responseStream.ClearRequest() - rm.terminateRequest(key) + rm.terminateRequest(requestID) return nil } response.state = graphsync.CompletingSend @@ -189,13 +189,12 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync defer messageSpan.End() for _, request := range requests { - key := responseKey{p: p, requestID: request.ID()} switch request.Type() { case graphsync.RequestTypeCancel: - _ = rm.abortRequest(ctx, p, request.ID(), ipldutil.ContextCancelError{}) + _ = rm.abortRequest(ctx, request.ID(), ipldutil.ContextCancelError{}) continue case graphsync.RequestTypeUpdate: - rm.processUpdate(ctx, key, request) + rm.processUpdate(ctx, request.ID(), request) continue default: } @@ -222,7 +221,7 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync )) rctx, cancelFn := context.WithCancel(rctx) sub := &subscriber{ - p: key.p, + p: p, request: request, requestCloser: rm, blockSentListeners: rm.blockSentListeners, @@ -231,16 +230,17 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync connManager: rm.connManager, } log.Infow("graphsync request initiated", "request id", request.ID().String(), "peer", p, "root", request.Root()) - ipr, ok := rm.inProgressResponses[key] + ipr, ok := rm.inProgressResponses[request.ID()] if ok && ipr.state == graphsync.Running { log.Warnf("there is an identical request already in progress", "request id", request.ID().String(), "peer", p) } - rm.inProgressResponses[key] = + rm.inProgressResponses[request.ID()] = &inProgressResponseStatus{ ctx: rctx, span: responseSpan, cancelFn: cancelFn, + peer: p, request: request, signals: queryexecutor.ResponseSignals{ PauseSignal: make(chan struct{}, 1), @@ -249,23 +249,23 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync }, state: graphsync.Queued, startTime: time.Now(), - responseStream: rm.responseAssembler.NewStream(ctx, key.p, key.requestID, sub), + responseStream: rm.responseAssembler.NewStream(ctx, p, request.ID(), sub), } // TODO: Use a better work estimation metric. - rm.responseQueue.PushTask(p, peertask.Task{Topic: key, Priority: int(request.Priority()), Work: 1}) + rm.responseQueue.PushTask(p, peertask.Task{Topic: request.ID(), Priority: int(request.Priority()), Work: 1}) } } -func (rm *ResponseManager) taskDataForKey(key responseKey) queryexecutor.ResponseTask { - response, hasResponse := rm.inProgressResponses[key] +func (rm *ResponseManager) taskDataForKey(requestID graphsync.RequestID) queryexecutor.ResponseTask { + response, hasResponse := rm.inProgressResponses[requestID] if !hasResponse || response.state == graphsync.CompletingSend { return queryexecutor.ResponseTask{Empty: true} } - log.Infow("graphsync response processing begins", "request id", key.requestID.String(), "peer", key.p, "total time", time.Since(response.startTime)) + log.Infow("graphsync response processing begins", "request id", requestID.String(), "peer", response.peer, "total time", time.Since(response.startTime)) if response.loader == nil || response.traverser == nil { - loader, traverser, isPaused, err := (&queryPreparer{rm.requestHooks, rm.linkSystem, rm.maxLinksPerRequest}).prepareQuery(response.ctx, key.p, response.request, response.responseStream, response.signals) + loader, traverser, isPaused, err := (&queryPreparer{rm.requestHooks, rm.linkSystem, rm.maxLinksPerRequest}).prepareQuery(response.ctx, response.peer, response.request, response.responseStream, response.signals) if err != nil { response.state = graphsync.CompletingSend response.span.RecordError(err) @@ -292,20 +292,20 @@ func (rm *ResponseManager) taskDataForKey(key responseKey) queryexecutor.Respons } } -func (rm *ResponseManager) startTask(task *peertask.Task) queryexecutor.ResponseTask { - key := task.Topic.(responseKey) - taskData := rm.taskDataForKey(key) +func (rm *ResponseManager) startTask(task *peertask.Task, p peer.ID) queryexecutor.ResponseTask { + requestID := task.Topic.(graphsync.RequestID) + taskData := rm.taskDataForKey(requestID) if taskData.Empty { - rm.responseQueue.TaskDone(key.p, task) + rm.responseQueue.TaskDone(p, task) } return taskData } -func (rm *ResponseManager) finishTask(task *peertask.Task, err error) { - key := task.Topic.(responseKey) - rm.responseQueue.TaskDone(key.p, task) - response, ok := rm.inProgressResponses[key] +func (rm *ResponseManager) finishTask(task *peertask.Task, p peer.ID, err error) { + requestID := task.Topic.(graphsync.RequestID) + rm.responseQueue.TaskDone(p, task) + response, ok := rm.inProgressResponses[requestID] if !ok { return } @@ -313,7 +313,7 @@ func (rm *ResponseManager) finishTask(task *peertask.Task, err error) { response.state = graphsync.Paused return } - log.Infow("graphsync response processing complete (messages stil sending)", "request id", key.requestID.String(), "peer", key.p, "total time", time.Since(response.startTime)) + log.Infow("graphsync response processing complete (messages stil sending)", "request id", requestID.String(), "peer", p, "total time", time.Since(response.startTime)) if err != nil { response.span.RecordError(err) @@ -322,21 +322,21 @@ func (rm *ResponseManager) finishTask(task *peertask.Task, err error) { } if ipldutil.IsContextCancelErr(err) { - rm.cancelledListeners.NotifyCancelledListeners(key.p, response.request) - rm.terminateRequest(key) + rm.cancelledListeners.NotifyCancelledListeners(p, response.request) + rm.terminateRequest(requestID) return } if err == queryexecutor.ErrNetworkError { - rm.terminateRequest(key) + rm.terminateRequest(requestID) return } response.state = graphsync.CompletingSend } -func (rm *ResponseManager) getUpdates(key responseKey) []gsmsg.GraphSyncRequest { - response, ok := rm.inProgressResponses[key] +func (rm *ResponseManager) getUpdates(requestID graphsync.RequestID) []gsmsg.GraphSyncRequest { + response, ok := rm.inProgressResponses[requestID] if !ok { return nil } @@ -345,11 +345,10 @@ func (rm *ResponseManager) getUpdates(key responseKey) []gsmsg.GraphSyncRequest return updates } -func (rm *ResponseManager) pauseRequest(p peer.ID, requestID graphsync.RequestID) error { - key := responseKey{p, requestID} - inProgressResponse, ok := rm.inProgressResponses[key] +func (rm *ResponseManager) pauseRequest(requestID graphsync.RequestID) error { + inProgressResponse, ok := rm.inProgressResponses[requestID] if !ok || inProgressResponse.state == graphsync.CompletingSend { - return errors.New("could not find request") + return graphsync.RequestNotFoundErr{} } if inProgressResponse.state == graphsync.Paused { return errors.New("request is already paused") @@ -366,8 +365,8 @@ func (rm *ResponseManager) peerState(p peer.ID) peerstate.PeerState { rm.responseQueue.WithPeerTopics(p, func(peerTopics *peertracker.PeerTrackerTopics) { requestStates := make(graphsync.RequestStates) for key, ipr := range rm.inProgressResponses { - if key.p == p { - requestStates[key.requestID] = ipr.state + if ipr.peer == p { + requestStates[key] = ipr.state } } peerState = peerstate.PeerState{RequestStates: requestStates, TaskQueueState: fromPeerTopics(peerTopics)} @@ -381,11 +380,11 @@ func fromPeerTopics(pt *peertracker.PeerTrackerTopics) peerstate.TaskQueueState } active := make([]graphsync.RequestID, 0, len(pt.Active)) for _, topic := range pt.Active { - active = append(active, topic.(responseKey).requestID) + active = append(active, topic.(graphsync.RequestID)) } pending := make([]graphsync.RequestID, 0, len(pt.Pending)) for _, topic := range pt.Pending { - pending = append(pending, topic.(responseKey).requestID) + pending = append(pending, topic.(graphsync.RequestID)) } return peerstate.TaskQueueState{ Active: active, diff --git a/responsemanager/subscriber.go b/responsemanager/subscriber.go index 8e3992e7..a5ef20a2 100644 --- a/responsemanager/subscriber.go +++ b/responsemanager/subscriber.go @@ -12,8 +12,8 @@ import ( // RequestCloser can cancel request on a network error type RequestCloser interface { - TerminateRequest(p peer.ID, requestID graphsync.RequestID) - CloseWithNetworkError(p peer.ID, requestID graphsync.RequestID) + TerminateRequest(requestID graphsync.RequestID) + CloseWithNetworkError(requestID graphsync.RequestID) } type subscriber struct { @@ -33,10 +33,10 @@ func (s *subscriber) OnNext(_ notifications.Topic, event notifications.Event) { } switch responseEvent.Name { case messagequeue.Error: - s.requestCloser.CloseWithNetworkError(s.p, s.request.ID()) + s.requestCloser.CloseWithNetworkError(s.request.ID()) responseCode := responseEvent.Metadata.ResponseCodes[s.request.ID()] if responseCode.IsTerminal() { - s.requestCloser.TerminateRequest(s.p, s.request.ID()) + s.requestCloser.TerminateRequest(s.request.ID()) } s.networkErrorListeners.NotifyNetworkErrorListeners(s.p, s.request, responseEvent.Err) case messagequeue.Sent: @@ -46,7 +46,7 @@ func (s *subscriber) OnNext(_ notifications.Topic, event notifications.Event) { } responseCode := responseEvent.Metadata.ResponseCodes[s.request.ID()] if responseCode.IsTerminal() { - s.requestCloser.TerminateRequest(s.p, s.request.ID()) + s.requestCloser.TerminateRequest(s.request.ID()) s.completedListeners.NotifyCompletedListeners(s.p, s.request, responseCode) } }