Skip to content

Handle context cancellation properly #428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions requestmanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,14 @@ func (rm *RequestManager) NewRequest(ctx context.Context,

inProgressRequestChan := make(chan inProgressRequest)

rm.send(&newRequestMessage{requestID, span, p, root, selectorNode, extensions, inProgressRequestChan}, ctx.Done())
err := rm.send(&newRequestMessage{requestID, span, p, root, selectorNode, extensions, inProgressRequestChan}, ctx.Done())
if err != nil {
return rm.emptyResponse()
}
var receivedInProgressRequest inProgressRequest
select {
case <-rm.ctx.Done():
return rm.emptyResponse()
case <-ctx.Done():
return rm.emptyResponse()
case receivedInProgressRequest = <-inProgressRequestChan:
}

Expand Down Expand Up @@ -283,12 +284,13 @@ func (rm *RequestManager) cancelRequestAndClose(requestID graphsync.RequestID,
// CancelRequest cancels the given request ID and waits for the request to terminate
func (rm *RequestManager) CancelRequest(ctx context.Context, requestID graphsync.RequestID) error {
terminated := make(chan error, 1)
rm.send(&cancelRequestMessage{requestID, terminated, graphsync.RequestClientCancelledErr{}}, ctx.Done())
err := rm.send(&cancelRequestMessage{requestID, terminated, graphsync.RequestClientCancelledErr{}}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-terminated:
return err
}
Expand All @@ -300,19 +302,20 @@ func (rm *RequestManager) ProcessResponses(p peer.ID,
responses []gsmsg.GraphSyncResponse,
blks []blocks.Block) {

rm.send(&processResponsesMessage{p, responses, blks}, nil)
_ = rm.send(&processResponsesMessage{p, responses, blks}, nil)
}

// UnpauseRequest unpauses a request that was paused in a block hook based request ID
// Can also send extensions with unpause
func (rm *RequestManager) UnpauseRequest(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&unpauseRequestMessage{requestID, extensions, response}, ctx.Done())
err := rm.send(&unpauseRequestMessage{requestID, extensions, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -321,12 +324,13 @@ func (rm *RequestManager) UnpauseRequest(ctx context.Context, requestID graphsyn
// PauseRequest pauses an in progress request (may take 1 or more blocks to process)
func (rm *RequestManager) PauseRequest(ctx context.Context, requestID graphsync.RequestID) error {
response := make(chan error, 1)
rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
err := rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -335,26 +339,27 @@ func (rm *RequestManager) PauseRequest(ctx context.Context, requestID graphsync.
// UpdateRequest updates an in progress request
func (rm *RequestManager) UpdateRequest(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
err := rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
}

// GetRequestTask gets data for the given task in the request queue
func (rm *RequestManager) GetRequestTask(p peer.ID, task *peertask.Task, requestExecutionChan chan executor.RequestTask) {
rm.send(&getRequestTaskMessage{p, task, requestExecutionChan}, nil)
_ = rm.send(&getRequestTaskMessage{p, task, requestExecutionChan}, nil)
}

// ReleaseRequestTask releases a task request the requestQueue
func (rm *RequestManager) ReleaseRequestTask(p peer.ID, task *peertask.Task, err error) {
done := make(chan struct{}, 1)
rm.send(&releaseRequestTaskMessage{p, task, err, done}, nil)
_ = rm.send(&releaseRequestTaskMessage{p, task, err, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -364,7 +369,7 @@ func (rm *RequestManager) ReleaseRequestTask(p peer.ID, task *peertask.Task, err
// PeerState gets stats on all outgoing requests for a given peer
func (rm *RequestManager) PeerState(p peer.ID) peerstate.PeerState {
response := make(chan peerstate.PeerState)
rm.send(&peerStateMessage{p, response}, nil)
_ = rm.send(&peerStateMessage{p, response}, nil)
select {
case <-rm.ctx.Done():
return peerstate.PeerState{}
Expand Down Expand Up @@ -392,11 +397,20 @@ func (rm *RequestManager) Shutdown() {
rm.cancel()
}

func (rm *RequestManager) send(message requestManagerMessage, done <-chan struct{}) {
func (rm *RequestManager) send(message requestManagerMessage, done <-chan struct{}) error {
// prioritize cancelled context
select {
case <-done:
return errors.New("unable to send message before cancellation")
default:
}
select {
case <-rm.ctx.Done():
return rm.ctx.Err()
case <-done:
return errors.New("unable to send message before cancellation")
case rm.messages <- message:
return nil
}
}

Expand Down
3 changes: 2 additions & 1 deletion requestmanager/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (rm *RequestManager) run() {
for {
select {
case message := <-rm.messages:

message.handle(rm)
case <-rm.ctx.Done():
return
Expand Down Expand Up @@ -304,13 +305,13 @@ func (rm *RequestManager) processResponses(p peer.ID,
for _, blk := range blks {
blkMap[blk.Cid()] = blk.RawData()
}
rm.updateLastResponses(filteredResponses)
for _, response := range filteredResponses {
reconciledLoader := rm.inProgressRequestStatuses[response.RequestID()].reconciledLoader
if reconciledLoader != nil {
reconciledLoader.IngestResponse(response.Metadata(), trace.LinkFromContext(ctx), blkMap)
}
}
rm.updateLastResponses(filteredResponses)
rm.processTerminations(filteredResponses)
log.Debugf("end processing responses for peer %s", p)
}
Expand Down
55 changes: 34 additions & 21 deletions responsemanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,19 @@ func New(ctx context.Context,

// ProcessRequests processes incoming requests for the given peer
func (rm *ResponseManager) ProcessRequests(ctx context.Context, p peer.ID, requests []gsmsg.GraphSyncRequest) {
rm.send(&processRequestsMessage{p, requests}, ctx.Done())
_ = rm.send(&processRequestsMessage{p, requests}, ctx.Done())
}

// UnpauseResponse unpauses a response that was previously paused
func (rm *ResponseManager) UnpauseResponse(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&unpauseRequestMessage{requestID, response, extensions}, ctx.Done())
err := rm.send(&unpauseRequestMessage{requestID, response, extensions}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -177,12 +178,13 @@ func (rm *ResponseManager) UnpauseResponse(ctx context.Context, requestID graphs
// PauseResponse pauses an in progress response (may take 1 or more blocks to process)
func (rm *ResponseManager) PauseResponse(ctx context.Context, requestID graphsync.RequestID) error {
response := make(chan error, 1)
rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
err := rm.send(&pauseRequestMessage{requestID, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -191,12 +193,13 @@ func (rm *ResponseManager) PauseResponse(ctx context.Context, requestID graphsyn
// CancelResponse cancels an in progress response
func (rm *ResponseManager) CancelResponse(ctx context.Context, requestID graphsync.RequestID) error {
response := make(chan error, 1)
rm.send(&errorRequestMessage{requestID, queryexecutor.ErrCancelledByCommand, response}, ctx.Done())
err := rm.send(&errorRequestMessage{requestID, queryexecutor.ErrCancelledByCommand, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -205,12 +208,13 @@ func (rm *ResponseManager) CancelResponse(ctx context.Context, requestID graphsy
// UpdateRequest updates an in progress response
func (rm *ResponseManager) UpdateResponse(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
err := rm.send(&updateRequestMessage{requestID, extensions, response}, ctx.Done())
if err != nil {
return err
}
select {
case <-rm.ctx.Done():
return errors.New("context cancelled")
case <-ctx.Done():
return ctx.Err()
case err := <-response:
return err
}
Expand All @@ -219,7 +223,7 @@ func (rm *ResponseManager) UpdateResponse(ctx context.Context, requestID graphsy
// Synchronize is a utility method that blocks until all current messages are processed
func (rm *ResponseManager) synchronize() {
sync := make(chan error)
rm.send(&synchronizeMessage{sync}, nil)
_ = rm.send(&synchronizeMessage{sync}, nil)
select {
case <-rm.ctx.Done():
case <-sync:
Expand All @@ -228,18 +232,18 @@ func (rm *ResponseManager) synchronize() {

// StartTask starts the given task from the peer task queue
func (rm *ResponseManager) StartTask(task *peertask.Task, p peer.ID, responseTaskChan chan<- queryexecutor.ResponseTask) {
rm.send(&startTaskRequest{task, p, responseTaskChan}, nil)
_ = rm.send(&startTaskRequest{task, p, responseTaskChan}, nil)
}

// GetUpdates is called to read pending updates for a task and clear them
func (rm *ResponseManager) GetUpdates(requestID graphsync.RequestID, updatesChan chan<- []gsmsg.GraphSyncRequest) {
rm.send(&responseUpdateRequest{requestID, updatesChan}, nil)
_ = rm.send(&responseUpdateRequest{requestID, updatesChan}, nil)
}

// FinishTask marks a task from the task queue as done
func (rm *ResponseManager) FinishTask(task *peertask.Task, p peer.ID, err error) {
done := make(chan struct{}, 1)
rm.send(&finishTaskRequest{task, p, err, done}, nil)
_ = rm.send(&finishTaskRequest{task, p, err, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -249,7 +253,7 @@ func (rm *ResponseManager) FinishTask(task *peertask.Task, p peer.ID, err error)
// CloseWithNetworkError closes a request due to a network error
func (rm *ResponseManager) CloseWithNetworkError(requestID graphsync.RequestID) {
done := make(chan error, 1)
rm.send(&errorRequestMessage{requestID, queryexecutor.ErrNetworkError, done}, nil)
_ = rm.send(&errorRequestMessage{requestID, queryexecutor.ErrNetworkError, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -259,7 +263,7 @@ func (rm *ResponseManager) CloseWithNetworkError(requestID graphsync.RequestID)
// TerminateRequest indicates a request has finished sending data and should no longer be tracked
func (rm *ResponseManager) TerminateRequest(requestID graphsync.RequestID) {
done := make(chan struct{}, 1)
rm.send(&terminateRequestMessage{requestID, done}, nil)
_ = rm.send(&terminateRequestMessage{requestID, done}, nil)
select {
case <-rm.ctx.Done():
case <-done:
Expand All @@ -269,7 +273,7 @@ func (rm *ResponseManager) TerminateRequest(requestID graphsync.RequestID) {
// PeerState gets current state of the outgoing responses for a given peer
func (rm *ResponseManager) PeerState(p peer.ID) peerstate.PeerState {
response := make(chan peerstate.PeerState)
rm.send(&peerStateMessage{p, response}, nil)
_ = rm.send(&peerStateMessage{p, response}, nil)
select {
case <-rm.ctx.Done():
return peerstate.PeerState{}
Expand All @@ -278,11 +282,20 @@ func (rm *ResponseManager) PeerState(p peer.ID) peerstate.PeerState {
}
}

func (rm *ResponseManager) send(message responseManagerMessage, done <-chan struct{}) {
func (rm *ResponseManager) send(message responseManagerMessage, done <-chan struct{}) error {
// prioritize cancelled context
select {
case <-done:
return errors.New("unable to send message before cancellation")
default:
}
select {
case <-rm.ctx.Done():
return rm.ctx.Err()
case <-done:
return errors.New("unable to send message before cancellation")
case rm.messages <- message:
return nil
}
}

Expand Down