diff --git a/dial_sync.go b/dial_sync.go index 03ba9cc0..6b3f2afe 100644 --- a/dial_sync.go +++ b/dial_sync.go @@ -8,8 +8,8 @@ import ( "github.com/libp2p/go-libp2p-core/peer" ) -// DialWorerFunc is used by DialSync to spawn a new dial worker -type dialWorkerFunc func(peer.ID, <-chan dialRequest) error +// dialWorkerFunc is used by DialSync to spawn a new dial worker +type dialWorkerFunc func(peer.ID, <-chan dialRequest) // newDialSync constructs a new DialSync func newDialSync(worker dialWorkerFunc) *DialSync { @@ -93,12 +93,7 @@ func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) { reqch: make(chan dialRequest), ds: ds, } - - if err := ds.dialWorker(p, actd.reqch); err != nil { - cancel() - return nil, err - } - + go ds.dialWorker(p, actd.reqch) ds.dials[p] = actd } @@ -108,9 +103,9 @@ func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) { return actd, nil } -// DialLock initiates a dial to the given peer if there are none in progress +// Dial initiates a dial to the given peer if there are none in progress // then waits for the dial to that peer to complete. -func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { +func (ds *DialSync) Dial(ctx context.Context, p peer.ID) (*Conn, error) { ad, err := ds.getActiveDial(p) if err != nil { return nil, err diff --git a/dial_sync_test.go b/dial_sync_test.go index 441c8b7d..0d9c6ca4 100644 --- a/dial_sync_test.go +++ b/dial_sync_test.go @@ -15,7 +15,7 @@ func getMockDialFunc() (dialWorkerFunc, func(), context.Context, <-chan struct{} dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care dialctx, cancel := context.WithCancel(context.Background()) ch := make(chan struct{}) - f := func(p peer.ID, reqch <-chan dialRequest) error { + f := func(p peer.ID, reqch <-chan dialRequest) { defer cancel() dfcalls <- struct{}{} go func() { @@ -24,7 +24,6 @@ func getMockDialFunc() (dialWorkerFunc, func(), context.Context, <-chan struct{} req.resch <- dialResponse{conn: new(Conn)} } }() - return nil } var once sync.Once @@ -38,14 +37,14 @@ func TestBasicDialSync(t *testing.T) { finished := make(chan struct{}, 2) go func() { - if _, err := dsync.DialLock(context.Background(), p); err != nil { + if _, err := dsync.Dial(context.Background(), p); err != nil { t.Error(err) } finished <- struct{}{} }() go func() { - if _, err := dsync.DialLock(context.Background(), p); err != nil { + if _, err := dsync.Dial(context.Background(), p); err != nil { t.Error(err) } finished <- struct{}{} @@ -74,7 +73,7 @@ func TestDialSyncCancel(t *testing.T) { finished := make(chan struct{}) go func() { - _, err := dsync.DialLock(ctx1, p) + _, err := dsync.Dial(ctx1, p) if err != ctx1.Err() { t.Error("should have gotten context error") } @@ -90,7 +89,7 @@ func TestDialSyncCancel(t *testing.T) { // Add a second dialwait in so two actors are waiting on the same dial go func() { - _, err := dsync.DialLock(context.Background(), p) + _, err := dsync.Dial(context.Background(), p) if err != nil { t.Error(err) } @@ -123,7 +122,7 @@ func TestDialSyncAllCancel(t *testing.T) { finished := make(chan struct{}) go func() { - if _, err := dsync.DialLock(ctx, p); err != ctx.Err() { + if _, err := dsync.Dial(ctx, p); err != ctx.Err() { t.Error("should have gotten context error") } finished <- struct{}{} @@ -131,7 +130,7 @@ func TestDialSyncAllCancel(t *testing.T) { // Add a second dialwait in so two actors are waiting on the same dial go func() { - if _, err := dsync.DialLock(ctx, p); err != ctx.Err() { + if _, err := dsync.Dial(ctx, p); err != ctx.Err() { t.Error("should have gotten context error") } finished <- struct{}{} @@ -155,14 +154,14 @@ func TestDialSyncAllCancel(t *testing.T) { // should be able to successfully dial that peer again done() - if _, err := dsync.DialLock(context.Background(), p); err != nil { + if _, err := dsync.Dial(context.Background(), p); err != nil { t.Fatal(err) } } func TestFailFirst(t *testing.T) { var count int32 - f := func(p peer.ID, reqch <-chan dialRequest) error { + f := func(p peer.ID, reqch <-chan dialRequest) { go func() { for { req, ok := <-reqch @@ -178,33 +177,29 @@ func TestFailFirst(t *testing.T) { atomic.AddInt32(&count, 1) } }() - return nil } ds := newDialSync(f) - p := peer.ID("testing") ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - _, err := ds.DialLock(ctx, p) - if err == nil { + if _, err := ds.Dial(ctx, p); err == nil { t.Fatal("expected gophers to have eaten the modem") } - c, err := ds.DialLock(ctx, p) + c, err := ds.Dial(ctx, p) if err != nil { t.Fatal(err) } - if c == nil { t.Fatal("should have gotten a 'real' conn back") } } func TestStressActiveDial(t *testing.T) { - ds := newDialSync(func(p peer.ID, reqch <-chan dialRequest) error { + ds := newDialSync(func(p peer.ID, reqch <-chan dialRequest) { go func() { for { req, ok := <-reqch @@ -214,7 +209,6 @@ func TestStressActiveDial(t *testing.T) { req.resch <- dialResponse{} } }() - return nil }) wg := sync.WaitGroup{} @@ -223,7 +217,7 @@ func TestStressActiveDial(t *testing.T) { makeDials := func() { for i := 0; i < 10000; i++ { - ds.DialLock(context.Background(), pid) + ds.Dial(context.Background(), pid) } wg.Done() } @@ -235,24 +229,3 @@ func TestStressActiveDial(t *testing.T) { wg.Wait() } - -func TestDialSelf(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - self := peer.ID("ABC") - s := NewSwarm(ctx, self, nil, nil) - defer s.Close() - - // this should fail - _, err := s.dsync.DialLock(ctx, self) - if err != ErrDialToSelf { - t.Fatal("expected error from self dial") - } - - // do it twice to make sure we get a new active dial object that fails again - _, err = s.dsync.DialLock(ctx, self) - if err != ErrDialToSelf { - t.Fatal("expected error from self dial") - } -} diff --git a/dial_test.go b/dial_test.go index 6258d0ed..3bbc5e2b 100644 --- a/dial_test.go +++ b/dial_test.go @@ -672,7 +672,7 @@ func TestDialSimultaneousJoin(t *testing.T) { } } -func TestDialSelf2(t *testing.T) { +func TestDialSelf(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/swarm.go b/swarm.go index 35eb4156..c34497f7 100644 --- a/swarm.go +++ b/swarm.go @@ -122,7 +122,7 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc } } - s.dsync = newDialSync(s.startDialWorker) + s.dsync = newDialSync(s.dialWorkerLoop) s.limiter = newDialLimiter(s.dialAddr) s.proc = goprocessctx.WithContext(ctx) s.ctx = goprocessctx.OnClosingContext(s.proc) diff --git a/swarm_dial.go b/swarm_dial.go index 83a0468f..21b5b84e 100644 --- a/swarm_dial.go +++ b/swarm_dial.go @@ -259,7 +259,7 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { ctx, cancel := context.WithTimeout(ctx, network.GetDialPeerTimeout(ctx)) defer cancel() - conn, err = s.dsync.DialLock(ctx, p) + conn, err = s.dsync.Dial(ctx, p) if err == nil { return conn, nil } @@ -294,16 +294,7 @@ type dialResponse struct { err error } -// startDialWorker starts an active dial goroutine that synchronizes and executes concurrent dials -func (s *Swarm) startDialWorker(p peer.ID, reqch <-chan dialRequest) error { - if p == s.local { - return ErrDialToSelf - } - - go s.dialWorkerLoop(p, reqch) - return nil -} - +// dialWorkerLoop synchronizes and executes concurrent dials to a single peer func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) { defer s.limiter.clearAllPeerDials(p)