diff --git a/dial_sync.go b/dial_sync.go index 03ba9cc0..074de3bc 100644 --- a/dial_sync.go +++ b/dial_sync.go @@ -8,7 +8,7 @@ import ( "github.com/libp2p/go-libp2p-core/peer" ) -// DialWorerFunc is used by DialSync to spawn a new dial worker +// dialWorkerFunc is used by DialSync to spawn a new dial worker type dialWorkerFunc func(peer.ID, <-chan dialRequest) error // newDialSync constructs a new DialSync @@ -22,35 +22,26 @@ func newDialSync(worker dialWorkerFunc) *DialSync { // DialSync is a dial synchronization helper that ensures that at most one dial // to any given peer is active at any given time. type DialSync struct { + mutex sync.Mutex dials map[peer.ID]*activeDial - dialsLk sync.Mutex dialWorker dialWorkerFunc } type activeDial struct { - id peer.ID refCnt int ctx context.Context cancel func() reqch chan dialRequest - - ds *DialSync } -func (ad *activeDial) decref() { - ad.ds.dialsLk.Lock() - ad.refCnt-- - if ad.refCnt == 0 { - ad.cancel() - close(ad.reqch) - delete(ad.ds.dials, ad.id) - } - ad.ds.dialsLk.Unlock() +func (ad *activeDial) close() { + ad.cancel() + close(ad.reqch) } -func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { +func (ad *activeDial) dial(ctx context.Context) (*Conn, error) { dialCtx := ad.ctx if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect { @@ -76,8 +67,8 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { } func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) { - ds.dialsLk.Lock() - defer ds.dialsLk.Unlock() + ds.mutex.Lock() + defer ds.mutex.Unlock() actd, ok := ds.dials[p] if !ok { @@ -85,26 +76,20 @@ func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) { // to Dial is canceled, subsequent dial calls will also be canceled. // XXX: this also breaks direct connection logic. We will need to pipe the // information through some other way. - adctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) actd = &activeDial{ - id: p, - ctx: adctx, + ctx: ctx, cancel: cancel, reqch: make(chan dialRequest), - ds: ds, } - if err := ds.dialWorker(p, actd.reqch); err != nil { cancel() return nil, err } - ds.dials[p] = actd } - - // increase ref count before dropping dialsLk + // increase ref count before dropping mutex actd.refCnt++ - return actd, nil } @@ -116,6 +101,14 @@ func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { return nil, err } - defer ad.decref() - return ad.dial(ctx, p) + defer func() { + ds.mutex.Lock() + defer ds.mutex.Unlock() + ad.refCnt-- + if ad.refCnt == 0 { + ad.close() + delete(ds.dials, p) + } + }() + return ad.dial(ctx) }