Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new verify_server_hostname to mitigate possibility of MITM #927

Merged
merged 11 commits into from
May 12, 2015
2 changes: 2 additions & 0 deletions command/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,12 @@ func (a *Agent) consulConfig() *consul.Config {
// Copy the TLS configuration
base.VerifyIncoming = a.config.VerifyIncoming
base.VerifyOutgoing = a.config.VerifyOutgoing
base.VerifyServerHostname = a.config.VerifyServerHostname
base.CAFile = a.config.CAFile
base.CertFile = a.config.CertFile
base.KeyFile = a.config.KeyFile
base.ServerName = a.config.ServerName
base.Domain = a.config.Domain

// Setup the ServerUp callback
base.ServerUp = a.state.ConsulServerUp
Expand Down
11 changes: 11 additions & 0 deletions command/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ type Config struct {
// certificate authority. This is used to verify authenticity of server nodes.
VerifyOutgoing bool `mapstructure:"verify_outgoing"`

// VerifyServerHostname is used to enable hostname verification of servers. This
// ensures that the certificate presented is valid for server.<datacenter>.<domain>.
// This prevents a compromised client from being restarted as a server, and then
// intercepting request traffic as well as being added as a raft peer. This should be
// enabled by default with VerifyOutgoing, but for legacy reasons we cannot break
// existing clients.
VerifyServerHostname bool `mapstructure:"verify_server_hostname"`

// CAFile is a path to a certificate authority file. This is used with VerifyIncoming
// or VerifyOutgoing to verify the TLS connection.
CAFile string `mapstructure:"ca_file"`
Expand Down Expand Up @@ -838,6 +846,9 @@ func MergeConfig(a, b *Config) *Config {
if b.VerifyOutgoing {
result.VerifyOutgoing = true
}
if b.VerifyServerHostname {
result.VerifyServerHostname = true
}
if b.CAFile != "" {
result.CAFile = b.CAFile
}
Expand Down
6 changes: 5 additions & 1 deletion command/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func TestDecodeConfig(t *testing.T) {
}

// TLS
input = `{"verify_incoming": true, "verify_outgoing": true}`
input = `{"verify_incoming": true, "verify_outgoing": true, "verify_server_hostname": true}`
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
if err != nil {
t.Fatalf("err: %s", err)
Expand All @@ -259,6 +259,10 @@ func TestDecodeConfig(t *testing.T) {
t.Fatalf("bad: %#v", config)
}

if config.VerifyServerHostname != true {
t.Fatalf("bad: %#v", config)
}

// TLS keys
input = `{"ca_file": "my/ca/file", "cert_file": "my.cert", "key_file": "key.pem", "server_name": "example.com"}`
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
Expand Down
12 changes: 5 additions & 7 deletions consul/client.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package consul

import (
"crypto/tls"
"fmt"
"log"
"math/rand"
Expand Down Expand Up @@ -91,10 +90,9 @@ func NewClient(config *Config) (*Client, error) {
config.LogOutput = os.Stderr
}

// Create the tlsConfig
var tlsConfig *tls.Config
var err error
if tlsConfig, err = config.tlsConfig().OutgoingTLSConfig(); err != nil {
// Create the tls Wrapper
tlsWrap, err := config.tlsConfig().OutgoingTLSWrapper()
if err != nil {
return nil, err
}

Expand All @@ -104,7 +102,7 @@ func NewClient(config *Config) (*Client, error) {
// Create server
c := &Client{
config: config,
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsConfig),
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap),
eventCh: make(chan serf.Event, 256),
logger: logger,
shutdownCh: make(chan struct{}),
Expand Down Expand Up @@ -357,7 +355,7 @@ func (c *Client) RPC(method string, args interface{}, reply interface{}) error {

// Forward to remote Consul
TRY_RPC:
if err := c.connPool.RPC(server.Addr, server.Version, method, args, reply); err != nil {
if err := c.connPool.RPC(c.config.Datacenter, server.Addr, server.Version, method, args, reply); err != nil {
c.lastServer = nil
c.lastRPCTime = time.Time{}
return err
Expand Down
29 changes: 21 additions & 8 deletions consul/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ type Config struct {
// Node name is the name we use to advertise. Defaults to hostname.
NodeName string

// Domain is the DNS domain for the records. Defaults to "consul."
Domain string

// RaftConfig is the configuration used for Raft in the local DC
RaftConfig *raft.Config

Expand Down Expand Up @@ -100,6 +103,14 @@ type Config struct {
// server nodes.
VerifyOutgoing bool

// VerifyServerHostname is used to enable hostname verification of servers. This
// ensures that the certificate presented is valid for server.<datacenter>.<domain>.
// This prevents a compromised client from being restarted as a server, and then
// intercepting request traffic as well as being added as a raft peer. This should be
// enabled by default with VerifyOutgoing, but for legacy reasons we cannot break
// existing clients.
VerifyServerHostname bool

// CAFile is a path to a certificate authority file. This is used with VerifyIncoming
// or VerifyOutgoing to verify the TLS connection.
CAFile string
Expand Down Expand Up @@ -267,13 +278,15 @@ func DefaultConfig() *Config {

func (c *Config) tlsConfig() *tlsutil.Config {
tlsConf := &tlsutil.Config{
VerifyIncoming: c.VerifyIncoming,
VerifyOutgoing: c.VerifyOutgoing,
CAFile: c.CAFile,
CertFile: c.CertFile,
KeyFile: c.KeyFile,
NodeName: c.NodeName,
ServerName: c.ServerName}

VerifyIncoming: c.VerifyIncoming,
VerifyOutgoing: c.VerifyOutgoing,
VerifyServerHostname: c.VerifyServerHostname,
CAFile: c.CAFile,
CertFile: c.CertFile,
KeyFile: c.KeyFile,
NodeName: c.NodeName,
ServerName: c.ServerName,
Domain: c.Domain,
}
return tlsConf
}
27 changes: 13 additions & 14 deletions consul/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package consul

import (
"container/list"
"crypto/tls"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -135,8 +134,8 @@ type ConnPool struct {
// Pool maps an address to a open connection
pool map[string]*Conn

// TLS settings
tlsConfig *tls.Config
// TLS wrapper
tlsWrap tlsutil.DCWrapper

// Used to indicate the pool is shutdown
shutdown bool
Expand All @@ -148,13 +147,13 @@ type ConnPool struct {
// Set maxTime to 0 to disable reaping. maxStreams is used to control
// the number of idle streams allowed.
// If TLS settings are provided outgoing connections use TLS.
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsConfig *tls.Config) *ConnPool {
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.DCWrapper) *ConnPool {
pool := &ConnPool{
logOutput: logOutput,
maxTime: maxTime,
maxStreams: maxStreams,
pool: make(map[string]*Conn),
tlsConfig: tlsConfig,
tlsWrap: tlsWrap,
shutdownCh: make(chan struct{}),
}
if maxTime > 0 {
Expand Down Expand Up @@ -183,14 +182,14 @@ func (p *ConnPool) Shutdown() error {

// Acquire is used to get a connection that is
// pooled or to return a new connection
func (p *ConnPool) acquire(addr net.Addr, version int) (*Conn, error) {
func (p *ConnPool) acquire(dc string, addr net.Addr, version int) (*Conn, error) {
// Check for a pooled ocnn
if conn := p.getPooled(addr, version); conn != nil {
return conn, nil
}

// Create a new connection
return p.getNewConn(addr, version)
return p.getNewConn(dc, addr, version)
}

// getPooled is used to return a pooled connection
Expand All @@ -206,7 +205,7 @@ func (p *ConnPool) getPooled(addr net.Addr, version int) *Conn {
}

// getNewConn is used to return a new connection
func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int) (*Conn, error) {
// Try to dial the conn
conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
if err != nil {
Expand All @@ -220,15 +219,15 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
}

// Check if TLS is enabled
if p.tlsConfig != nil {
if p.tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
return nil, err
}

// Wrap the connection in a TLS client
tlsConn, err := tlsutil.WrapTLSClient(conn, p.tlsConfig)
tlsConn, err := p.tlsWrap(dc, conn)
if err != nil {
conn.Close()
return nil, err
Expand Down Expand Up @@ -314,11 +313,11 @@ func (p *ConnPool) releaseConn(conn *Conn) {
}

// getClient is used to get a usable client for an address and protocol version
func (p *ConnPool) getClient(addr net.Addr, version int) (*Conn, *StreamClient, error) {
func (p *ConnPool) getClient(dc string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
retries := 0
START:
// Try to get a conn first
conn, err := p.acquire(addr, version)
conn, err := p.acquire(dc, addr, version)
if err != nil {
return nil, nil, fmt.Errorf("failed to get conn: %v", err)
}
Expand All @@ -340,9 +339,9 @@ START:
}

// RPC is used to make an RPC call to a remote host
func (p *ConnPool) RPC(addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
func (p *ConnPool) RPC(dc string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
// Get a usable client
conn, sc, err := p.getClient(addr, version)
conn, sc, err := p.getClient(dc, addr, version)
if err != nil {
return fmt.Errorf("rpc error: %v", err)
}
Expand Down
22 changes: 11 additions & 11 deletions consul/raft_rpc.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package consul

import (
"crypto/tls"
"fmt"
"github.com/hashicorp/consul/tlsutil"
"net"
"sync"
"time"

"github.com/hashicorp/consul/tlsutil"
)

// RaftLayer implements the raft.StreamLayer interface,
Expand All @@ -18,8 +18,8 @@ type RaftLayer struct {
// connCh is used to accept connections
connCh chan net.Conn

// TLS configuration
tlsConfig *tls.Config
// TLS wrapper
tlsWrap tlsutil.Wrapper

// Tracks if we are closed
closed bool
Expand All @@ -30,12 +30,12 @@ type RaftLayer struct {
// NewRaftLayer is used to initialize a new RaftLayer which can
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
// then the connection will use TLS.
func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer {
func NewRaftLayer(addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer {
layer := &RaftLayer{
addr: addr,
connCh: make(chan net.Conn),
tlsConfig: tlsConfig,
closeCh: make(chan struct{}),
addr: addr,
connCh: make(chan net.Conn),
tlsWrap: tlsWrap,
closeCh: make(chan struct{}),
}
return layer
}
Expand Down Expand Up @@ -87,15 +87,15 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
}

// Check for tls mode
if l.tlsConfig != nil {
if l.tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
return nil, err
}

// Wrap the connection in a TLS client
conn, err = tlsutil.WrapTLSClient(conn, l.tlsConfig)
conn, err = l.tlsWrap(conn)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions consul/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (s *Server) forwardLeader(method string, args interface{}, reply interface{
if server == nil {
return structs.ErrNoLeader
}
return s.connPool.RPC(server.Addr, server.Version, method, args, reply)
return s.connPool.RPC(s.config.Datacenter, server.Addr, server.Version, method, args, reply)
}

// forwardDC is used to forward an RPC call to a remote DC, or fail if no servers
Expand All @@ -229,7 +229,7 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{

// Forward to remote Consul
metrics.IncrCounter([]string{"consul", "rpc", "cross-dc", dc}, 1)
return s.connPool.RPC(server.Addr, server.Version, method, args, reply)
return s.connPool.RPC(dc, server.Addr, server.Version, method, args, reply)
}

// globalRPC is used to forward an RPC request to one server in each datacenter.
Expand Down
16 changes: 10 additions & 6 deletions consul/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/golang-lru"
"github.com/hashicorp/raft"
"github.com/hashicorp/raft-boltdb"
Expand Down Expand Up @@ -182,9 +183,9 @@ func NewServer(config *Config) (*Server, error) {
config.LogOutput = os.Stderr
}

// Create the tlsConfig for outgoing connections
// Create the tls wrapper for outgoing connections
tlsConf := config.tlsConfig()
tlsConfig, err := tlsConf.OutgoingTLSConfig()
tlsWrap, err := tlsConf.OutgoingTLSWrapper()
if err != nil {
return nil, err
}
Expand All @@ -207,7 +208,7 @@ func NewServer(config *Config) (*Server, error) {
// Create server
s := &Server{
config: config,
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsConfig),
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
eventChLAN: make(chan serf.Event, 256),
eventChWAN: make(chan serf.Event, 256),
localConsuls: make(map[string]*serverParts),
Expand Down Expand Up @@ -242,7 +243,7 @@ func NewServer(config *Config) (*Server, error) {
}

// Initialize the RPC layer
if err := s.setupRPC(tlsConfig); err != nil {
if err := s.setupRPC(tlsWrap); err != nil {
s.Shutdown()
return nil, fmt.Errorf("Failed to start RPC layer: %v", err)
}
Expand Down Expand Up @@ -410,7 +411,7 @@ func (s *Server) setupRaft() error {
}

// setupRPC is used to setup the RPC listener
func (s *Server) setupRPC(tlsConfig *tls.Config) error {
func (s *Server) setupRPC(tlsWrap tlsutil.DCWrapper) error {
// Create endpoints
s.endpoints.Status = &Status{s}
s.endpoints.Catalog = &Catalog{s}
Expand Down Expand Up @@ -453,7 +454,10 @@ func (s *Server) setupRPC(tlsConfig *tls.Config) error {
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
}

s.raftLayer = NewRaftLayer(advertise, tlsConfig)
// Provide a DC specific wrapper. Raft replication is only
// ever done in the same datacenter, so we can provide it as a constant.
wrapper := tlsutil.SpecificDC(s.config.Datacenter, tlsWrap)
s.raftLayer = NewRaftLayer(advertise, wrapper)
return nil
}

Expand Down
Loading