From 48e6cc4de4a6a407d98e51e4f3d705f89f353ddf Mon Sep 17 00:00:00 2001
From: James Phillips <slackpad@gmail.com>
Date: Thu, 13 Aug 2015 11:38:39 -0700
Subject: [PATCH] Merge pull request #1170 from hashicorp/b-connection-spam

Fixes #1165 by having threads wait for any outstanding connect to finish.
---
 consul/client_test.go | 41 ++++++++++++++++++
 consul/pool.go        | 97 +++++++++++++++++++++++++++++++------------
 2 files changed, 111 insertions(+), 27 deletions(-)

diff --git a/consul/client_test.go b/consul/client_test.go
index c799696abc61..274784d818b9 100644
--- a/consul/client_test.go
+++ b/consul/client_test.go
@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"net"
 	"os"
+	"sync"
 	"testing"
 	"time"
 
@@ -189,6 +190,46 @@ func TestClient_RPC(t *testing.T) {
 	})
 }
 
+func TestClient_RPC_Pool(t *testing.T) {
+	dir1, s1 := testServer(t)
+	defer os.RemoveAll(dir1)
+	defer s1.Shutdown()
+
+	dir2, c1 := testClient(t)
+	defer os.RemoveAll(dir2)
+	defer c1.Shutdown()
+
+	// Try to join.
+	addr := fmt.Sprintf("127.0.0.1:%d",
+		s1.config.SerfLANConfig.MemberlistConfig.BindPort)
+	if _, err := c1.JoinLAN([]string{addr}); err != nil {
+		t.Fatalf("err: %v", err)
+	}
+	if len(s1.LANMembers()) != 2 || len(c1.LANMembers()) != 2 {
+		t.Fatalf("bad len")
+	}
+
+	// Blast out a bunch of RPC requests at the same time to try to get
+	// contention opening new connections.
+	var wg sync.WaitGroup
+	for i := 0; i < 150; i++ {
+		wg.Add(1)
+
+		go func() {
+			defer wg.Done()
+			var out struct{}
+			testutil.WaitForResult(func() (bool, error) {
+				err := c1.RPC("Status.Ping", struct{}{}, &out)
+				return err == nil, err
+			}, func(err error) {
+				t.Fatalf("err: %v", err)
+			})
+		}()
+	}
+
+	wg.Wait()
+}
+
 func TestClient_RPC_TLS(t *testing.T) {
 	dir1, conf1 := testServerConfig(t, "a.testco.internal")
 	conf1.VerifyIncoming = true
diff --git a/consul/pool.go b/consul/pool.go
index 3512fa621270..0cd0a99dfd01 100644
--- a/consul/pool.go
+++ b/consul/pool.go
@@ -114,6 +114,12 @@ func (c *Conn) returnClient(client *StreamClient) {
 	}
 }
 
+// markForUse does all the bookkeeping required to ready a connection for use.
+func (c *Conn) markForUse() {
+	c.lastUsed = time.Now()
+	atomic.AddInt32(&c.refCount, 1)
+}
+
 // ConnPool is used to maintain a connection pool to other
 // Consul servers. This is used to reduce the latency of
 // RPC requests between servers. It is only used to pool
@@ -134,6 +140,12 @@ type ConnPool struct {
 	// Pool maps an address to a open connection
 	pool map[string]*Conn
 
+	// limiter is used to throttle the number of connect attempts
+	// to a given address. The first thread will attempt a connection
+	// and put a channel in here, which all other threads will wait
+	// on to close.
+	limiter map[string]chan struct{}
+
 	// TLS wrapper
 	tlsWrap tlsutil.DCWrapper
 
@@ -153,6 +165,7 @@ func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap
 		maxTime:    maxTime,
 		maxStreams: maxStreams,
 		pool:       make(map[string]*Conn),
+		limiter:    make(map[string]chan struct{}),
 		tlsWrap:    tlsWrap,
 		shutdownCh: make(chan struct{}),
 	}
@@ -180,28 +193,69 @@ func (p *ConnPool) Shutdown() error {
 	return nil
 }
 
-// Acquire is used to get a connection that is
-// pooled or to return a new connection
+// acquire will return a pooled connection, if available. Otherwise it will
+// wait for an existing connection attempt to finish, if one if in progress,
+// and will return that one if it succeeds. If all else fails, it will return a
+// newly-created connection and add it to the pool.
 func (p *ConnPool) acquire(dc string, addr net.Addr, version int) (*Conn, error) {
-	// Check for a pooled ocnn
-	if conn := p.getPooled(addr, version); conn != nil {
-		return conn, nil
+	// Check to see if there's a pooled connection available. This is up
+	// here since it should the the vastly more common case than the rest
+	// of the code here.
+	p.Lock()
+	c := p.pool[addr.String()]
+	if c != nil {
+		c.markForUse()
+		p.Unlock()
+		return c, nil
 	}
 
-	// Create a new connection
-	return p.getNewConn(dc, addr, version)
-}
+	// If not (while we are still locked), set up the throttling structure
+	// for this address, which will make everyone else wait until our
+	// attempt is done.
+	var wait chan struct{}
+	var ok bool
+	if wait, ok = p.limiter[addr.String()]; !ok {
+		wait = make(chan struct{})
+		p.limiter[addr.String()] = wait
+	}
+	isLeadThread := !ok
+	p.Unlock()
+
+	// If we are the lead thread, make the new connection and then wake
+	// everybody else up to see if we got it.
+	if isLeadThread {
+		c, err := p.getNewConn(dc, addr, version)
+		p.Lock()
+		delete(p.limiter, addr.String())
+		close(wait)
+		if err != nil {
+			p.Unlock()
+			return nil, err
+		}
+
+		p.pool[addr.String()] = c
+		p.Unlock()
+		return c, nil
+	}
+
+	// Otherwise, wait for the lead thread to attempt the connection
+	// and use what's in the pool at that point.
+	select {
+	case <-p.shutdownCh:
+		return nil, fmt.Errorf("rpc error: shutdown")
+	case <-wait:
+	}
 
-// getPooled is used to return a pooled connection
-func (p *ConnPool) getPooled(addr net.Addr, version int) *Conn {
+	// See if the lead thread was able to get us a connection.
 	p.Lock()
-	c := p.pool[addr.String()]
-	if c != nil {
-		c.lastUsed = time.Now()
-		atomic.AddInt32(&c.refCount, 1)
+	if c := p.pool[addr.String()]; c != nil {
+		c.markForUse()
+		p.Unlock()
+		return c, nil
 	}
+
 	p.Unlock()
-	return c
+	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
 }
 
 // getNewConn is used to return a new connection
@@ -272,18 +326,7 @@ func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int) (*Conn, err
 		version:  version,
 		pool:     p,
 	}
-
-	// Track this connection, handle potential race condition
-	p.Lock()
-	if existing := p.pool[addr.String()]; existing != nil {
-		c.Close()
-		p.Unlock()
-		return existing, nil
-	} else {
-		p.pool[addr.String()] = c
-		p.Unlock()
-		return c, nil
-	}
+	return c, nil
 }
 
 // clearConn is used to clear any cached connection, potentially in response to an erro