Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #1165 by having threads wait for any outstanding connect to finish. #1170

Merged
merged 5 commits into from
Aug 13, 2015
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions consul/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"os"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -189,6 +190,46 @@ func TestClient_RPC(t *testing.T) {
})
}

func TestClient_RPC_Pool(t *testing.T) {
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()

dir2, c1 := testClient(t)
defer os.RemoveAll(dir2)
defer c1.Shutdown()

// Try to join.
addr := fmt.Sprintf("127.0.0.1:%d",
s1.config.SerfLANConfig.MemberlistConfig.BindPort)
if _, err := c1.JoinLAN([]string{addr}); err != nil {
t.Fatalf("err: %v", err)
}
if len(s1.LANMembers()) != 2 || len(c1.LANMembers()) != 2 {
t.Fatalf("bad len")
}

// Blast out a bunch of RPC requests at the same time to try to get
// contention opening new connections.
var wg sync.WaitGroup
for i := 0; i < 150; i++ {
wg.Add(1)

go func() {
defer wg.Done()
var out struct{}
testutil.WaitForResult(func() (bool, error) {
err := c1.RPC("Status.Ping", struct{}{}, &out)
return err == nil, err
}, func(err error) {
t.Fatalf("err: %v", err)
})
}()
}

wg.Wait()
}

func TestClient_RPC_TLS(t *testing.T) {
dir1, conf1 := testServerConfig(t, "a.testco.internal")
conf1.VerifyIncoming = true
Expand Down
97 changes: 70 additions & 27 deletions consul/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ func (c *Conn) returnClient(client *StreamClient) {
}
}

// markForUse does all the bookkeeping required to ready a connection for use.
func (c *Conn) markForUse() {
c.lastUsed = time.Now()
atomic.AddInt32(&c.refCount, 1)
}

// ConnPool is used to maintain a connection pool to other
// Consul servers. This is used to reduce the latency of
// RPC requests between servers. It is only used to pool
Expand All @@ -134,6 +140,12 @@ type ConnPool struct {
// Pool maps an address to a open connection
pool map[string]*Conn

// limiter is used to throttle the number of connect attempts
// to a given address. The first thread will attempt a connection
// and put a channel in here, which all other threads will wait
// on to close.
limiter map[string]chan struct{}

// TLS wrapper
tlsWrap tlsutil.DCWrapper

Expand All @@ -153,6 +165,7 @@ func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap
maxTime: maxTime,
maxStreams: maxStreams,
pool: make(map[string]*Conn),
limiter: make(map[string]chan struct{}),
tlsWrap: tlsWrap,
shutdownCh: make(chan struct{}),
}
Expand Down Expand Up @@ -180,28 +193,69 @@ func (p *ConnPool) Shutdown() error {
return nil
}

// Acquire is used to get a connection that is
// pooled or to return a new connection
// acquire will return a pooled connection, if available. Otherwise it will
// wait for an existing connection attempt to finish, if one if in progress,
// and will return that one if it succeeds. If all else fails, it will return a
// newly-created connection and add it to the pool.
func (p *ConnPool) acquire(dc string, addr net.Addr, version int) (*Conn, error) {
// Check for a pooled ocnn
if conn := p.getPooled(addr, version); conn != nil {
return conn, nil
// Check to see if there's a pooled connection available. This is up
// here since it should the the vastly more common case than the rest
// of the code here.
p.Lock()
c := p.pool[addr.String()]
if c != nil {
c.markForUse()
p.Unlock()
return c, nil
}

// Create a new connection
return p.getNewConn(dc, addr, version)
}
// If not (while we are still locked), set up the throttling structure
// for this address, which will make everyone else wait until our
// attempt is done.
var wait chan struct{}
var ok bool
if wait, ok = p.limiter[addr.String()]; !ok {
wait = make(chan struct{}, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need to buffer it, since we just close

p.limiter[addr.String()] = wait
}
isLeadThread := !ok
p.Unlock()

// If we are the lead thread, make the new connection and then wake
// everybody else up to see if we got it.
if isLeadThread {
c, err := p.getNewConn(dc, addr, version)
p.Lock()
delete(p.limiter, addr.String())
close(wait)
if err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor, but can we move the code in the defer to after getNewConn and delete the limiter entry and add the pool entry inside the same critical section? Just to save a defer and another round of locking

p.Unlock()
return nil, err
}

p.pool[addr.String()] = c
p.Unlock()
return c, nil
}

// Otherwise, wait for the lead thread to attempt the connection
// and use what's in the pool at that point.
select {
case <-p.shutdownCh:
return nil, fmt.Errorf("rpc error: shutdown")
case <-wait:
}

// getPooled is used to return a pooled connection
func (p *ConnPool) getPooled(addr net.Addr, version int) *Conn {
// See if the lead thread was able to get us a connection.
p.Lock()
c := p.pool[addr.String()]
if c != nil {
c.lastUsed = time.Now()
atomic.AddInt32(&c.refCount, 1)
if c := p.pool[addr.String()]; c != nil {
c.markForUse()
p.Unlock()
return c, nil
}

p.Unlock()
return c
return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
}

// getNewConn is used to return a new connection
Expand Down Expand Up @@ -272,18 +326,7 @@ func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int) (*Conn, err
version: version,
pool: p,
}

// Track this connection, handle potential race condition
p.Lock()
if existing := p.pool[addr.String()]; existing != nil {
c.Close()
p.Unlock()
return existing, nil
} else {
p.pool[addr.String()] = c
p.Unlock()
return c, nil
}
return c, nil
}

// clearConn is used to clear any cached connection, potentially in response to an erro
Expand Down