Skip to content

Commit

Permalink
Merge pull request #1170 from hashicorp/b-connection-spam
Browse files Browse the repository at this point in the history
Fixes #1165 by having threads wait for any outstanding connect to finish.
  • Loading branch information
slackpad committed Aug 13, 2015
2 parents 00e35cd + 614bf44 commit 009f0fb
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 27 deletions.
41 changes: 41 additions & 0 deletions consul/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"os"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -189,6 +190,46 @@ func TestClient_RPC(t *testing.T) {
})
}

func TestClient_RPC_Pool(t *testing.T) {
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()

dir2, c1 := testClient(t)
defer os.RemoveAll(dir2)
defer c1.Shutdown()

// Try to join.
addr := fmt.Sprintf("127.0.0.1:%d",
s1.config.SerfLANConfig.MemberlistConfig.BindPort)
if _, err := c1.JoinLAN([]string{addr}); err != nil {
t.Fatalf("err: %v", err)
}
if len(s1.LANMembers()) != 2 || len(c1.LANMembers()) != 2 {
t.Fatalf("bad len")
}

// Blast out a bunch of RPC requests at the same time to try to get
// contention opening new connections.
var wg sync.WaitGroup
for i := 0; i < 150; i++ {
wg.Add(1)

go func() {
defer wg.Done()
var out struct{}
testutil.WaitForResult(func() (bool, error) {
err := c1.RPC("Status.Ping", struct{}{}, &out)
return err == nil, err
}, func(err error) {
t.Fatalf("err: %v", err)
})
}()
}

wg.Wait()
}

func TestClient_RPC_TLS(t *testing.T) {
dir1, conf1 := testServerConfig(t, "a.testco.internal")
conf1.VerifyIncoming = true
Expand Down
97 changes: 70 additions & 27 deletions consul/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ func (c *Conn) returnClient(client *StreamClient) {
}
}

// markForUse does all the bookkeeping required to ready a connection for use.
func (c *Conn) markForUse() {
c.lastUsed = time.Now()
atomic.AddInt32(&c.refCount, 1)
}

// ConnPool is used to maintain a connection pool to other
// Consul servers. This is used to reduce the latency of
// RPC requests between servers. It is only used to pool
Expand All @@ -134,6 +140,12 @@ type ConnPool struct {
// Pool maps an address to a open connection
pool map[string]*Conn

// limiter is used to throttle the number of connect attempts
// to a given address. The first thread will attempt a connection
// and put a channel in here, which all other threads will wait
// on to close.
limiter map[string]chan struct{}

// TLS wrapper
tlsWrap tlsutil.DCWrapper

Expand All @@ -153,6 +165,7 @@ func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap
maxTime: maxTime,
maxStreams: maxStreams,
pool: make(map[string]*Conn),
limiter: make(map[string]chan struct{}),
tlsWrap: tlsWrap,
shutdownCh: make(chan struct{}),
}
Expand Down Expand Up @@ -180,28 +193,69 @@ func (p *ConnPool) Shutdown() error {
return nil
}

// Acquire is used to get a connection that is
// pooled or to return a new connection
// acquire will return a pooled connection, if available. Otherwise it will
// wait for an existing connection attempt to finish, if one if in progress,
// and will return that one if it succeeds. If all else fails, it will return a
// newly-created connection and add it to the pool.
func (p *ConnPool) acquire(dc string, addr net.Addr, version int) (*Conn, error) {
// Check for a pooled ocnn
if conn := p.getPooled(addr, version); conn != nil {
return conn, nil
// Check to see if there's a pooled connection available. This is up
// here since it should the the vastly more common case than the rest
// of the code here.
p.Lock()
c := p.pool[addr.String()]
if c != nil {
c.markForUse()
p.Unlock()
return c, nil
}

// Create a new connection
return p.getNewConn(dc, addr, version)
}
// If not (while we are still locked), set up the throttling structure
// for this address, which will make everyone else wait until our
// attempt is done.
var wait chan struct{}
var ok bool
if wait, ok = p.limiter[addr.String()]; !ok {
wait = make(chan struct{})
p.limiter[addr.String()] = wait
}
isLeadThread := !ok
p.Unlock()

// If we are the lead thread, make the new connection and then wake
// everybody else up to see if we got it.
if isLeadThread {
c, err := p.getNewConn(dc, addr, version)
p.Lock()
delete(p.limiter, addr.String())
close(wait)
if err != nil {
p.Unlock()
return nil, err
}

p.pool[addr.String()] = c
p.Unlock()
return c, nil
}

// Otherwise, wait for the lead thread to attempt the connection
// and use what's in the pool at that point.
select {
case <-p.shutdownCh:
return nil, fmt.Errorf("rpc error: shutdown")
case <-wait:
}

// getPooled is used to return a pooled connection
func (p *ConnPool) getPooled(addr net.Addr, version int) *Conn {
// See if the lead thread was able to get us a connection.
p.Lock()
c := p.pool[addr.String()]
if c != nil {
c.lastUsed = time.Now()
atomic.AddInt32(&c.refCount, 1)
if c := p.pool[addr.String()]; c != nil {
c.markForUse()
p.Unlock()
return c, nil
}

p.Unlock()
return c
return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
}

// getNewConn is used to return a new connection
Expand Down Expand Up @@ -272,18 +326,7 @@ func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int) (*Conn, err
version: version,
pool: p,
}

// Track this connection, handle potential race condition
p.Lock()
if existing := p.pool[addr.String()]; existing != nil {
c.Close()
p.Unlock()
return existing, nil
} else {
p.pool[addr.String()] = c
p.Unlock()
return c, nil
}
return c, nil
}

// clearConn is used to clear any cached connection, potentially in response to an erro
Expand Down

0 comments on commit 009f0fb

Please sign in to comment.