Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Commit

Permalink
simplify the DialSync code
Browse files Browse the repository at this point in the history
It's easier to reason about this code if activeDial doesn't contain a pointer
back to DialSync (which already has a map of activeDials). It also allows us to
remove the memory footprint of the activeDial struct, so this should be
(slightly) more efficient.
  • Loading branch information
marten-seemann committed Aug 23, 2021
1 parent f3ae0cb commit 49c430f
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions dial_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/libp2p/go-libp2p-core/peer"
)

// DialWorerFunc is used by DialSync to spawn a new dial worker
// dialWorkerFunc is used by DialSync to spawn a new dial worker
type dialWorkerFunc func(peer.ID, <-chan dialRequest) error

// newDialSync constructs a new DialSync
Expand All @@ -22,35 +22,26 @@ func newDialSync(worker dialWorkerFunc) *DialSync {
// DialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
type DialSync struct {
mutex sync.Mutex
dials map[peer.ID]*activeDial
dialsLk sync.Mutex
dialWorker dialWorkerFunc
}

type activeDial struct {
id peer.ID
refCnt int

ctx context.Context
cancel func()

reqch chan dialRequest

ds *DialSync
}

func (ad *activeDial) decref() {
ad.ds.dialsLk.Lock()
ad.refCnt--
if ad.refCnt == 0 {
ad.cancel()
close(ad.reqch)
delete(ad.ds.dials, ad.id)
}
ad.ds.dialsLk.Unlock()
func (ad *activeDial) close() {
ad.cancel()
close(ad.reqch)
}

func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
func (ad *activeDial) dial(ctx context.Context) (*Conn, error) {
dialCtx := ad.ctx

if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
Expand All @@ -76,35 +67,29 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
}

func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) {
ds.dialsLk.Lock()
defer ds.dialsLk.Unlock()
ds.mutex.Lock()
defer ds.mutex.Unlock()

actd, ok := ds.dials[p]
if !ok {
// This code intentionally uses the background context. Otherwise, if the first call
// to Dial is canceled, subsequent dial calls will also be canceled.
// XXX: this also breaks direct connection logic. We will need to pipe the
// information through some other way.
adctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(context.Background())
actd = &activeDial{
id: p,
ctx: adctx,
ctx: ctx,
cancel: cancel,
reqch: make(chan dialRequest),
ds: ds,
}

if err := ds.dialWorker(p, actd.reqch); err != nil {
cancel()
return nil, err
}

ds.dials[p] = actd
}

// increase ref count before dropping dialsLk
// increase ref count before dropping mutex
actd.refCnt++

return actd, nil
}

Expand All @@ -116,6 +101,14 @@ func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
return nil, err
}

defer ad.decref()
return ad.dial(ctx, p)
defer func() {
ds.mutex.Lock()
defer ds.mutex.Unlock()
ad.refCnt--
if ad.refCnt == 0 {
ad.close()
delete(ds.dials, p)
}
}()
return ad.dial(ctx)
}

0 comments on commit 49c430f

Please sign in to comment.