Skip to content

ipldutil: simplify state synchronization, add docs #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 112 additions & 97 deletions ipldutil/traverser.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"io"
"sync"

dagpb "github.com/ipld/go-codec-dagpb"
"github.com/ipld/go-ipld-prime"
Expand Down Expand Up @@ -50,27 +51,35 @@ type TraversalBuilder struct {
// Traverser is an interface for performing a selector traversal that operates iteratively --
// it stops and waits for a manual load every time a block boundary is encountered
type Traverser interface {
// IsComplete returns the completion state (boolean) and if so, the final error result from IPLD
// IsComplete returns the completion state (boolean) and if so, the final
// error result from IPLD.
//
// Note that CurrentRequest will block if the traverser is performing an
// IPLD load.
IsComplete() (bool, error)
// Current request returns the current link waiting to be loaded

// CurrentRequest returns the parameters for the current block load waiting
// to be fulfilled in order to advance further.
//
// Note that CurrentRequest will block if the traverser is performing an
// IPLD load.
CurrentRequest() (ipld.Link, ipld.LinkContext)
// Advance advances the traversal successfully by supplying the given reader as the result of the next IPLD load

// Advance advances the traversal successfully by supplying the given reader
// as the result of the next IPLD load.
Advance(reader io.Reader) error
// Error errors the traversal by returning the given error as the result of the next IPLD load

// Error errors the traversal by returning the given error as the result of
// the next IPLD load.
Error(err error)

// Shutdown cancels the traversal
Shutdown(ctx context.Context)

// NBlocksTraversed returns the number of blocks successfully traversed
NBlocksTraversed() int
}

type state struct {
isDone bool
completionErr error
currentLink ipld.Link
currentContext ipld.LinkContext
}

type nextResponse struct {
input io.Reader
err error
Expand All @@ -81,26 +90,24 @@ type nextResponse struct {
func (tb TraversalBuilder) Start(parentCtx context.Context) Traverser {
ctx, cancel := context.WithCancel(parentCtx)
t := &traverser{
blocksCount: 0,
parentCtx: parentCtx,
ctx: ctx,
cancel: cancel,
root: tb.Root,
selector: tb.Selector,
visitor: defaultVisitor,
chooser: dagpb.AddSupportToChooser(basicnode.Chooser),
linkSystem: tb.LinkSystem,
budget: tb.Budget,
awaitRequest: make(chan struct{}, 1),
stateChan: make(chan state, 1),
responses: make(chan nextResponse),
stopped: make(chan struct{}),
ctx: ctx,
cancel: cancel,
root: tb.Root,
selector: tb.Selector,
linkSystem: tb.LinkSystem,
budget: tb.Budget,
responses: make(chan nextResponse),
stopped: make(chan struct{}),
}
if tb.Visitor != nil {
t.visitor = tb.Visitor
} else {
t.visitor = defaultVisitor
}
if tb.Chooser != nil {
t.chooser = tb.Chooser
} else {
t.chooser = dagpb.AddSupportToChooser(basicnode.Chooser)
}
if tb.LinkSystem.DecoderChooser == nil {
t.linkSystem.DecoderChooser = defaultLinkSystem.DecoderChooser
Expand All @@ -119,75 +126,75 @@ func (tb TraversalBuilder) Start(parentCtx context.Context) Traverser {
// traverser is a class to perform a selector traversal that stops every time a new block is loaded
// and waits for manual input (in the form of advance or error)
type traverser struct {
blocksCount int
parentCtx context.Context
ctx context.Context
cancel context.CancelFunc
root ipld.Link
selector ipld.Node
visitor traversal.AdvVisitFn
linkSystem ipld.LinkSystem
chooser traversal.LinkTargetNodePrototypeChooser
currentLink ipld.Link
currentContext ipld.LinkContext
budget *traversal.Budget
blocksCount int
ctx context.Context
cancel context.CancelFunc
root ipld.Link
selector ipld.Node
visitor traversal.AdvVisitFn
linkSystem ipld.LinkSystem
chooser traversal.LinkTargetNodePrototypeChooser
budget *traversal.Budget

// stateMu is held while a block is being loaded.
// It is released when a StorageReadOpener callback is received,
// so that the user can inspect the state and use Advance or Error.
// Advance/Error grab the mutex and let StorageReadOpener return.
// The four state fields are only safe to read while the mutex isn't held.
stateMu sync.Mutex
isDone bool
completionErr error
awaitRequest chan struct{}
stateChan chan state
responses chan nextResponse
stopped chan struct{}
currentLink ipld.Link
currentContext ipld.LinkContext

// responses blocks LinkSystem block loads (in the method "loader")
// until the next Advance or Error method call.
responses chan nextResponse

// stopped is closed when the traverser is stopped,
// due to being finishing, cancelled, or shut down.
stopped chan struct{}
}

func (t *traverser) NBlocksTraversed() int {
return t.blocksCount
}

func (t *traverser) loader(lnkCtx ipld.LinkContext, lnk ipld.Link) (io.Reader, error) {
// A StorageReadOpener call came in; update the state and release the lock.
// We can't simply unlock the mutex inside the <-t.responses case,
// as then we'd deadlock with the other side trying to send.
// The other side can't lock after sending to t.responses,
// as otherwise the load might start before the mutex is held.
t.currentLink = lnk
t.currentContext = lnkCtx
t.stateMu.Unlock()

select {
case <-t.ctx.Done():
return nil, ContextCancelError{}
case t.stateChan <- state{false, nil, lnk, lnkCtx}:
}
select {
case <-t.ctx.Done():
// We got cancelled, so we won't load this block via the responses chan.
// Lock the mutex again, until writeDone gives the user their final error.
t.stateMu.Lock()
return nil, ContextCancelError{}
case response := <-t.responses:
return response.input, response.err
}
}

func (t *traverser) checkState() {
select {
case <-t.awaitRequest:
select {
case <-t.ctx.Done():
t.isDone = true
t.completionErr = ContextCancelError{}
case newState := <-t.stateChan:
t.isDone = newState.isDone
t.completionErr = newState.completionErr
t.currentLink = newState.currentLink
t.currentContext = newState.currentContext
}
default:
}
}

func (t *traverser) writeDone(err error) {
select {
case <-t.ctx.Done():
case t.stateChan <- state{true, err, nil, ipld.LinkContext{Ctx: t.ctx}}:
}
t.isDone = true
t.completionErr = err
t.currentContext = ipld.LinkContext{Ctx: t.ctx}

// The traversal is done, so there won't be another StorageReadOpener call.
// Unlock the state so the user can use IsComplete etc.
t.stateMu.Unlock()
}

func (t *traverser) start() {
select {
case <-t.ctx.Done():
close(t.stopped)
return
case t.awaitRequest <- struct{}{}:
}
// Grab the state mutex until the first StorageReadOpener call comes in.
t.stateMu.Lock()

go func() {
defer close(t.stopped)
ns, err := t.chooser(t.root, ipld.LinkContext{Ctx: t.ctx})
Expand Down Expand Up @@ -225,6 +232,9 @@ func (t *traverser) start() {
}()
}

// Shutdown cancels the traverser's context as passed to Start,
// and blocks until the traverser is fully stopped
// or until ctx is cancelled.
func (t *traverser) Shutdown(ctx context.Context) {
t.cancel()
select {
Expand All @@ -233,55 +243,60 @@ func (t *traverser) Shutdown(ctx context.Context) {
}
}

// IsComplete returns true if a traversal is complete
func (t *traverser) IsComplete() (bool, error) {
t.checkState()
// If the state is currently held due to an ongoing block load,
// block until it's finished or until the traverser stops,
// which then enables us to read the fields directly.
t.stateMu.Lock()
defer t.stateMu.Unlock()
return t.isDone, t.completionErr
}

// CurrentRequest returns the current block load waiting to be fulfilled in order
// to advance further
func (t *traverser) CurrentRequest() (ipld.Link, ipld.LinkContext) {
t.checkState()
// Just like IsComplete.
t.stateMu.Lock()
defer t.stateMu.Unlock()
return t.currentLink, t.currentContext
}

// Advance advances the traversal with an io.Reader for the next requested block
func (t *traverser) Advance(reader io.Reader) error {
isComplete, _ := t.IsComplete()
if isComplete {
return errors.New("cannot advance when done")
}
// Just like IsComplete, block until we're ready to load another block.
// We leave it to the other goroutine to unlock the mutex,
// once the next StorageReadOpener call comes in or the traversal is done.
t.stateMu.Lock()

select {
case <-t.ctx.Done():
return ContextCancelError{}
case t.awaitRequest <- struct{}{}:
if t.isDone {
// The other goroutine won't unlock, so we have to unlock.
t.stateMu.Unlock()
return errors.New("cannot advance when done")
}

select {
case <-t.ctx.Done():
// The other goroutine won't unlock, so we have to unlock.
t.stateMu.Unlock()
return ContextCancelError{}
case t.responses <- nextResponse{reader, nil}:
case t.responses <- nextResponse{input: reader}:
}

t.blocksCount++
return nil
}

// Error aborts the traversal with an error
func (t *traverser) Error(err error) {
isComplete, _ := t.IsComplete()
if isComplete {
return
}
select {
case <-t.ctx.Done():
// Just like Advance.
t.stateMu.Lock()

if t.isDone {
// The other goroutine won't unlock, so we have to unlock.
t.stateMu.Unlock()
return
case t.awaitRequest <- struct{}{}:
}

select {
case <-t.ctx.Done():
case t.responses <- nextResponse{nil, err}:
// The other goroutine won't unlock, so we have to unlock.
t.stateMu.Unlock()
case t.responses <- nextResponse{err: err}:
}
}
28 changes: 28 additions & 0 deletions ipldutil/traverser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ipldutil
import (
"bytes"
"context"
"errors"
"fmt"
"math"
"testing"
Expand Down Expand Up @@ -137,6 +138,33 @@ func TestTraverser(t *testing.T) {
}.Start(ctx)
checkTraverseSequence(ctx, t, traverser, []blocks.Block{}, &traversal.ErrBudgetExceeded{BudgetKind: "link", Link: blockChain.TipLink})
})

t.Run("started with shutdown context, then calls methods after done", func(t *testing.T) {
cancelledCtx, cancel := context.WithCancel(ctx)
cancel()
testdata := testutil.NewTestIPLDTree()
ssb := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any)
sel := ssb.ExploreRecursive(selector.RecursionLimitNone(), ssb.ExploreAll(ssb.ExploreRecursiveEdge())).Node()
traverser := TraversalBuilder{
Root: testdata.RootNodeLnk,
Selector: sel,
}.Start(cancelledCtx)

var err error
// To ensure the state isn't broken, do multiple calls.
for i := 0; i < 3; i++ {
err = traverser.Advance(bytes.NewBuffer(nil))
require.Error(t, err)

traverser.Error(errors.New("foo"))

done, err := traverser.IsComplete()
require.True(t, done)
require.Error(t, err)

_, _ = traverser.CurrentRequest()
}
})
}

func checkTraverseSequence(ctx context.Context, t *testing.T, traverser Traverser, expectedBlks []blocks.Block, finalErr error) {
Expand Down