Skip to content

Commit

Permalink
Add a path for transitioning to TLS on an existing cluster (#3001)
Browse files Browse the repository at this point in the history
Fixes #1705
  • Loading branch information
kyhavlov authored May 10, 2017
1 parent 6eba69f commit 5bab68b
Show file tree
Hide file tree
Showing 16 changed files with 224 additions and 43 deletions.
3 changes: 3 additions & 0 deletions command/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ func (a *Agent) consulConfig() (*consul.Config, error) {

// Copy the TLS configuration
base.VerifyIncoming = a.config.VerifyIncoming || a.config.VerifyIncomingRPC
if a.config.CAPath != "" || a.config.CAFile != "" {
base.UseTLS = true
}
base.VerifyOutgoing = a.config.VerifyOutgoing
base.VerifyServerHostname = a.config.VerifyServerHostname
base.CAFile = a.config.CAFile
Expand Down
6 changes: 6 additions & 0 deletions consul/agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ type Server struct {
NonVoter bool
Addr net.Addr
Status serf.MemberStatus

// If true, use TLS when connecting to this server
UseTLS bool
}

// Key returns the corresponding Key
Expand Down Expand Up @@ -72,6 +75,8 @@ func IsConsulServer(m serf.Member) (bool, *Server) {
datacenter := m.Tags["dc"]
_, bootstrap := m.Tags["bootstrap"]

_, useTLS := m.Tags["use_tls"]

expect := 0
expect_str, ok := m.Tags["expect"]
var err error
Expand Down Expand Up @@ -135,6 +140,7 @@ func IsConsulServer(m serf.Member) (bool, *Server) {
RaftVersion: raft_vsn,
Status: m.Status,
NonVoter: nonVoter,
UseTLS: useTLS,
}
return true, parts
}
4 changes: 4 additions & 0 deletions consul/agent/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func TestIsConsulServer(t *testing.T) {
"vsn": "1",
"expect": "3",
"raft_vsn": "3",
"use_tls": "1",
},
Status: serf.StatusLeft,
}
Expand Down Expand Up @@ -95,6 +96,9 @@ func TestIsConsulServer(t *testing.T) {
if parts.Status != serf.StatusLeft {
t.Fatalf("bad: %v", parts.Status)
}
if !parts.UseTLS {
t.Fatalf("bad: %v", parts.UseTLS)
}
m.Tags["bootstrap"] = "1"
m.Tags["disabled"] = "1"
ok, parts = agent.IsConsulServer(m)
Expand Down
6 changes: 3 additions & 3 deletions consul/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func NewClient(config *Config) (*Client, error) {
// Create server
c := &Client{
config: config,
connPool: NewPool(config.RPCSrcAddr, config.LogOutput, clientRPCConnMaxIdle, clientMaxStreams, tlsWrap),
connPool: NewPool(config.RPCSrcAddr, config.LogOutput, clientRPCConnMaxIdle, clientMaxStreams, tlsWrap, config.VerifyOutgoing),
eventCh: make(chan serf.Event, serfEventBacklog),
logger: logger,
shutdownCh: make(chan struct{}),
Expand Down Expand Up @@ -334,7 +334,7 @@ func (c *Client) RPC(method string, args interface{}, reply interface{}) error {
}

// Forward to remote Consul
if err := c.connPool.RPC(c.config.Datacenter, server.Addr, server.Version, method, args, reply); err != nil {
if err := c.connPool.RPC(c.config.Datacenter, server.Addr, server.Version, method, server.UseTLS, args, reply); err != nil {
c.servers.NotifyFailedServer(server)
c.logger.Printf("[ERR] consul: RPC failed to server %s: %v", server.Addr, err)
return err
Expand All @@ -361,7 +361,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io

// Request the operation.
var reply structs.SnapshotResponse
snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.Addr, args, in, &reply)
snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.Addr, server.UseTLS, args, in, &reply)
if err != nil {
return err
}
Expand Down
10 changes: 7 additions & 3 deletions consul/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,15 @@ type Config struct {
// must match a provided certificate authority. This can be used to force client auth.
VerifyIncoming bool

// VerifyOutgoing is used to verify the authenticity of outgoing connections.
// VerifyOutgoing is used to force verification of the authenticity of outgoing connections.
// This means that TLS requests are used, and TCP requests are not made. TLS connections
// must match a provided certificate authority. This is used to verify authenticity of
// server nodes.
// must match a provided certificate authority.
VerifyOutgoing bool

// UseTLS is used to enable TLS for outgoing connections to other TLS-capable Consul
// servers. This doesn't imply any verification, it only enables TLS if possible.
UseTLS bool

// VerifyServerHostname is used to enable hostname verification of servers. This
// ensures that the certificate presented is valid for server.<datacenter>.<domain>.
// This prevents a compromised client from being restarted as a server, and then
Expand Down Expand Up @@ -439,6 +442,7 @@ func (c *Config) tlsConfig() *tlsutil.Config {
VerifyIncoming: c.VerifyIncoming,
VerifyOutgoing: c.VerifyOutgoing,
VerifyServerHostname: c.VerifyServerHostname,
UseTLS: c.UseTLS,
CAFile: c.CAFile,
CAPath: c.CAPath,
CertFile: c.CertFile,
Expand Down
28 changes: 16 additions & 12 deletions consul/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ type ConnPool struct {
// TLS wrapper
tlsWrap tlsutil.DCWrapper

// forceTLS is used to enforce outgoing TLS verification
forceTLS bool

// Used to indicate the pool is shutdown
shutdown bool
shutdownCh chan struct{}
Expand All @@ -154,7 +157,7 @@ type ConnPool struct {
// Set maxTime to 0 to disable reaping. maxStreams is used to control
// the number of idle streams allowed.
// If TLS settings are provided outgoing connections use TLS.
func NewPool(src *net.TCPAddr, logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.DCWrapper) *ConnPool {
func NewPool(src *net.TCPAddr, logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.DCWrapper, forceTLS bool) *ConnPool {
pool := &ConnPool{
src: src,
logOutput: logOutput,
Expand All @@ -163,6 +166,7 @@ func NewPool(src *net.TCPAddr, logOutput io.Writer, maxTime time.Duration, maxSt
pool: make(map[string]*Conn),
limiter: make(map[string]chan struct{}),
tlsWrap: tlsWrap,
forceTLS: forceTLS,
shutdownCh: make(chan struct{}),
}
if maxTime > 0 {
Expand Down Expand Up @@ -193,7 +197,7 @@ func (p *ConnPool) Shutdown() error {
// wait for an existing connection attempt to finish, if one if in progress,
// and will return that one if it succeeds. If all else fails, it will return a
// newly-created connection and add it to the pool.
func (p *ConnPool) acquire(dc string, addr net.Addr, version int) (*Conn, error) {
func (p *ConnPool) acquire(dc string, addr net.Addr, version int, useTLS bool) (*Conn, error) {
addrStr := addr.String()

// Check to see if there's a pooled connection available. This is up
Expand Down Expand Up @@ -222,7 +226,7 @@ func (p *ConnPool) acquire(dc string, addr net.Addr, version int) (*Conn, error)
// If we are the lead thread, make the new connection and then wake
// everybody else up to see if we got it.
if isLeadThread {
c, err := p.getNewConn(dc, addr, version)
c, err := p.getNewConn(dc, addr, version, useTLS)
p.Lock()
delete(p.limiter, addrStr)
close(wait)
Expand Down Expand Up @@ -267,7 +271,7 @@ type HalfCloser interface {

// DialTimeout is used to establish a raw connection to the given server, with a
// given connection timeout.
func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration) (net.Conn, HalfCloser, error) {
func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration, useTLS bool) (net.Conn, HalfCloser, error) {
// Try to dial the conn
d := &net.Dialer{LocalAddr: p.src, Timeout: timeout}
conn, err := d.Dial("tcp", addr.String())
Expand All @@ -284,7 +288,7 @@ func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration)
}

// Check if TLS is enabled
if p.tlsWrap != nil {
if (useTLS || p.forceTLS) && p.tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
Expand All @@ -304,9 +308,9 @@ func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration)
}

// getNewConn is used to return a new connection
func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int) (*Conn, error) {
func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int, useTLS bool) (*Conn, error) {
// Get a new, raw connection.
conn, _, err := p.DialTimeout(dc, addr, defaultDialTimeout)
conn, _, err := p.DialTimeout(dc, addr, defaultDialTimeout, useTLS)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -372,11 +376,11 @@ func (p *ConnPool) releaseConn(conn *Conn) {
}

// getClient is used to get a usable client for an address and protocol version
func (p *ConnPool) getClient(dc string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
func (p *ConnPool) getClient(dc string, addr net.Addr, version int, useTLS bool) (*Conn, *StreamClient, error) {
retries := 0
START:
// Try to get a conn first
conn, err := p.acquire(dc, addr, version)
conn, err := p.acquire(dc, addr, version, useTLS)
if err != nil {
return nil, nil, fmt.Errorf("failed to get conn: %v", err)
}
Expand All @@ -398,9 +402,9 @@ START:
}

// RPC is used to make an RPC call to a remote host
func (p *ConnPool) RPC(dc string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
func (p *ConnPool) RPC(dc string, addr net.Addr, version int, method string, useTLS bool, args interface{}, reply interface{}) error {
// Get a usable client
conn, sc, err := p.getClient(dc, addr, version)
conn, sc, err := p.getClient(dc, addr, version, useTLS)
if err != nil {
return fmt.Errorf("rpc error: %v", err)
}
Expand All @@ -423,7 +427,7 @@ func (p *ConnPool) RPC(dc string, addr net.Addr, version int, method string, arg
// returns true if healthy, false if an error occurred
func (p *ConnPool) PingConsulServer(s *agent.Server) (bool, error) {
// Get a usable client
conn, sc, err := p.getClient(s.Datacenter, s.Addr, s.Version)
conn, sc, err := p.getClient(s.Datacenter, s.Addr, s.Version, s.UseTLS)
if err != nil {
return false, err
}
Expand Down
9 changes: 7 additions & 2 deletions consul/raft_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,23 @@ type RaftLayer struct {
closed bool
closeCh chan struct{}
closeLock sync.Mutex

// tlsFunc is a callback to determine whether to use TLS for connecting to
// a given Raft server
tlsFunc func(raft.ServerAddress) bool
}

// NewRaftLayer is used to initialize a new RaftLayer which can
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
// then the connection will use TLS.
func NewRaftLayer(src, addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer {
func NewRaftLayer(src, addr net.Addr, tlsWrap tlsutil.Wrapper, tlsFunc func(raft.ServerAddress) bool) *RaftLayer {
layer := &RaftLayer{
src: src,
addr: addr,
connCh: make(chan net.Conn),
tlsWrap: tlsWrap,
closeCh: make(chan struct{}),
tlsFunc: tlsFunc,
}
return layer
}
Expand Down Expand Up @@ -93,7 +98,7 @@ func (l *RaftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net
}

// Check for tls mode
if l.tlsWrap != nil {
if l.tlsFunc(address) && l.tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
Expand Down
4 changes: 2 additions & 2 deletions consul/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func (s *Server) forwardLeader(server *agent.Server, method string, args interfa
if server == nil {
return structs.ErrNoLeader
}
return s.connPool.RPC(s.config.Datacenter, server.Addr, server.Version, method, args, reply)
return s.connPool.RPC(s.config.Datacenter, server.Addr, server.Version, method, server.UseTLS, args, reply)
}

// forwardDC is used to forward an RPC call to a remote DC, or fail if no servers
Expand All @@ -274,7 +274,7 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{
}

metrics.IncrCounter([]string{"consul", "rpc", "cross-dc", dc}, 1)
if err := s.connPool.RPC(dc, server.Addr, server.Version, method, args, reply); err != nil {
if err := s.connPool.RPC(dc, server.Addr, server.Version, method, server.UseTLS, args, reply); err != nil {
manager.NotifyFailedServer(server)
s.logger.Printf("[ERR] consul: RPC failed to server %s in DC %q: %v", server.Addr, dc, err)
return err
Expand Down
2 changes: 1 addition & 1 deletion consul/serf.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (s *Server) maybeBootstrap() {
// Retry with exponential backoff to get peer status from this server
for attempt := uint(0); attempt < maxPeerRetries; attempt++ {
if err := s.connPool.RPC(s.config.Datacenter, server.Addr, server.Version,
"Status.Peers", &struct{}{}, &peers); err != nil {
"Status.Peers", server.UseTLS, &struct{}{}, &peers); err != nil {
nextRetry := time.Duration((1 << attempt) * peerRetryBase)
s.logger.Printf("[ERR] consul: Failed to confirm peer status for %s: %v. Retrying in "+
"%v...", server.Name, err, nextRetry.String())
Expand Down
29 changes: 27 additions & 2 deletions consul/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ func NewServer(config *Config) (*Server, error) {
}
logger := log.New(config.LogOutput, "", log.LstdFlags)

// Check if TLS is enabled
if config.CAFile != "" || config.CAPath != "" {
config.UseTLS = true
}

// Create the TLS wrapper for outgoing connections.
tlsConf := config.tlsConfig()
tlsWrap, err := tlsConf.OutgoingTLSWrapper()
Expand Down Expand Up @@ -261,7 +266,7 @@ func NewServer(config *Config) (*Server, error) {
autopilotRemoveDeadCh: make(chan struct{}),
autopilotShutdownCh: make(chan struct{}),
config: config,
connPool: NewPool(config.RPCSrcAddr, config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
connPool: NewPool(config.RPCSrcAddr, config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap, config.VerifyOutgoing),
eventChLAN: make(chan serf.Event, 256),
eventChWAN: make(chan serf.Event, 256),
localConsuls: make(map[raft.ServerAddress]*agent.Server),
Expand Down Expand Up @@ -393,6 +398,9 @@ func (s *Server) setupSerf(conf *serf.Config, ch chan serf.Event, path string, w
if s.config.NonVoter {
conf.Tags["nonvoter"] = "1"
}
if s.config.UseTLS {
conf.Tags["use_tls"] = "1"
}
conf.MemberlistConfig.LogOutput = s.config.LogOutput
conf.LogOutput = s.config.LogOutput
conf.EventCh = ch
Expand Down Expand Up @@ -626,7 +634,24 @@ func (s *Server) setupRPC(tlsWrap tlsutil.DCWrapper) error {
// Provide a DC specific wrapper. Raft replication is only
// ever done in the same datacenter, so we can provide it as a constant.
wrapper := tlsutil.SpecificDC(s.config.Datacenter, tlsWrap)
s.raftLayer = NewRaftLayer(s.config.RPCSrcAddr, s.config.RPCAdvertise, wrapper)

// Define a callback for determining whether to wrap a connection with TLS
tlsFunc := func(address raft.ServerAddress) bool {
if s.config.VerifyOutgoing {
return true
}

s.localLock.RLock()
server, ok := s.localConsuls[address]
s.localLock.RUnlock()

if !ok {
return false
}

return server.UseTLS
}
s.raftLayer = NewRaftLayer(s.config.RPCSrcAddr, s.config.RPCAdvertise, wrapper, tlsFunc)
return nil
}

Expand Down
Loading

0 comments on commit 5bab68b

Please sign in to comment.