Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Fix simultaneous dials #246

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 159 additions & 44 deletions dial_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@ import (
"errors"
"sync"

"github.com/hashicorp/go-multierror"

"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"

ma "github.com/multiformats/go-multiaddr"
)

// TODO: change this text when we fix the bug
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")
var errDialFailed = errors.New("dial failed")

// DialFunc is the type of function expected by DialSync.
type DialFunc func(context.Context, peer.ID) (*Conn, error)
type DialFunc func(context.Context, peer.ID, DialDedupFunc) (*Conn, error)

// DialDedupFunc is a function that deduplicates a set of multiaddrs from active dials
type DialDedupFunc func([]ma.Multiaddr) []ma.Multiaddr

// NewDialSync constructs a new DialSync
func NewDialSync(dfn DialFunc) *DialSync {
Expand All @@ -34,17 +41,61 @@ type activeDial struct {
id peer.ID
refCnt int
refCntLk sync.Mutex
ctx context.Context
cancel func()

addrs map[string]struct{}
addrsLk sync.Mutex

err error
conn *Conn
waitch chan struct{}
connch chan *Conn
errch chan error
dialch chan struct{}
donech chan struct{}

ds *DialSync
}

func (ad *activeDial) wait(ctx context.Context) (*Conn, error) {
func (ad *activeDial) dial(ctx context.Context) (*Conn, error) {
defer ad.decref()

dialCtx := ad.dialContext(ctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: by go convention, xContext means "x but with a context". Rename to deriveDialContext.

res := make(chan *Conn, 1)
go func() {
defer close(res)

c, err := ad.ds.dialFunc(dialCtx, ad.id, ad.dedup)

if err != nil {
select {
case ad.errch <- err:
case <-ad.waitch:
}

return
}

res <- c

select {
case ad.connch <- c:
case <-ad.waitch:
}
}()

// first try to get the context specific connection, if any
select {
case c := <-res:
if c != nil {
return c, nil
}
case <-ctx.Done():
return nil, ctx.Err()
}

// we don't have a context-specific connection, join the other dials
select {
case <-ad.waitch:
return ad.conn, ad.err
Expand All @@ -53,47 +104,96 @@ func (ad *activeDial) wait(ctx context.Context) (*Conn, error) {
}
}

func (ad *activeDial) dialContext(ctx context.Context) context.Context {
dialCtx := ad.ctx

forceDirect, reason := network.GetForceDirectDial(ctx)
if forceDirect {
dialCtx = network.WithForceDirectDial(dialCtx, reason)
}

simConnect, reason := network.GetSimultaneousConnect(ctx)
if simConnect {
dialCtx = network.WithSimultaneousConnect(dialCtx, reason)
}

return dialCtx
}

func (ad *activeDial) dedup(addrs []ma.Multiaddr) (result []ma.Multiaddr) {
ad.addrsLk.Lock()
defer ad.addrsLk.Unlock()

for _, a := range addrs {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we never bother to dial one of these addresses because we're canceled? We'll just skip that overall.

Test case:

  1. Start dial A.
  2. Start dial B.
  3. Cancel dial A.
  4. Dial B should pick up where dial A left off.

With the current code, dial B will, I believe, fail to dial any of the addresses dial A was trying to dial.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, the dials will still happen -- they use the background context, so there should be no interference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see.

key := a.String()

_, active := ad.addrs[key]
if active {
continue
}

result = append(result, a)
ad.addrs[key] = struct{}{}
}

return result
}

func (ad *activeDial) incref() {
ad.refCntLk.Lock()
defer ad.refCntLk.Unlock()
ad.refCnt++
}

func (ad *activeDial) decref() {
// make sure to always take locks in correct order.
ad.ds.dialsLk.Lock()
ad.refCntLk.Lock()
ad.refCnt--
maybeZero := (ad.refCnt <= 0)
ad.refCntLk.Unlock()

// make sure to always take locks in correct order.
if maybeZero {
ad.ds.dialsLk.Lock()
ad.refCntLk.Lock()
// check again after lock swap drop to make sure nobody else called incref
// in between locks
if ad.refCnt <= 0 {
ad.cancel()
delete(ad.ds.dials, ad.id)
}
ad.refCntLk.Unlock()
ad.ds.dialsLk.Unlock()
if ad.refCnt == 0 {
ad.cancel()
close(ad.donech)
delete(ad.ds.dials, ad.id)
}
ad.refCntLk.Unlock()
ad.ds.dialsLk.Unlock()
}

func (ad *activeDial) start(ctx context.Context) {
ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id)

// This isn't the user's context so we should fix the error.
switch ad.err {
case context.Canceled:
// The dial was canceled with `CancelDial`.
ad.err = errDialCanceled
case context.DeadlineExceeded:
// We hit an internal timeout, not a context timeout.
ad.err = ErrDialTimeout
defer ad.cancel()
defer close(ad.waitch)

dialCnt := 0
for {
select {
case <-ad.dialch:
dialCnt++

case ad.conn = <-ad.connch:
ad.err = nil
return

case err := <-ad.errch:
if err != ErrNoNewAddresses {
ad.err = multierror.Append(ad.err, err)
}

dialCnt--
if dialCnt == 0 {
if ad.err == nil {
ad.err = errDialFailed
}

return
}

case <-ctx.Done():
if ad.err == nil {
ad.err = errDialFailed
}
return
}
}
close(ad.waitch)
ad.cancel()
}

func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
Expand All @@ -102,15 +202,17 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {

actd, ok := ds.dials[p]
if !ok {
// This code intentionally uses the background context. Otherwise, if the first call
// to Dial is canceled, subsequent dial calls will also be canceled.
// XXX: this also breaks direct connection logic. We will need to pipe the
// information through some other way.
adctx, cancel := context.WithCancel(context.Background())
actd = &activeDial{
id: p,
ctx: adctx,
cancel: cancel,
addrs: make(map[string]struct{}),
waitch: make(chan struct{}),
connch: make(chan *Conn),
errch: make(chan error),
dialch: make(chan struct{}),
donech: make(chan struct{}),
ds: ds,
}
ds.dials[p] = actd
Expand All @@ -127,14 +229,27 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
// DialLock initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
return ds.getActiveDial(p).wait(ctx)
}

// CancelDial cancels all in-progress dials to the given peer.
func (ds *DialSync) CancelDial(p peer.ID) {
ds.dialsLk.Lock()
defer ds.dialsLk.Unlock()
if ad, ok := ds.dials[p]; ok {
ad.cancel()
var ad *activeDial

startDial:
for {
ad = ds.getActiveDial(p)

// signal the start of dial
select {
case ad.dialch <- struct{}{}:
break startDial
case <-ad.waitch:
// we lost a race, we need to try again because the connection might not be what we want
ad.decref()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we dereferencing? Won't that cause the in-progress dial to be canceled if the original caller bails?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, we'll just restart, but that's pretty inefficient if we were half-way through opening a bunch of connections.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out, we only invoke this code if the active dial is currently "finishing". So we're not going to cancel anything.


select {
case <-ad.donech:
case <-ctx.Done():
return nil, ctx.Err()
}
}
}

return ad.dial(ctx)
}
Loading