-
Notifications
You must be signed in to change notification settings - Fork 37
Fix simultaneous dials #246
Changes from all commits
5be49a7
cbe6057
cbb1532
7912714
ee6ff3c
ca02ff0
1ea1a82
e9d8076
9976396
58851a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,14 +5,21 @@ import ( | |
"errors" | ||
"sync" | ||
|
||
"github.com/hashicorp/go-multierror" | ||
|
||
"github.com/libp2p/go-libp2p-core/network" | ||
"github.com/libp2p/go-libp2p-core/peer" | ||
|
||
ma "github.com/multiformats/go-multiaddr" | ||
) | ||
|
||
// TODO: change this text when we fix the bug | ||
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW") | ||
var errDialFailed = errors.New("dial failed") | ||
|
||
// DialFunc is the type of function expected by DialSync. | ||
type DialFunc func(context.Context, peer.ID) (*Conn, error) | ||
type DialFunc func(context.Context, peer.ID, DialDedupFunc) (*Conn, error) | ||
|
||
// DialDedupFunc is a function that deduplicates a set of multiaddrs from active dials | ||
type DialDedupFunc func([]ma.Multiaddr) []ma.Multiaddr | ||
|
||
// NewDialSync constructs a new DialSync | ||
func NewDialSync(dfn DialFunc) *DialSync { | ||
|
@@ -34,17 +41,61 @@ type activeDial struct { | |
id peer.ID | ||
refCnt int | ||
refCntLk sync.Mutex | ||
ctx context.Context | ||
cancel func() | ||
|
||
addrs map[string]struct{} | ||
addrsLk sync.Mutex | ||
|
||
err error | ||
conn *Conn | ||
waitch chan struct{} | ||
connch chan *Conn | ||
errch chan error | ||
dialch chan struct{} | ||
donech chan struct{} | ||
|
||
ds *DialSync | ||
} | ||
|
||
func (ad *activeDial) wait(ctx context.Context) (*Conn, error) { | ||
func (ad *activeDial) dial(ctx context.Context) (*Conn, error) { | ||
defer ad.decref() | ||
|
||
dialCtx := ad.dialContext(ctx) | ||
res := make(chan *Conn, 1) | ||
go func() { | ||
defer close(res) | ||
|
||
c, err := ad.ds.dialFunc(dialCtx, ad.id, ad.dedup) | ||
|
||
if err != nil { | ||
select { | ||
case ad.errch <- err: | ||
case <-ad.waitch: | ||
} | ||
|
||
return | ||
} | ||
|
||
res <- c | ||
|
||
select { | ||
case ad.connch <- c: | ||
case <-ad.waitch: | ||
} | ||
}() | ||
|
||
// first try to get the context specific connection, if any | ||
select { | ||
case c := <-res: | ||
if c != nil { | ||
return c, nil | ||
} | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
} | ||
|
||
// we don't have a context-specific connection, join the other dials | ||
select { | ||
case <-ad.waitch: | ||
return ad.conn, ad.err | ||
|
@@ -53,47 +104,96 @@ func (ad *activeDial) wait(ctx context.Context) (*Conn, error) { | |
} | ||
} | ||
|
||
func (ad *activeDial) dialContext(ctx context.Context) context.Context { | ||
dialCtx := ad.ctx | ||
|
||
forceDirect, reason := network.GetForceDirectDial(ctx) | ||
if forceDirect { | ||
dialCtx = network.WithForceDirectDial(dialCtx, reason) | ||
} | ||
|
||
simConnect, reason := network.GetSimultaneousConnect(ctx) | ||
if simConnect { | ||
dialCtx = network.WithSimultaneousConnect(dialCtx, reason) | ||
} | ||
|
||
return dialCtx | ||
} | ||
|
||
func (ad *activeDial) dedup(addrs []ma.Multiaddr) (result []ma.Multiaddr) { | ||
ad.addrsLk.Lock() | ||
defer ad.addrsLk.Unlock() | ||
|
||
for _, a := range addrs { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if we never bother to dial one of these addresses because we're canceled? We'll just skip that overall. Test case:
With the current code, dial B will, I believe, fail to dial any of the addresses dial A was trying to dial. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, the dials will still happen -- they use the background context, so there should be no interference. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. |
||
key := a.String() | ||
|
||
_, active := ad.addrs[key] | ||
if active { | ||
continue | ||
} | ||
|
||
result = append(result, a) | ||
ad.addrs[key] = struct{}{} | ||
} | ||
|
||
return result | ||
} | ||
|
||
func (ad *activeDial) incref() { | ||
ad.refCntLk.Lock() | ||
defer ad.refCntLk.Unlock() | ||
ad.refCnt++ | ||
} | ||
|
||
func (ad *activeDial) decref() { | ||
// make sure to always take locks in correct order. | ||
ad.ds.dialsLk.Lock() | ||
ad.refCntLk.Lock() | ||
ad.refCnt-- | ||
maybeZero := (ad.refCnt <= 0) | ||
ad.refCntLk.Unlock() | ||
|
||
// make sure to always take locks in correct order. | ||
if maybeZero { | ||
ad.ds.dialsLk.Lock() | ||
ad.refCntLk.Lock() | ||
// check again after lock swap drop to make sure nobody else called incref | ||
// in between locks | ||
if ad.refCnt <= 0 { | ||
ad.cancel() | ||
delete(ad.ds.dials, ad.id) | ||
} | ||
ad.refCntLk.Unlock() | ||
ad.ds.dialsLk.Unlock() | ||
if ad.refCnt == 0 { | ||
ad.cancel() | ||
close(ad.donech) | ||
delete(ad.ds.dials, ad.id) | ||
} | ||
ad.refCntLk.Unlock() | ||
ad.ds.dialsLk.Unlock() | ||
} | ||
|
||
func (ad *activeDial) start(ctx context.Context) { | ||
ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id) | ||
|
||
// This isn't the user's context so we should fix the error. | ||
switch ad.err { | ||
case context.Canceled: | ||
// The dial was canceled with `CancelDial`. | ||
ad.err = errDialCanceled | ||
case context.DeadlineExceeded: | ||
// We hit an internal timeout, not a context timeout. | ||
ad.err = ErrDialTimeout | ||
defer ad.cancel() | ||
defer close(ad.waitch) | ||
|
||
dialCnt := 0 | ||
for { | ||
select { | ||
case <-ad.dialch: | ||
dialCnt++ | ||
|
||
case ad.conn = <-ad.connch: | ||
ad.err = nil | ||
return | ||
|
||
case err := <-ad.errch: | ||
if err != ErrNoNewAddresses { | ||
ad.err = multierror.Append(ad.err, err) | ||
} | ||
|
||
dialCnt-- | ||
if dialCnt == 0 { | ||
if ad.err == nil { | ||
ad.err = errDialFailed | ||
} | ||
|
||
return | ||
} | ||
|
||
case <-ctx.Done(): | ||
if ad.err == nil { | ||
ad.err = errDialFailed | ||
} | ||
return | ||
} | ||
} | ||
close(ad.waitch) | ||
ad.cancel() | ||
} | ||
|
||
func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { | ||
|
@@ -102,15 +202,17 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { | |
|
||
actd, ok := ds.dials[p] | ||
if !ok { | ||
// This code intentionally uses the background context. Otherwise, if the first call | ||
// to Dial is canceled, subsequent dial calls will also be canceled. | ||
// XXX: this also breaks direct connection logic. We will need to pipe the | ||
// information through some other way. | ||
adctx, cancel := context.WithCancel(context.Background()) | ||
actd = &activeDial{ | ||
id: p, | ||
ctx: adctx, | ||
cancel: cancel, | ||
addrs: make(map[string]struct{}), | ||
waitch: make(chan struct{}), | ||
connch: make(chan *Conn), | ||
errch: make(chan error), | ||
dialch: make(chan struct{}), | ||
donech: make(chan struct{}), | ||
ds: ds, | ||
} | ||
ds.dials[p] = actd | ||
|
@@ -127,14 +229,27 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { | |
// DialLock initiates a dial to the given peer if there are none in progress | ||
// then waits for the dial to that peer to complete. | ||
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { | ||
return ds.getActiveDial(p).wait(ctx) | ||
} | ||
|
||
// CancelDial cancels all in-progress dials to the given peer. | ||
func (ds *DialSync) CancelDial(p peer.ID) { | ||
ds.dialsLk.Lock() | ||
defer ds.dialsLk.Unlock() | ||
if ad, ok := ds.dials[p]; ok { | ||
ad.cancel() | ||
var ad *activeDial | ||
|
||
startDial: | ||
for { | ||
ad = ds.getActiveDial(p) | ||
|
||
// signal the start of dial | ||
select { | ||
case ad.dialch <- struct{}{}: | ||
break startDial | ||
case <-ad.waitch: | ||
// we lost a race, we need to try again because the connection might not be what we want | ||
ad.decref() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we dereferencing? Won't that cause the in-progress dial to be canceled if the original caller bails? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, we'll just restart, but that's pretty inefficient if we were half-way through opening a bunch of connections. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Turns out, we only invoke this code if the active dial is currently "finishing". So we're not going to cancel anything. |
||
|
||
select { | ||
case <-ad.donech: | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
} | ||
} | ||
} | ||
|
||
return ad.dial(ctx) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: by go convention,
xContext
means "x but with a context". Rename toderiveDialContext
.