From 4757f910d94baf32db3ebae3f6e6621521cb54d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Magiera?= Date: Mon, 19 Feb 2018 03:17:27 +0100 Subject: [PATCH] dial delay: simplify the code --- dial_delay.go | 121 ++++++++++++++++++++++++++++++--------------- dial_delay_test.go | 6 +++ 2 files changed, 86 insertions(+), 41 deletions(-) diff --git a/dial_delay.go b/dial_delay.go index c7b589f4..de77d467 100644 --- a/dial_delay.go +++ b/dial_delay.go @@ -49,38 +49,38 @@ func delayDialAddrs(ctx context.Context, c <-chan ma.Multiaddr) (<-chan ma.Multi return nil, -1 } - outer: - for { - fill: + // fillBuckets reads pending addresses form the channel without blocking + fillBuckets := func() bool { for { select { case addr, ok := <-c: if !ok { - break outer + return false } put(addr) default: - break fill + return true } } + } - next, tier := get() - - // Nothing? Block! - if next == nil { - select { - case addr, ok := <-c: - if !ok { - break outer - } - put(addr) - case <-ctx.Done(): - return + // waitForMore woits for addresses from the channel + waitForMore := func() (bool, error) { + select { + case addr, ok := <-c: + if !ok { + return false, nil } - continue + put(addr) + case <-ctx.Done(): + return false, ctx.Err() } + return true, nil + } - // Jumping a tier? + // maybeJumpTier will check if the address tier is changing and optionally + // wait some time. + maybeJumpTier := func(tier int, next ma.Multiaddr) (cont bool, brk bool, err error) { if tier > lastTier && lastTier != -1 { // Wait the delay (preempt with new addresses or when the dialer // requests more addresses) @@ -88,10 +88,10 @@ func delayDialAddrs(ctx context.Context, c <-chan ma.Multiaddr) (<-chan ma.Multi case addr, ok := <-c: put(next) if !ok { - break outer + return false, true, nil } put(addr) - continue + return true, false, nil case <-delay.C: delay.Reset(tierDelay) case <-triggerNext: @@ -100,20 +100,24 @@ func delayDialAddrs(ctx context.Context, c <-chan ma.Multiaddr) (<-chan ma.Multi } delay.Reset(tierDelay) case <-ctx.Done(): - return + return false, false, ctx.Err() } } + // Note that we want to only update the tier after we've done the waiting + // or we were asked to finish early lastTier = tier + return false, false, nil + } + recvOrSend := func(next ma.Multiaddr) (brk bool, err error) { select { case addr, ok := <-c: put(next) if !ok { - break outer + return true, nil } put(addr) - continue case out <- next: // Always count the timeout since the last dial. if !delay.Stop() { @@ -121,34 +125,69 @@ func delayDialAddrs(ctx context.Context, c <-chan ma.Multiaddr) (<-chan ma.Multi } delay.Reset(tierDelay) case <-ctx.Done(): + return false, ctx.Err() + } + return false, nil + } + + // process the address stream + for { + if !fillBuckets() { + break // input channel closed + } + + next, tier := get() + + // Nothing? Block! + if next == nil { + ok, err := waitForMore() + if err != nil { + return + } + if !ok { + break // input channel closed + } + continue + } + + cont, brk, err := maybeJumpTier(tier, next) + if cont { + continue // received an address while waiting, in case it's lower tier + // look at it immediately + } + if brk { + break // input channel closed + } + if err != nil { + return + } + + brk, err = recvOrSend(next) + if brk { + break // input channel closed + } + if err != nil { return } } + // the channel is closed by now + c = nil + // finish sending for { next, tier := get() if next == nil { return } - if tier > lastTier && lastTier != -1 { - select { - case <-delay.C: - case <-triggerNext: - if !delay.Stop() { - <-delay.C - } - delay.Reset(tierDelay) - case <-ctx.Done(): - return - } + + _, _, err := maybeJumpTier(tier, next) + if err != nil { + return } - lastTier = tier - select { - case out <- next: - delay.Stop() - delay.Reset(tierDelay) - case <-ctx.Done(): + + _, err = recvOrSend(next) + if err != nil { return } } diff --git a/dial_delay_test.go b/dial_delay_test.go index 9fb4bea1..b5391d44 100644 --- a/dial_delay_test.go +++ b/dial_delay_test.go @@ -32,6 +32,12 @@ func prepare() { tierDelay = 32 * time.Millisecond // 2x windows timer resolution } +// addrChan creates a multiaddr channel with `nsync` size. If nsync is larger +// than 0, the entries will get pre-buffered in the channel. +// addrDelays is a set of addresses and delays between sending them. If a string +// starts with '/' it will be parsed as an address and sent to the channel. +// Otherwise it will get parsed as a time to sleep before sending next addresses +// or closing the channel func addrChan(t *testing.T, nsync int, addrDelays ...string) <-chan ma.Multiaddr { out := make(chan ma.Multiaddr, nsync) c := sync.NewCond(&sync.Mutex{})