diff --git a/agent/agent.go b/agent/agent.go index 58e374e72a80..dec221bff59b 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -393,7 +393,11 @@ func (a *Agent) Start() error { // waiting to discover a consul server consulCfg.ServerUp = a.sync.SyncFull.Trigger - a.tlsConfigurator = tlsutil.NewConfigurator(c.ToTLSUtilConfig()) + tlsConfigurator, err := tlsutil.NewConfigurator(c.ToTLSUtilConfig(), a.logger) + if err != nil { + return err + } + a.tlsConfigurator = tlsConfigurator // Setup either the client or the server. if c.ServerMode { @@ -662,10 +666,7 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { var tlscfg *tls.Config _, isTCP := l.(*tcpKeepAliveListener) if isTCP && proto == "https" { - tlscfg, err = a.tlsConfigurator.IncomingHTTPSConfig() - if err != nil { - return err - } + tlscfg = a.tlsConfigurator.IncomingHTTPSConfig() l = tls.NewListener(l, tlscfg) } srv := &HTTPServer{ @@ -2232,11 +2233,7 @@ func (a *Agent) addCheck(check *structs.HealthCheck, chkType *structs.CheckType, chkType.Interval = checks.MinInterval } - a.tlsConfigurator.AddCheck(string(check.CheckID), chkType.TLSSkipVerify) - tlsClientConfig, err := a.tlsConfigurator.OutgoingTLSConfigForCheck(string(check.CheckID)) - if err != nil { - return fmt.Errorf("Failed to set up TLS: %v", err) - } + tlsClientConfig := a.tlsConfigurator.OutgoingTLSConfigForCheck(chkType.TLSSkipVerify) http := &checks.CheckHTTP{ Notify: a.State, @@ -2287,12 +2284,7 @@ func (a *Agent) addCheck(check *structs.HealthCheck, chkType *structs.CheckType, var tlsClientConfig *tls.Config if chkType.GRPCUseTLS { - var err error - a.tlsConfigurator.AddCheck(string(check.CheckID), chkType.TLSSkipVerify) - tlsClientConfig, err = a.tlsConfigurator.OutgoingTLSConfigForCheck(string(check.CheckID)) - if err != nil { - return fmt.Errorf("Failed to set up TLS: %v", err) - } + tlsClientConfig = a.tlsConfigurator.OutgoingTLSConfigForCheck(chkType.TLSSkipVerify) } grpc := &checks.CheckGRPC{ @@ -2431,7 +2423,6 @@ func (a *Agent) removeCheckLocked(checkID types.CheckID, persist bool) error { return fmt.Errorf("CheckID missing") } - a.tlsConfigurator.RemoveCheck(string(checkID)) a.cancelCheckMonitors(checkID) a.State.RemoveCheck(checkID) @@ -3559,6 +3550,10 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error { // the checks and service registrations. a.loadTokens(newCfg) + if err := a.tlsConfigurator.Update(newCfg.ToTLSUtilConfig()); err != nil { + return fmt.Errorf("Failed reloading tls configuration: %s", err) + } + // Reload service/check definitions and metadata. if err := a.loadServices(newCfg); err != nil { return fmt.Errorf("Failed reloading services: %s", err) diff --git a/agent/agent_test.go b/agent/agent_test.go index 1046dc10a08c..ee364d90aa94 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "crypto/tls" "encoding/json" "fmt" "io/ioutil" @@ -3366,11 +3367,7 @@ func TestAgent_SetupProxyManager(t *testing.T) { ports { http = -1 } data_dir = "` + dataDir + `" ` - c := TestConfig( - // randomPortsSource(false), - config.Source{Name: t.Name(), Format: "hcl", Data: hcl}, - ) - a, err := New(c) + a, err := NewUnstartedAgent(t, t.Name(), hcl) require.NoError(t, err) require.Error(t, a.setupProxyManager(), "setupProxyManager should fail with invalid HTTP API config") @@ -3378,11 +3375,7 @@ func TestAgent_SetupProxyManager(t *testing.T) { ports { http = 8001 } data_dir = "` + dataDir + `" ` - c = TestConfig( - // randomPortsSource(false), - config.Source{Name: t.Name(), Format: "hcl", Data: hcl}, - ) - a, err = New(c) + a, err = NewUnstartedAgent(t, t.Name(), hcl) require.NoError(t, err) require.NoError(t, a.setupProxyManager()) } @@ -3543,3 +3536,107 @@ func TestAgent_loadTokens(t *testing.T) { require.Equal("foxtrot", a.tokens.ReplicationToken()) }) } + +func TestAgent_ReloadConfigOutgoingRPCConfig(t *testing.T) { + t.Parallel() + dataDir := testutil.TempDir(t, "agent") // we manage the data dir + defer os.RemoveAll(dataDir) + hcl := ` + data_dir = "` + dataDir + `" + verify_outgoing = true + ca_file = "../test/ca/root.cer" + cert_file = "../test/key/ourdomain.cer" + key_file = "../test/key/ourdomain.key" + verify_server_hostname = false + ` + a, err := NewUnstartedAgent(t, t.Name(), hcl) + require.NoError(t, err) + tlsConf := a.tlsConfigurator.OutgoingRPCConfig() + require.True(t, tlsConf.InsecureSkipVerify) + require.Len(t, tlsConf.ClientCAs.Subjects(), 1) + require.Len(t, tlsConf.RootCAs.Subjects(), 1) + + hcl = ` + data_dir = "` + dataDir + `" + verify_outgoing = true + ca_path = "../test/ca_path" + cert_file = "../test/key/ourdomain.cer" + key_file = "../test/key/ourdomain.key" + verify_server_hostname = true + ` + c := TestConfig(config.Source{Name: t.Name(), Format: "hcl", Data: hcl}) + require.NoError(t, a.ReloadConfig(c)) + tlsConf = a.tlsConfigurator.OutgoingRPCConfig() + require.False(t, tlsConf.InsecureSkipVerify) + require.Len(t, tlsConf.RootCAs.Subjects(), 2) + require.Len(t, tlsConf.ClientCAs.Subjects(), 2) +} + +func TestAgent_ReloadConfigIncomingRPCConfig(t *testing.T) { + t.Parallel() + dataDir := testutil.TempDir(t, "agent") // we manage the data dir + defer os.RemoveAll(dataDir) + hcl := ` + data_dir = "` + dataDir + `" + verify_outgoing = true + ca_file = "../test/ca/root.cer" + cert_file = "../test/key/ourdomain.cer" + key_file = "../test/key/ourdomain.key" + verify_server_hostname = false + ` + a, err := NewUnstartedAgent(t, t.Name(), hcl) + require.NoError(t, err) + tlsConf := a.tlsConfigurator.IncomingRPCConfig() + require.NotNil(t, tlsConf.GetConfigForClient) + tlsConf, err = tlsConf.GetConfigForClient(nil) + require.NoError(t, err) + require.NotNil(t, tlsConf) + require.True(t, tlsConf.InsecureSkipVerify) + require.Len(t, tlsConf.ClientCAs.Subjects(), 1) + require.Len(t, tlsConf.RootCAs.Subjects(), 1) + + hcl = ` + data_dir = "` + dataDir + `" + verify_outgoing = true + ca_path = "../test/ca_path" + cert_file = "../test/key/ourdomain.cer" + key_file = "../test/key/ourdomain.key" + verify_server_hostname = true + ` + c := TestConfig(config.Source{Name: t.Name(), Format: "hcl", Data: hcl}) + require.NoError(t, a.ReloadConfig(c)) + tlsConf, err = tlsConf.GetConfigForClient(nil) + require.NoError(t, err) + require.False(t, tlsConf.InsecureSkipVerify) + require.Len(t, tlsConf.ClientCAs.Subjects(), 2) + require.Len(t, tlsConf.RootCAs.Subjects(), 2) +} + +func TestAgent_ReloadConfigTLSConfigFailure(t *testing.T) { + t.Parallel() + dataDir := testutil.TempDir(t, "agent") // we manage the data dir + defer os.RemoveAll(dataDir) + hcl := ` + data_dir = "` + dataDir + `" + verify_outgoing = true + ca_file = "../test/ca/root.cer" + cert_file = "../test/key/ourdomain.cer" + key_file = "../test/key/ourdomain.key" + verify_server_hostname = false + ` + a, err := NewUnstartedAgent(t, t.Name(), hcl) + require.NoError(t, err) + tlsConf := a.tlsConfigurator.IncomingRPCConfig() + + hcl = ` + data_dir = "` + dataDir + `" + verify_incoming = true + ` + c := TestConfig(config.Source{Name: t.Name(), Format: "hcl", Data: hcl}) + require.Error(t, a.ReloadConfig(c)) + tlsConf, err = tlsConf.GetConfigForClient(nil) + require.NoError(t, err) + require.Equal(t, tls.NoClientCert, tlsConf.ClientAuth) + require.Len(t, tlsConf.ClientCAs.Subjects(), 1) + require.Len(t, tlsConf.RootCAs.Subjects(), 1) +} diff --git a/agent/config/runtime.go b/agent/config/runtime.go index 0596034f648b..9aefadc0bd31 100644 --- a/agent/config/runtime.go +++ b/agent/config/runtime.go @@ -1580,12 +1580,13 @@ func (c *RuntimeConfig) Sanitized() map[string]interface{} { return sanitize("rt", reflect.ValueOf(c)).Interface().(map[string]interface{}) } -func (c *RuntimeConfig) ToTLSUtilConfig() *tlsutil.Config { - return &tlsutil.Config{ +func (c *RuntimeConfig) ToTLSUtilConfig() tlsutil.Config { + return tlsutil.Config{ VerifyIncoming: c.VerifyIncoming, VerifyIncomingRPC: c.VerifyIncomingRPC, VerifyIncomingHTTPS: c.VerifyIncomingHTTPS, VerifyOutgoing: c.VerifyOutgoing, + VerifyServerHostname: c.VerifyServerHostname, CAFile: c.CAFile, CAPath: c.CAPath, CertFile: c.CertFile, diff --git a/agent/config/runtime_test.go b/agent/config/runtime_test.go index e9889f541171..6de42304e137 100644 --- a/agent/config/runtime_test.go +++ b/agent/config/runtime_test.go @@ -5434,6 +5434,7 @@ func TestRuntime_ToTLSUtilConfig(t *testing.T) { VerifyIncomingRPC: true, VerifyIncomingHTTPS: true, VerifyOutgoing: true, + VerifyServerHostname: true, CAFile: "a", CAPath: "b", CertFile: "c", @@ -5450,6 +5451,7 @@ func TestRuntime_ToTLSUtilConfig(t *testing.T) { require.Equal(t, c.VerifyIncomingRPC, r.VerifyIncomingRPC) require.Equal(t, c.VerifyIncomingHTTPS, r.VerifyIncomingHTTPS) require.Equal(t, c.VerifyOutgoing, r.VerifyOutgoing) + require.Equal(t, c.VerifyServerHostname, r.VerifyServerHostname) require.Equal(t, c.CAFile, r.CAFile) require.Equal(t, c.CAPath, r.CAPath) require.Equal(t, c.CertFile, r.CertFile) diff --git a/agent/consul/client.go b/agent/consul/client.go index 48e279a12bae..d43fbf08d461 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -86,10 +86,16 @@ type Client struct { EnterpriseClient } -// NewClient is used to construct a new Consul client from the -// configuration, potentially returning an error +// NewClient is used to construct a new Consul client from the configuration, +// potentially returning an error. +// NewClient only used to help setting up a client for testing. Normal code +// exercises NewClientLogger. func NewClient(config *Config) (*Client, error) { - return NewClientLogger(config, nil, tlsutil.NewConfigurator(config.ToTLSUtilConfig())) + c, err := tlsutil.NewConfigurator(config.ToTLSUtilConfig(), nil) + if err != nil { + return nil, err + } + return NewClientLogger(config, nil, c) } func NewClientLogger(config *Config, logger *log.Logger, tlsConfigurator *tlsutil.Configurator) (*Client, error) { @@ -113,12 +119,6 @@ func NewClientLogger(config *Config, logger *log.Logger, tlsConfigurator *tlsuti config.LogOutput = os.Stderr } - // Create the tls Wrapper - tlsWrap, err := tlsConfigurator.OutgoingRPCWrapper() - if err != nil { - return nil, err - } - // Create a logger if logger == nil { logger = log.New(config.LogOutput, "", log.LstdFlags) @@ -129,7 +129,7 @@ func NewClientLogger(config *Config, logger *log.Logger, tlsConfigurator *tlsuti LogOutput: config.LogOutput, MaxTime: clientRPCConnMaxIdle, MaxStreams: clientMaxStreams, - TLSWrapper: tlsWrap, + TLSWrapper: tlsConfigurator.OutgoingRPCWrapper(), ForceTLS: config.VerifyOutgoing, } @@ -158,6 +158,7 @@ func NewClientLogger(config *Config, logger *log.Logger, tlsConfigurator *tlsuti CacheConfig: clientACLCacheConfig, Sentinel: nil, } + var err error if c.acls, err = NewACLResolver(&aclConfig); err != nil { c.Shutdown() return nil, fmt.Errorf("Failed to create ACL resolver: %v", err) diff --git a/agent/consul/config.go b/agent/consul/config.go index 4f8146e44db8..2a7066443452 100644 --- a/agent/consul/config.go +++ b/agent/consul/config.go @@ -379,8 +379,8 @@ type Config struct { CAConfig *structs.CAConfiguration } -func (c *Config) ToTLSUtilConfig() *tlsutil.Config { - return &tlsutil.Config{ +func (c *Config) ToTLSUtilConfig() tlsutil.Config { + return tlsutil.Config{ VerifyIncoming: c.VerifyIncoming, VerifyOutgoing: c.VerifyOutgoing, CAFile: c.CAFile, diff --git a/agent/consul/server.go b/agent/consul/server.go index a81c8310bed5..330321f08a90 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -252,11 +252,17 @@ type Server struct { EnterpriseServer } +// NewServer is only used to help setting up a server for testing. Normal code +// exercises NewServerLogger. func NewServer(config *Config) (*Server, error) { - return NewServerLogger(config, nil, new(token.Store), tlsutil.NewConfigurator(config.ToTLSUtilConfig())) + c, err := tlsutil.NewConfigurator(config.ToTLSUtilConfig(), nil) + if err != nil { + return nil, err + } + return NewServerLogger(config, nil, new(token.Store), c) } -// NewServer is used to construct a new Consul server from the +// NewServerLogger is used to construct a new Consul server from the // configuration, potentially returning an error func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tlsConfigurator *tlsutil.Configurator) (*Server, error) { // Check the protocol version. @@ -296,18 +302,6 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tl } } - // Create the TLS wrapper for outgoing connections. - tlsWrap, err := tlsConfigurator.OutgoingRPCWrapper() - if err != nil { - return nil, err - } - - // Get the incoming TLS config. - incomingTLS, err := tlsConfigurator.IncomingRPCConfig() - if err != nil { - return nil, err - } - // Create the tombstone GC. gc, err := state.NewTombstoneGC(config.TombstoneTTL, config.TombstoneTTLGranularity) if err != nil { @@ -322,7 +316,7 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tl LogOutput: config.LogOutput, MaxTime: serverRPCCache, MaxStreams: serverMaxStreams, - TLSWrapper: tlsWrap, + TLSWrapper: tlsConfigurator.OutgoingRPCWrapper(), ForceTLS: config.VerifyOutgoing, } @@ -338,7 +332,7 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tl reconcileCh: make(chan serf.Member, reconcileChSize), router: router.NewRouter(logger, config.Datacenter), rpcServer: rpc.NewServer(), - rpcTLS: incomingTLS, + rpcTLS: tlsConfigurator.IncomingRPCConfig(), reassertLeaderCh: make(chan chan error), segmentLAN: make(map[string]*serf.Serf, len(config.Segments)), sessionTimers: NewSessionTimers(), @@ -373,7 +367,7 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tl } // Initialize the RPC layer. - if err := s.setupRPC(tlsWrap); err != nil { + if err := s.setupRPC(tlsConfigurator.OutgoingRPCWrapper()); err != nil { s.Shutdown() return nil, fmt.Errorf("Failed to start RPC layer: %v", err) } diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index ba53c5517002..f33616a01cd3 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -179,7 +179,11 @@ func newServer(c *Config) (*Server, error) { w = os.Stderr } logger := log.New(w, c.NodeName+" - ", log.LstdFlags|log.Lmicroseconds) - srv, err := NewServerLogger(c, logger, new(token.Store), tlsutil.NewConfigurator(c.ToTLSUtilConfig())) + tlsConf, err := tlsutil.NewConfigurator(c.ToTLSUtilConfig(), logger) + if err != nil { + return nil, err + } + srv, err := NewServerLogger(c, logger, new(token.Store), tlsConf) if err != nil { return nil, err } diff --git a/agent/testagent.go b/agent/testagent.go index e9749611dbf4..64343f071a15 100644 --- a/agent/testagent.go +++ b/agent/testagent.go @@ -18,9 +18,11 @@ import ( metrics "github.com/armon/go-metrics" uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/consul/agent/ae" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/consul" + "github.com/hashicorp/consul/agent/local" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/lib/freeport" @@ -103,6 +105,24 @@ func NewTestAgent(t *testing.T, name string, hcl string) *TestAgent { return a } +func NewUnstartedAgent(t *testing.T, name string, hcl string) (*Agent, error) { + c := TestConfig(config.Source{Name: name, Format: "hcl", Data: hcl}) + a, err := New(c) + if err != nil { + return nil, err + } + a.State = local.NewState(LocalConfig(c), a.logger, a.tokens) + a.sync = ae.NewStateSyncer(a.State, c.AEInterval, a.shutdownCh, a.logger) + a.delegate = &consul.Client{} + a.State.TriggerSyncChanges = a.sync.SyncChanges.Trigger + tlsConfigurator, err := tlsutil.NewConfigurator(c.ToTLSUtilConfig(), nil) + if err != nil { + return nil, err + } + a.tlsConfigurator = tlsConfigurator + return a, nil +} + // Start starts a test agent. It fails the test if the agent could not be started. func (a *TestAgent) Start(t *testing.T) *TestAgent { require := require.New(t) @@ -149,7 +169,9 @@ func (a *TestAgent) Start(t *testing.T) *TestAgent { agent.LogWriter = a.LogWriter agent.logger = log.New(logOutput, a.Name+" - ", log.LstdFlags|log.Lmicroseconds) agent.MemSink = metrics.NewInmemSink(1*time.Second, time.Minute) - agent.tlsConfigurator = tlsutil.NewConfigurator(a.Config.ToTLSUtilConfig()) + tlsConfigurator, err := tlsutil.NewConfigurator(a.Config.ToTLSUtilConfig(), nil) + require.NoError(err) + agent.tlsConfigurator = tlsConfigurator // we need the err var in the next exit condition if err := agent.Start(); err == nil { diff --git a/tlsutil/config.go b/tlsutil/config.go index 19d08aad3e5d..c9f2222bdf51 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "log" "net" "strings" "sync" @@ -23,6 +24,7 @@ type Wrapper func(conn net.Conn) (net.Conn, error) // TLSLookup maps the tls_min_version configuration to the internal value var TLSLookup = map[string]uint16{ + "": tls.VersionTLS10, // default in golang "tls10": tls.VersionTLS10, "tls11": tls.VersionTLS11, "tls12": tls.VersionTLS12, @@ -114,14 +116,7 @@ type Config struct { // KeyPair is used to open and parse a certificate and key file func (c *Config) KeyPair() (*tls.Certificate, error) { - if c.CertFile == "" || c.KeyFile == "" { - return nil, nil - } - cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile) - if err != nil { - return nil, fmt.Errorf("Failed to load cert/key pair: %v", err) - } - return &cert, err + return loadKeyPair(c.CertFile, c.KeyFile) } // SpecificDC is used to invoke a static datacenter @@ -135,96 +130,115 @@ func SpecificDC(dc string, tlsWrap DCWrapper) Wrapper { } } -// Wrap a net.Conn into a client tls connection, performing any -// additional verification as needed. -// -// As of go 1.3, crypto/tls only supports either doing no certificate -// verification, or doing full verification including of the peer's -// DNS name. For consul, we want to validate that the certificate is -// signed by a known CA, but because consul doesn't use DNS names for -// node names, we don't verify the certificate DNS names. Since go 1.3 -// no longer supports this mode of operation, we have to do it -// manually. -func (c *Config) wrapTLSClient(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { - var err error - var tlsConn *tls.Conn - - tlsConn = tls.Client(conn, tlsConfig) +// Configurator holds a Config and is responsible for generating all the +// *tls.Config necessary for Consul. Except the one in the api package. +type Configurator struct { + sync.RWMutex + base *Config + cert *tls.Certificate + cas *x509.CertPool + logger *log.Logger + version int +} - // If crypto/tls is doing verification, there's no need to do - // our own. - if tlsConfig.InsecureSkipVerify == false { - return tlsConn, nil +// NewConfigurator creates a new Configurator and sets the provided +// configuration. +func NewConfigurator(config Config, logger *log.Logger) (*Configurator, error) { + c := &Configurator{logger: logger} + err := c.Update(config) + if err != nil { + return nil, err } + return c, nil +} - // If verification is not turned on, don't do it. - if !c.VerifyOutgoing { - return tlsConn, nil +// Update updates the internal configuration which is used to generate +// *tls.Config. +// This function acquires a write lock because it writes the new config. +func (c *Configurator) Update(config Config) error { + cert, err := loadKeyPair(config.CertFile, config.KeyFile) + if err != nil { + return err } - - if err = tlsConn.Handshake(); err != nil { - tlsConn.Close() - return nil, err + cas, err := loadCAs(config.CAFile, config.CAPath) + if err != nil { + return err } - // The following is lightly-modified from the doFullHandshake - // method in crypto/tls's handshake_client.go. - opts := x509.VerifyOptions{ - Roots: tlsConfig.RootCAs, - CurrentTime: time.Now(), - DNSName: "", - Intermediates: x509.NewCertPool(), + if err = c.check(config, cas, cert); err != nil { + return err } + c.Lock() + c.base = &config + c.cert = cert + c.cas = cas + c.version++ + c.Unlock() + c.log("Update") + return nil +} - certs := tlsConn.ConnectionState().PeerCertificates - for i, cert := range certs { - if i == 0 { - continue +func (c *Configurator) check(config Config, cas *x509.CertPool, cert *tls.Certificate) error { + // Check if a minimum TLS version was set + if config.TLSMinVersion != "" { + if _, ok := TLSLookup[config.TLSMinVersion]; !ok { + return fmt.Errorf("TLSMinVersion: value %s not supported, please specify one of [tls10,tls11,tls12]", config.TLSMinVersion) } - opts.Intermediates.AddCert(cert) } - _, err = certs[0].Verify(opts) - if err != nil { - tlsConn.Close() - return nil, err + // Ensure we have a CA if VerifyOutgoing is set + if config.VerifyOutgoing && cas == nil { + return fmt.Errorf("VerifyOutgoing set, and no CA certificate provided!") } - return tlsConn, err -} - -// Configurator holds a Config and is responsible for generating all the -// *tls.Config necessary for Consul. Except the one in the api package. -type Configurator struct { - sync.Mutex - base *Config - checks map[string]bool + // Ensure we have a CA and cert if VerifyIncoming is set + if config.VerifyIncoming || config.VerifyIncomingRPC || config.VerifyIncomingHTTPS { + if cas == nil { + return fmt.Errorf("VerifyIncoming set, and no CA certificate provided!") + } + if cert == nil { + return fmt.Errorf("VerifyIncoming set, and no Cert/Key pair provided!") + } + } + return nil } -// NewConfigurator creates a new Configurator and sets the provided -// configuration. -// Todo (Hans): should config be a value instead a pointer to avoid side -// effects? -func NewConfigurator(config *Config) *Configurator { - return &Configurator{base: config, checks: map[string]bool{}} +func loadKeyPair(certFile, keyFile string) (*tls.Certificate, error) { + if certFile == "" || keyFile == "" { + return nil, nil + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, fmt.Errorf("Failed to load cert/key pair: %v", err) + } + return &cert, nil } -// Update updates the internal configuration which is used to generate -// *tls.Config. -func (c *Configurator) Update(config *Config) { - c.Lock() - defer c.Unlock() - c.base = config +func loadCAs(caFile, caPath string) (*x509.CertPool, error) { + if caFile != "" { + return rootcerts.LoadCAFile(caFile) + } else if caPath != "" { + pool, err := rootcerts.LoadCAPath(caPath) + if err != nil { + return nil, err + } + // make sure to not return an empty pool because this is not + // the users intention when providing a path for CAs. + if len(pool.Subjects()) == 0 { + return nil, fmt.Errorf("Error loading CA: path %q has no CAs", caPath) + } + return pool, nil + } + return nil, nil } // commonTLSConfig generates a *tls.Config from the base configuration the // Configurator has. It accepts an additional flag in case a config is needed // for incoming TLS connections. -func (c *Configurator) commonTLSConfig(additionalVerifyIncomingFlag bool) (*tls.Config, error) { - if c.base == nil { - return nil, fmt.Errorf("No base config") - } - +// This function acquires a read lock because it reads from the config. +func (c *Configurator) commonTLSConfig(additionalVerifyIncomingFlag bool) *tls.Config { + c.RLock() + defer c.RUnlock() tlsConfig := &tls.Config{ InsecureSkipVerify: !c.base.VerifyServerHostname, } @@ -233,156 +247,215 @@ func (c *Configurator) commonTLSConfig(additionalVerifyIncomingFlag bool) (*tls. if len(c.base.CipherSuites) != 0 { tlsConfig.CipherSuites = c.base.CipherSuites } - if c.base.PreferServerCipherSuites { - tlsConfig.PreferServerCipherSuites = true - } - // Add cert/key - cert, err := c.base.KeyPair() - if err != nil { - return nil, err - } else if cert != nil { - tlsConfig.Certificates = []tls.Certificate{*cert} - } + tlsConfig.PreferServerCipherSuites = c.base.PreferServerCipherSuites - // Check if a minimum TLS version was set - if c.base.TLSMinVersion != "" { - tlsvers, ok := TLSLookup[c.base.TLSMinVersion] - if !ok { - return nil, fmt.Errorf("TLSMinVersion: value %s not supported, please specify one of [tls10,tls11,tls12]", c.base.TLSMinVersion) - } - tlsConfig.MinVersion = tlsvers + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return c.cert, nil } - - // Ensure we have a CA if VerifyOutgoing is set - if c.base.VerifyOutgoing && c.base.CAFile == "" && c.base.CAPath == "" { - return nil, fmt.Errorf("VerifyOutgoing set, and no CA certificate provided!") + tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return c.cert, nil } - // Parse the CA certs if any - if c.base.CAFile != "" { - pool, err := rootcerts.LoadCAFile(c.base.CAFile) - if err != nil { - return nil, err - } - tlsConfig.ClientCAs = pool - tlsConfig.RootCAs = pool - } else if c.base.CAPath != "" { - pool, err := rootcerts.LoadCAPath(c.base.CAPath) - if err != nil { - return nil, err - } - tlsConfig.ClientCAs = pool - tlsConfig.RootCAs = pool - } + tlsConfig.ClientCAs = c.cas + tlsConfig.RootCAs = c.cas + + // This is possible because TLSLookup also contains "" with golang's + // default (tls10). And because the initial check makes sure the + // version correctly matches. + tlsConfig.MinVersion = TLSLookup[c.base.TLSMinVersion] // Set ClientAuth if necessary if c.base.VerifyIncoming || additionalVerifyIncomingFlag { - if c.base.CAFile == "" && c.base.CAPath == "" { - return nil, fmt.Errorf("VerifyIncoming set, and no CA certificate provided!") - } - if len(tlsConfig.Certificates) == 0 { - return nil, fmt.Errorf("VerifyIncoming set, and no Cert/Key pair provided!") - } - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert } - return tlsConfig, nil + return tlsConfig +} + +// This function acquires a read lock because it reads from the config. +func (c *Configurator) outgoingRPCTLSDisabled() bool { + c.RLock() + defer c.RUnlock() + return c.cas == nil && !c.base.VerifyOutgoing +} + +// This function acquires a read lock because it reads from the config. +func (c *Configurator) someValuesFromConfig() (bool, bool, string) { + c.RLock() + defer c.RUnlock() + return c.base.VerifyServerHostname, c.base.VerifyOutgoing, c.base.Domain +} + +// This function acquires a read lock because it reads from the config. +func (c *Configurator) verifyIncomingRPC() bool { + c.RLock() + defer c.RUnlock() + return c.base.VerifyIncomingRPC +} + +// This function acquires a read lock because it reads from the config. +func (c *Configurator) verifyIncomingHTTPS() bool { + c.RLock() + defer c.RUnlock() + return c.base.VerifyIncomingHTTPS +} + +// This function acquires a read lock because it reads from the config. +func (c *Configurator) enableAgentTLSForChecks() bool { + c.RLock() + defer c.RUnlock() + return c.base.EnableAgentTLSForChecks +} + +// This function acquires a read lock because it reads from the config. +func (c *Configurator) serverNameOrNodeName() string { + c.RLock() + defer c.RUnlock() + if c.base.ServerName != "" { + return c.base.ServerName + } + return c.base.NodeName } // IncomingRPCConfig generates a *tls.Config for incoming RPC connections. -func (c *Configurator) IncomingRPCConfig() (*tls.Config, error) { - return c.commonTLSConfig(c.base.VerifyIncomingRPC) +func (c *Configurator) IncomingRPCConfig() *tls.Config { + c.log("IncomingRPCConfig") + config := c.commonTLSConfig(c.verifyIncomingRPC()) + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return c.IncomingRPCConfig(), nil + } + return config } // IncomingHTTPSConfig generates a *tls.Config for incoming HTTPS connections. -func (c *Configurator) IncomingHTTPSConfig() (*tls.Config, error) { - return c.commonTLSConfig(c.base.VerifyIncomingHTTPS) +func (c *Configurator) IncomingHTTPSConfig() *tls.Config { + c.log("IncomingHTTPSConfig") + config := c.commonTLSConfig(c.verifyIncomingHTTPS()) + config.NextProtos = []string{"h2", "http/1.1"} + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return c.IncomingHTTPSConfig(), nil + } + return config } // IncomingTLSConfig generates a *tls.Config for outgoing TLS connections for // checks. This function is seperated because there is an extra flag to // consider for checks. EnableAgentTLSForChecks and InsecureSkipVerify has to // be checked for checks. -func (c *Configurator) OutgoingTLSConfigForCheck(id string) (*tls.Config, error) { - if !c.base.EnableAgentTLSForChecks { +func (c *Configurator) OutgoingTLSConfigForCheck(skipVerify bool) *tls.Config { + c.log("OutgoingTLSConfigForCheck") + if !c.enableAgentTLSForChecks() { return &tls.Config{ - InsecureSkipVerify: c.getSkipVerifyForCheck(id), - }, nil + InsecureSkipVerify: skipVerify, + } } - tlsConfig, err := c.commonTLSConfig(false) - if err != nil { - return nil, err - } - tlsConfig.InsecureSkipVerify = c.getSkipVerifyForCheck(id) - tlsConfig.ServerName = c.base.ServerName - if tlsConfig.ServerName == "" { - tlsConfig.ServerName = c.base.NodeName - } + config := c.commonTLSConfig(false) + config.InsecureSkipVerify = skipVerify + config.ServerName = c.serverNameOrNodeName() - return tlsConfig, nil + return config } // OutgoingRPCConfig generates a *tls.Config for outgoing RPC connections. If // there is a CA or VerifyOutgoing is set, a *tls.Config will be provided, // otherwise we assume that no TLS should be used. -func (c *Configurator) OutgoingRPCConfig() (*tls.Config, error) { - useTLS := c.base.CAFile != "" || c.base.CAPath != "" || c.base.VerifyOutgoing - if !useTLS { - return nil, nil +func (c *Configurator) OutgoingRPCConfig() *tls.Config { + c.log("OutgoingRPCConfig") + if c.outgoingRPCTLSDisabled() { + return nil } return c.commonTLSConfig(false) } // OutgoingRPCWrapper wraps the result of OutgoingRPCConfig in a DCWrapper. It // decides if verify server hostname should be used. -func (c *Configurator) OutgoingRPCWrapper() (DCWrapper, error) { - // Get the TLS config - tlsConfig, err := c.OutgoingRPCConfig() - if err != nil { - return nil, err +func (c *Configurator) OutgoingRPCWrapper() DCWrapper { + c.log("OutgoingRPCWrapper") + if c.outgoingRPCTLSDisabled() { + return nil } - // Check if TLS is not enabled - if tlsConfig == nil { - return nil, nil + // Generate the wrapper based on dc + return func(dc string, conn net.Conn) (net.Conn, error) { + return c.wrapTLSClient(dc, conn) } +} - // Generate the wrapper based on hostname verification - wrapper := func(dc string, conn net.Conn) (net.Conn, error) { - if c.base.VerifyServerHostname { - // Strip the trailing '.' from the domain if any - domain := strings.TrimSuffix(c.base.Domain, ".") - tlsConfig = tlsConfig.Clone() - tlsConfig.ServerName = "server." + dc + "." + domain - } - return c.base.wrapTLSClient(conn, tlsConfig) +// This function acquires a read lock because it reads from the config. +func (c *Configurator) log(name string) { + if c.logger != nil { + c.RLock() + defer c.RUnlock() + c.logger.Printf("[DEBUG] tlsutil: %s with version %d", name, c.version) } - - return wrapper, nil } -// AddCheck adds a check to the internal check map together with the skipVerify -// value, which is used when generating a *tls.Config for this check. -func (c *Configurator) AddCheck(id string, skipVerify bool) { - c.Lock() - defer c.Unlock() - c.checks[id] = skipVerify -} +// Wrap a net.Conn into a client tls connection, performing any +// additional verification as needed. +// +// As of go 1.3, crypto/tls only supports either doing no certificate +// verification, or doing full verification including of the peer's +// DNS name. For consul, we want to validate that the certificate is +// signed by a known CA, but because consul doesn't use DNS names for +// node names, we don't verify the certificate DNS names. Since go 1.3 +// no longer supports this mode of operation, we have to do it +// manually. +func (c *Configurator) wrapTLSClient(dc string, conn net.Conn) (net.Conn, error) { + var err error + var tlsConn *tls.Conn -// RemoveCheck removes a check from the internal check map. -func (c *Configurator) RemoveCheck(id string) { - c.Lock() - defer c.Unlock() - delete(c.checks, id) -} + config := c.OutgoingRPCConfig() + verifyServerHostname, verifyOutgoing, domain := c.someValuesFromConfig() -func (c *Configurator) getSkipVerifyForCheck(id string) bool { - c.Lock() - defer c.Unlock() - return c.checks[id] + if verifyServerHostname { + // Strip the trailing '.' from the domain if any + domain = strings.TrimSuffix(domain, ".") + config.ServerName = "server." + dc + "." + domain + } + tlsConn = tls.Client(conn, config) + + // If crypto/tls is doing verification, there's no need to do + // our own. + if !config.InsecureSkipVerify { + return tlsConn, nil + } + + // If verification is not turned on, don't do it. + if !verifyOutgoing { + return tlsConn, nil + } + + if err = tlsConn.Handshake(); err != nil { + tlsConn.Close() + return nil, err + } + + // The following is lightly-modified from the doFullHandshake + // method in crypto/tls's handshake_client.go. + opts := x509.VerifyOptions{ + Roots: config.RootCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + + certs := tlsConn.ConnectionState().PeerCertificates + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + _, err = certs[0].Verify(opts) + if err != nil { + tlsConn.Close() + return nil, err + } + + return tlsConn, err } // ParseCiphers parse ciphersuites from the comma-separated string into diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index de213c0def47..5515871c771d 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -1,10 +1,13 @@ package tlsutil import ( + "bytes" "crypto/tls" "crypto/x509" + "fmt" "io" "io/ioutil" + "log" "net" "reflect" "strings" @@ -14,135 +17,15 @@ import ( "github.com/stretchr/testify/require" ) -func TestConfig_KeyPair_None(t *testing.T) { - conf := &Config{} - cert, err := conf.KeyPair() - if err != nil { - t.Fatalf("err: %v", err) - } - if cert != nil { - t.Fatalf("bad: %v", cert) - } -} - -func TestConfig_KeyPair_Valid(t *testing.T) { - conf := &Config{ - CertFile: "../test/key/ourdomain.cer", - KeyFile: "../test/key/ourdomain.key", - } - cert, err := conf.KeyPair() - if err != nil { - t.Fatalf("err: %v", err) - } - if cert == nil { - t.Fatalf("expected cert") - } -} - -func TestConfigurator_OutgoingTLS_MissingCA(t *testing.T) { - conf := &Config{ - VerifyOutgoing: true, - } - c := NewConfigurator(conf) - tlsConf, err := c.OutgoingRPCConfig() - require.Error(t, err) - require.Nil(t, tlsConf) -} - -func TestConfigurator_OutgoingTLS_OnlyCA(t *testing.T) { - conf := &Config{ - CAFile: "../test/ca/root.cer", - } - c := NewConfigurator(conf) - tlsConf, err := c.OutgoingRPCConfig() - require.NoError(t, err) - require.NotNil(t, tlsConf) -} - -func TestConfigurator_OutgoingTLS_VerifyOutgoing(t *testing.T) { - conf := &Config{ - VerifyOutgoing: true, - CAFile: "../test/ca/root.cer", - } - c := NewConfigurator(conf) - tlsConf, err := c.OutgoingRPCConfig() - require.NoError(t, err) - require.NotNil(t, tlsConf) - require.Len(t, tlsConf.RootCAs.Subjects(), 1) - require.Empty(t, tlsConf.ServerName) - require.True(t, tlsConf.InsecureSkipVerify) -} - -func TestConfigurator_OutgoingRPC_ServerName(t *testing.T) { - conf := &Config{ - VerifyOutgoing: true, - CAFile: "../test/ca/root.cer", - ServerName: "consul.example.com", - } - c := NewConfigurator(conf) - tlsConf, err := c.OutgoingRPCConfig() - require.NoError(t, err) - require.NotNil(t, tlsConf) - require.Len(t, tlsConf.RootCAs.Subjects(), 1) - require.Empty(t, tlsConf.ServerName) - require.True(t, tlsConf.InsecureSkipVerify) -} - -func TestConfigurator_OutgoingTLS_VerifyHostname(t *testing.T) { - conf := &Config{ - VerifyOutgoing: true, - VerifyServerHostname: true, - CAFile: "../test/ca/root.cer", - } - c := NewConfigurator(conf) - tlsConf, err := c.OutgoingRPCConfig() - require.NoError(t, err) - require.NotNil(t, tlsConf) - require.Len(t, tlsConf.RootCAs.Subjects(), 1) - require.False(t, tlsConf.InsecureSkipVerify) -} - -func TestConfigurator_OutgoingTLS_WithKeyPair(t *testing.T) { - conf := &Config{ - VerifyOutgoing: true, - CAFile: "../test/ca/root.cer", - CertFile: "../test/key/ourdomain.cer", - KeyFile: "../test/key/ourdomain.key", - } - c := NewConfigurator(conf) - tlsConf, err := c.OutgoingRPCConfig() - require.NoError(t, err) - require.NotNil(t, tlsConf) - require.True(t, tlsConf.InsecureSkipVerify) - require.Len(t, tlsConf.Certificates, 1) -} - -func TestConfigurator_OutgoingTLS_TLSMinVersion(t *testing.T) { - tlsVersions := []string{"tls10", "tls11", "tls12"} - for _, version := range tlsVersions { - conf := &Config{ - VerifyOutgoing: true, - CAFile: "../test/ca/root.cer", - TLSMinVersion: version, - } - c := NewConfigurator(conf) - tlsConf, err := c.OutgoingRPCConfig() - require.NoError(t, err) - require.NotNil(t, tlsConf) - require.Equal(t, tlsConf.MinVersion, TLSLookup[version]) - } -} - func startTLSServer(config *Config) (net.Conn, chan error) { errc := make(chan error, 1) - c := NewConfigurator(config) - tlsConfigServer, err := c.IncomingRPCConfig() + c, err := NewConfigurator(*config, nil) if err != nil { errc <- err return nil, errc } - + tlsConfigServer := c.IncomingRPCConfig() client, server := net.Pipe() // Use yamux to buffer the reads, otherwise it's easy to deadlock @@ -172,7 +55,7 @@ func startTLSServer(config *Config) (net.Conn, chan error) { } func TestConfigurator_outgoingWrapper_OK(t *testing.T) { - config := &Config{ + config := Config{ CAFile: "../test/hostname/CertAuth.crt", CertFile: "../test/hostname/Alice.crt", KeyFile: "../test/hostname/Alice.key", @@ -181,14 +64,44 @@ func TestConfigurator_outgoingWrapper_OK(t *testing.T) { Domain: "consul", } - client, errc := startTLSServer(config) + client, errc := startTLSServer(&config) + if client == nil { + t.Fatalf("startTLSServer err: %v", <-errc) + } + + c, err := NewConfigurator(config, nil) + require.NoError(t, err) + wrap := c.OutgoingRPCWrapper() + require.NotNil(t, wrap) + + tlsClient, err := wrap("dc1", client) + require.NoError(t, err) + + defer tlsClient.Close() + err = tlsClient.(*tls.Conn).Handshake() + require.NoError(t, err) + + err = <-errc + require.NoError(t, err) +} + +func TestConfigurator_outgoingWrapper_noverify_OK(t *testing.T) { + config := Config{ + CAFile: "../test/hostname/CertAuth.crt", + CertFile: "../test/hostname/Alice.crt", + KeyFile: "../test/hostname/Alice.key", + Domain: "consul", + } + + client, errc := startTLSServer(&config) if client == nil { t.Fatalf("startTLSServer err: %v", <-errc) } - c := NewConfigurator(config) - wrap, err := c.OutgoingRPCWrapper() + c, err := NewConfigurator(config, nil) require.NoError(t, err) + wrap := c.OutgoingRPCWrapper() + require.NotNil(t, wrap) tlsClient, err := wrap("dc1", client) require.NoError(t, err) @@ -202,7 +115,7 @@ func TestConfigurator_outgoingWrapper_OK(t *testing.T) { } func TestConfigurator_outgoingWrapper_BadDC(t *testing.T) { - config := &Config{ + config := Config{ CAFile: "../test/hostname/CertAuth.crt", CertFile: "../test/hostname/Alice.crt", KeyFile: "../test/hostname/Alice.key", @@ -211,14 +124,14 @@ func TestConfigurator_outgoingWrapper_BadDC(t *testing.T) { Domain: "consul", } - client, errc := startTLSServer(config) + client, errc := startTLSServer(&config) if client == nil { t.Fatalf("startTLSServer err: %v", <-errc) } - c := NewConfigurator(config) - wrap, err := c.OutgoingRPCWrapper() + c, err := NewConfigurator(config, nil) require.NoError(t, err) + wrap := c.OutgoingRPCWrapper() tlsClient, err := wrap("dc2", client) require.NoError(t, err) @@ -232,7 +145,7 @@ func TestConfigurator_outgoingWrapper_BadDC(t *testing.T) { } func TestConfigurator_outgoingWrapper_BadCert(t *testing.T) { - config := &Config{ + config := Config{ CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key", @@ -241,14 +154,14 @@ func TestConfigurator_outgoingWrapper_BadCert(t *testing.T) { Domain: "consul", } - client, errc := startTLSServer(config) + client, errc := startTLSServer(&config) if client == nil { t.Fatalf("startTLSServer err: %v", <-errc) } - c := NewConfigurator(config) - wrap, err := c.OutgoingRPCWrapper() + c, err := NewConfigurator(config, nil) require.NoError(t, err) + wrap := c.OutgoingRPCWrapper() tlsClient, err := wrap("dc1", client) require.NoError(t, err) @@ -263,23 +176,22 @@ func TestConfigurator_outgoingWrapper_BadCert(t *testing.T) { } func TestConfigurator_wrapTLS_OK(t *testing.T) { - config := &Config{ + config := Config{ CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key", VerifyOutgoing: true, } - client, errc := startTLSServer(config) + client, errc := startTLSServer(&config) if client == nil { t.Fatalf("startTLSServer err: %v", <-errc) } - c := NewConfigurator(config) - clientConfig, err := c.OutgoingRPCConfig() + c, err := NewConfigurator(config, nil) require.NoError(t, err) - tlsClient, err := config.wrapTLSClient(client, clientConfig) + tlsClient, err := c.wrapTLSClient("dc1", client) require.NoError(t, err) tlsClient.Close() @@ -298,16 +210,14 @@ func TestConfigurator_wrapTLS_BadCert(t *testing.T) { t.Fatalf("startTLSServer err: %v", <-errc) } - clientConfig := &Config{ + clientConfig := Config{ CAFile: "../test/ca/root.cer", VerifyOutgoing: true, } - c := NewConfigurator(clientConfig) - clientTLSConfig, err := c.OutgoingRPCConfig() + c, err := NewConfigurator(clientConfig, nil) require.NoError(t, err) - - tlsClient, err := clientConfig.wrapTLSClient(client, clientTLSConfig) + tlsClient, err := c.wrapTLSClient("dc1", client) require.Error(t, err) require.Nil(t, tlsClient) @@ -365,355 +275,485 @@ func TestConfig_ParseCiphers(t *testing.T) { tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, } v, err := ParseCiphers(testOk) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if got, want := v, ciphers; !reflect.DeepEqual(got, want) { t.Fatalf("got ciphers %#v want %#v", got, want) } - testBad := "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,cipherX" - if _, err := ParseCiphers(testBad); err == nil { - t.Fatal("should fail on unsupported cipherX") - } -} - -func TestConfigurator_IncomingHTTPSConfig_CA_PATH(t *testing.T) { - conf := &Config{CAPath: "../test/ca_path"} + _, err = ParseCiphers("TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,cipherX") + require.Error(t, err) - c := NewConfigurator(conf) - tlsConf, err := c.IncomingHTTPSConfig() + v, err = ParseCiphers("") require.NoError(t, err) - require.Len(t, tlsConf.ClientCAs.Subjects(), 2) + require.Equal(t, []uint16{}, v) } -func TestConfigurator_IncomingHTTPS(t *testing.T) { - conf := &Config{ - VerifyIncoming: true, - CAFile: "../test/ca/root.cer", - CertFile: "../test/key/ourdomain.cer", - KeyFile: "../test/key/ourdomain.key", +func TestConfigurator_loadKeyPair(t *testing.T) { + type variant struct { + cert, key string + shoulderr bool + isnil bool } - c := NewConfigurator(conf) - tlsConf, err := c.IncomingHTTPSConfig() - require.NoError(t, err) - require.NotNil(t, tlsConf) - require.Len(t, tlsConf.ClientCAs.Subjects(), 1) - require.Equal(t, tlsConf.ClientAuth, tls.RequireAndVerifyClientCert) - require.Len(t, tlsConf.Certificates, 1) -} - -func TestConfigurator_IncomingHTTPS_MissingCA(t *testing.T) { - conf := &Config{ - VerifyIncoming: true, - CertFile: "../test/key/ourdomain.cer", - KeyFile: "../test/key/ourdomain.key", + variants := []variant{ + {"", "", false, true}, + {"bogus", "", false, true}, + {"", "bogus", false, true}, + {"../test/key/ourdomain.cer", "", false, true}, + {"", "../test/key/ourdomain.key", false, true}, + {"bogus", "bogus", true, true}, + {"../test/key/ourdomain.cer", "../test/key/ourdomain.key", + false, false}, } - c := NewConfigurator(conf) - _, err := c.IncomingHTTPSConfig() - require.Error(t, err) -} - -func TestConfigurator_IncomingHTTPS_MissingKey(t *testing.T) { - conf := &Config{ - VerifyIncoming: true, - CAFile: "../test/ca/root.cer", + for _, v := range variants { + cert1, err1 := loadKeyPair(v.cert, v.key) + config := &Config{CertFile: v.cert, KeyFile: v.key} + cert2, err2 := config.KeyPair() + if v.shoulderr { + require.Error(t, err1) + require.Error(t, err2) + } else { + require.NoError(t, err1) + require.NoError(t, err2) + } + if v.isnil { + require.Nil(t, cert1) + require.Nil(t, cert2) + } else { + require.NotNil(t, cert1) + require.NotNil(t, cert2) + } } - c := NewConfigurator(conf) - _, err := c.IncomingHTTPSConfig() - require.Error(t, err) } -func TestConfigurator_IncomingHTTPS_NoVerify(t *testing.T) { - conf := &Config{} - c := NewConfigurator(conf) - tlsConf, err := c.IncomingHTTPSConfig() +func TestConfig_SpecifyDC(t *testing.T) { + require.Nil(t, SpecificDC("", nil)) + dcwrap := func(dc string, conn net.Conn) (net.Conn, error) { return nil, nil } + wrap := SpecificDC("", dcwrap) + require.NotNil(t, wrap) + conn, err := wrap(nil) require.NoError(t, err) - require.NotNil(t, tlsConf) - require.Nil(t, tlsConf.ClientCAs) - require.Equal(t, tlsConf.ClientAuth, tls.NoClientCert) - require.Empty(t, tlsConf.Certificates) -} - -func TestConfigurator_IncomingHTTPS_TLSMinVersion(t *testing.T) { - tlsVersions := []string{"tls10", "tls11", "tls12"} - for _, version := range tlsVersions { - conf := &Config{ - VerifyIncoming: true, - CAFile: "../test/ca/root.cer", - CertFile: "../test/key/ourdomain.cer", - KeyFile: "../test/key/ourdomain.key", - TLSMinVersion: version, - } - c := NewConfigurator(conf) - tlsConf, err := c.IncomingHTTPSConfig() - require.NoError(t, err) - require.NotNil(t, tlsConf) - require.Equal(t, tlsConf.MinVersion, TLSLookup[version]) - } + require.Nil(t, conn) } -func TestConfigurator_IncomingHTTPSCAPath_Valid(t *testing.T) { - - c := NewConfigurator(&Config{CAPath: "../test/ca_path"}) - tlsConf, err := c.IncomingHTTPSConfig() +func TestConfigurator_NewConfigurator(t *testing.T) { + buf := bytes.Buffer{} + logger := log.New(&buf, "logger: ", log.Lshortfile) + c, err := NewConfigurator(Config{}, logger) require.NoError(t, err) - require.Len(t, tlsConf.ClientCAs.Subjects(), 2) -} + require.NotNil(t, c) + require.Equal(t, logger, c.logger) -func TestConfigurator_CommonTLSConfigNoBaseConfig(t *testing.T) { - c := NewConfigurator(nil) - _, err := c.commonTLSConfig(false) + c, err = NewConfigurator(Config{VerifyOutgoing: true}, nil) require.Error(t, err) + require.Nil(t, c) +} + +func TestConfigurator_ErrorPropagation(t *testing.T) { + type variant struct { + config Config + shouldErr bool + excludeCheck bool + } + cafile := "../test/ca/root.cer" + capath := "../test/ca_path" + certfile := "../test/key/ourdomain.cer" + keyfile := "../test/key/ourdomain.key" + variants := []variant{ + {Config{}, false, false}, + {Config{TLSMinVersion: "tls9"}, true, false}, + {Config{TLSMinVersion: ""}, false, false}, + {Config{TLSMinVersion: "tls10"}, false, false}, + {Config{TLSMinVersion: "tls11"}, false, false}, + {Config{TLSMinVersion: "tls12"}, false, false}, + {Config{VerifyOutgoing: true, CAFile: "", CAPath: ""}, true, false}, + {Config{VerifyOutgoing: false, CAFile: "", CAPath: ""}, false, false}, + {Config{VerifyOutgoing: false, CAFile: cafile, CAPath: ""}, + false, false}, + {Config{VerifyOutgoing: false, CAFile: "", CAPath: capath}, + false, false}, + {Config{VerifyOutgoing: false, CAFile: cafile, CAPath: capath}, + false, false}, + {Config{VerifyOutgoing: true, CAFile: cafile, CAPath: ""}, + false, false}, + {Config{VerifyOutgoing: true, CAFile: "", CAPath: capath}, + false, false}, + {Config{VerifyOutgoing: true, CAFile: cafile, CAPath: capath}, + false, false}, + {Config{VerifyIncoming: true, CAFile: "", CAPath: ""}, true, false}, + {Config{VerifyIncomingRPC: true, CAFile: "", CAPath: ""}, + true, false}, + {Config{VerifyIncomingHTTPS: true, CAFile: "", CAPath: ""}, + true, false}, + {Config{VerifyIncoming: true, CAFile: cafile, CAPath: ""}, true, false}, + {Config{VerifyIncoming: true, CAFile: "", CAPath: capath}, true, false}, + {Config{VerifyIncoming: true, CAFile: "", CAPath: capath, + CertFile: certfile, KeyFile: keyfile}, false, false}, + {Config{CertFile: "bogus", KeyFile: "bogus"}, true, true}, + {Config{CAFile: "bogus"}, true, true}, + {Config{CAPath: "bogus"}, true, true}, + } + + c := &Configurator{} + for i, v := range variants { + info := fmt.Sprintf("case %d", i) + _, err1 := NewConfigurator(v.config, nil) + err2 := c.Update(v.config) + + var err3 error + if !v.excludeCheck { + cert, err := v.config.KeyPair() + require.NoError(t, err, info) + cas, _ := loadCAs(v.config.CAFile, v.config.CAPath) + require.NoError(t, err, info) + err3 = c.check(v.config, cas, cert) + } + if v.shouldErr { + require.Error(t, err1, info) + require.Error(t, err2, info) + if !v.excludeCheck { + require.Error(t, err3, info) + } + } else { + require.NoError(t, err1, info) + require.NoError(t, err2, info) + if !v.excludeCheck { + require.NoError(t, err3, info) + } + } + } } func TestConfigurator_CommonTLSConfigServerNameNodeName(t *testing.T) { type variant struct { - config *Config + config Config result string } variants := []variant{ - {config: &Config{NodeName: "node", ServerName: "server"}, + {config: Config{NodeName: "node", ServerName: "server"}, result: "server"}, - {config: &Config{ServerName: "server"}, + {config: Config{ServerName: "server"}, result: "server"}, - {config: &Config{NodeName: "node"}, + {config: Config{NodeName: "node"}, result: "node"}, } for _, v := range variants { - c := NewConfigurator(v.config) - tlsConf, err := c.commonTLSConfig(false) + c, err := NewConfigurator(v.config, nil) require.NoError(t, err) + tlsConf := c.commonTLSConfig(false) require.Empty(t, tlsConf.ServerName) } } -func TestConfigurator_CommonTLSConfigCipherSuites(t *testing.T) { - c := NewConfigurator(&Config{}) - tlsConfig, err := c.commonTLSConfig(false) - require.NoError(t, err) - require.Empty(t, tlsConfig.CipherSuites) - - conf := &Config{CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}} - c.Update(conf) - tlsConfig, err = c.commonTLSConfig(false) - require.NoError(t, err) - require.Equal(t, conf.CipherSuites, tlsConfig.CipherSuites) +func TestConfigurator_loadCAs(t *testing.T) { + type variant struct { + cafile, capath string + shouldErr bool + isNil bool + count int + } + variants := []variant{ + {"", "", false, true, 0}, + {"bogus", "", true, true, 0}, + {"", "bogus", true, true, 0}, + {"", "../test/bin", true, true, 0}, + {"../test/ca/root.cer", "", false, false, 1}, + {"", "../test/ca_path", false, false, 2}, + {"../test/ca/root.cer", "../test/ca_path", false, false, 1}, + } + for i, v := range variants { + cas, err := loadCAs(v.cafile, v.capath) + info := fmt.Sprintf("case %d", i) + if v.shouldErr { + require.Error(t, err, info) + } else { + require.NoError(t, err, info) + } + if v.isNil { + require.Nil(t, cas, info) + } else { + require.NotNil(t, cas, info) + require.Len(t, cas.Subjects(), v.count, info) + } + } } -func TestConfigurator_CommonTLSConfigCertKey(t *testing.T) { - c := NewConfigurator(&Config{}) - tlsConf, err := c.commonTLSConfig(false) +func TestConfigurator_CommonTLSConfigInsecureSkipVerify(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) require.NoError(t, err) - require.Empty(t, tlsConf.Certificates) + tlsConf := c.commonTLSConfig(false) + require.True(t, tlsConf.InsecureSkipVerify) - c.Update(&Config{CertFile: "/something/bogus", KeyFile: "/more/bogus"}) - tlsConf, err = c.commonTLSConfig(false) - require.Error(t, err) + require.NoError(t, c.Update(Config{VerifyServerHostname: false})) + tlsConf = c.commonTLSConfig(false) + require.True(t, tlsConf.InsecureSkipVerify) - c.Update(&Config{CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.commonTLSConfig(false) - require.NoError(t, err) - require.Len(t, tlsConf.Certificates, 1) + require.NoError(t, c.Update(Config{VerifyServerHostname: true})) + tlsConf = c.commonTLSConfig(false) + require.False(t, tlsConf.InsecureSkipVerify) } -func TestConfigurator_CommonTLSConfigTLSMinVersion(t *testing.T) { - tlsVersions := []string{"tls10", "tls11", "tls12"} - for _, version := range tlsVersions { - c := NewConfigurator(&Config{TLSMinVersion: version}) - tlsConf, err := c.commonTLSConfig(false) - require.NoError(t, err) - require.Equal(t, tlsConf.MinVersion, TLSLookup[version]) - } +func TestConfigurator_CommonTLSConfigPreferServerCipherSuites(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) + require.NoError(t, err) + tlsConf := c.commonTLSConfig(false) + require.False(t, tlsConf.PreferServerCipherSuites) - c := NewConfigurator(&Config{TLSMinVersion: "tlsBOGUS"}) - _, err := c.commonTLSConfig(false) - require.Error(t, err) -} + require.NoError(t, c.Update(Config{PreferServerCipherSuites: false})) + tlsConf = c.commonTLSConfig(false) + require.False(t, tlsConf.PreferServerCipherSuites) -func TestConfigurator_CommonTLSConfigValidateVerifyOutgoingCA(t *testing.T) { - c := NewConfigurator(&Config{VerifyOutgoing: true}) - _, err := c.commonTLSConfig(false) - require.Error(t, err) + require.NoError(t, c.Update(Config{PreferServerCipherSuites: true})) + tlsConf = c.commonTLSConfig(false) + require.True(t, tlsConf.PreferServerCipherSuites) } -func TestConfigurator_CommonTLSConfigLoadCA(t *testing.T) { - c := NewConfigurator(&Config{}) - tlsConf, err := c.commonTLSConfig(false) +func TestConfigurator_CommonTLSConfigCipherSuites(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) require.NoError(t, err) - require.Nil(t, tlsConf.RootCAs) - require.Nil(t, tlsConf.ClientCAs) + tlsConf := c.commonTLSConfig(false) + require.Empty(t, tlsConf.CipherSuites) - c.Update(&Config{CAFile: "/something/bogus"}) - _, err = c.commonTLSConfig(false) - require.Error(t, err) + conf := Config{CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}} + require.NoError(t, c.Update(conf)) + tlsConf = c.commonTLSConfig(false) + require.Equal(t, conf.CipherSuites, tlsConf.CipherSuites) +} - c.Update(&Config{CAPath: "/something/bogus/"}) - _, err = c.commonTLSConfig(false) - require.Error(t, err) +func TestConfigurator_CommonTLSConfigGetClientCertificate(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) + require.NoError(t, err) - c.Update(&Config{CAFile: "../test/ca/root.cer"}) - tlsConf, err = c.commonTLSConfig(false) + cert, err := c.commonTLSConfig(false).GetCertificate(nil) require.NoError(t, err) - require.Len(t, tlsConf.RootCAs.Subjects(), 1) - require.Len(t, tlsConf.ClientCAs.Subjects(), 1) + require.Nil(t, cert) - c.Update(&Config{CAPath: "../test/ca_path"}) - tlsConf, err = c.commonTLSConfig(false) + c.cert = &tls.Certificate{} + cert, err = c.commonTLSConfig(false).GetCertificate(nil) require.NoError(t, err) - require.Len(t, tlsConf.RootCAs.Subjects(), 2) - require.Len(t, tlsConf.ClientCAs.Subjects(), 2) + require.Equal(t, c.cert, cert) - c.Update(&Config{CAFile: "../test/ca/root.cer", CAPath: "../test/ca_path"}) - tlsConf, err = c.commonTLSConfig(false) + cert, err = c.commonTLSConfig(false).GetClientCertificate(nil) require.NoError(t, err) - require.Len(t, tlsConf.RootCAs.Subjects(), 1) - require.Len(t, tlsConf.ClientCAs.Subjects(), 1) + require.Equal(t, c.cert, cert) } -func TestConfigurator_CommonTLSConfigVerifyIncoming(t *testing.T) { - c := NewConfigurator(&Config{}) - tlsConf, err := c.commonTLSConfig(false) +func TestConfigurator_CommonTLSConfigCAs(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) require.NoError(t, err) - require.Equal(t, tls.NoClientCert, tlsConf.ClientAuth) - - c.Update(&Config{VerifyIncoming: true}) - tlsConf, err = c.commonTLSConfig(false) - require.Error(t, err) - - c.Update(&Config{VerifyIncoming: true, CAFile: "../test/ca/root.cer"}) - tlsConf, err = c.commonTLSConfig(false) - require.Error(t, err) + require.Nil(t, c.commonTLSConfig(false).ClientCAs) + require.Nil(t, c.commonTLSConfig(false).RootCAs) - c.Update(&Config{VerifyIncoming: true, CAFile: "../test/ca/root.cer", CertFile: "../test/cert/ourdomain.cer"}) - tlsConf, err = c.commonTLSConfig(false) - require.Error(t, err) - - c.Update(&Config{VerifyIncoming: true, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.commonTLSConfig(false) - require.NoError(t, err) - require.Equal(t, tls.RequireAndVerifyClientCert, tlsConf.ClientAuth) + c.cas = &x509.CertPool{} + require.Equal(t, c.cas, c.commonTLSConfig(false).ClientCAs) + require.Equal(t, c.cas, c.commonTLSConfig(false).RootCAs) +} - c.Update(&Config{VerifyIncoming: false, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.commonTLSConfig(true) +func TestConfigurator_CommonTLSConfigTLSMinVersion(t *testing.T) { + c, err := NewConfigurator(Config{TLSMinVersion: ""}, nil) require.NoError(t, err) - require.Equal(t, tls.RequireAndVerifyClientCert, tlsConf.ClientAuth) + require.Equal(t, c.commonTLSConfig(false).MinVersion, TLSLookup["tls10"]) - c.Update(&Config{VerifyServerHostname: false, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.commonTLSConfig(false) - require.NoError(t, err) - require.True(t, tlsConf.InsecureSkipVerify) + tlsVersions := []string{"tls10", "tls11", "tls12"} + for _, version := range tlsVersions { + require.NoError(t, c.Update(Config{TLSMinVersion: version})) + require.Equal(t, c.commonTLSConfig(false).MinVersion, + TLSLookup[version]) + } - c.Update(&Config{VerifyServerHostname: true, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.commonTLSConfig(false) - require.NoError(t, err) - require.False(t, tlsConf.InsecureSkipVerify) + require.Error(t, c.Update(Config{TLSMinVersion: "tlsBOGUS"})) } -func TestConfigurator_IncomingRPCConfig(t *testing.T) { - c := NewConfigurator(&Config{}) - tlsConf, err := c.IncomingRPCConfig() - require.NoError(t, err) - require.Equal(t, tls.NoClientCert, tlsConf.ClientAuth) +func TestConfigurator_CommonTLSConfigVerifyIncoming(t *testing.T) { + c := Configurator{base: &Config{}} + type variant struct { + verify bool + additional bool + expected tls.ClientAuthType + } + variants := []variant{ + {false, false, tls.NoClientCert}, + {true, false, tls.RequireAndVerifyClientCert}, + {false, true, tls.RequireAndVerifyClientCert}, + {true, true, tls.RequireAndVerifyClientCert}, + } + for _, v := range variants { + c.base.VerifyIncoming = v.verify + require.Equal(t, v.expected, + c.commonTLSConfig(v.additional).ClientAuth) + } +} - c.Update(&Config{VerifyIncoming: true, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.IncomingRPCConfig() - require.NoError(t, err) - require.Equal(t, tls.RequireAndVerifyClientCert, tlsConf.ClientAuth) +func TestConfigurator_OutgoingRPCTLSDisabled(t *testing.T) { + c := Configurator{base: &Config{}} + type variant struct { + verify bool + file string + path string + expected bool + } + cafile := "../test/ca/root.cer" + capath := "../test/ca_path" + variants := []variant{ + {false, "", "", true}, + {false, cafile, "", false}, + {false, "", capath, false}, + {false, cafile, capath, false}, + {true, "", "", false}, + {true, cafile, "", false}, + {true, "", capath, false}, + {true, cafile, capath, false}, + } + for i, v := range variants { + info := fmt.Sprintf("case %d", i) + cas, err := loadCAs(v.file, v.path) + require.NoError(t, err, info) + c.cas = cas + c.base.VerifyOutgoing = v.verify + require.Equal(t, v.expected, c.outgoingRPCTLSDisabled(), info) + } +} + +func TestConfigurator_SomeValuesFromConfig(t *testing.T) { + c := Configurator{base: &Config{ + VerifyServerHostname: true, + VerifyOutgoing: true, + Domain: "abc.de", + }} + one, two, three := c.someValuesFromConfig() + require.Equal(t, c.base.VerifyServerHostname, one) + require.Equal(t, c.base.VerifyOutgoing, two) + require.Equal(t, c.base.Domain, three) +} - c.Update(&Config{VerifyIncomingRPC: true, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.IncomingRPCConfig() - require.NoError(t, err) - require.Equal(t, tls.RequireAndVerifyClientCert, tlsConf.ClientAuth) +func TestConfigurator_VerifyIncomingRPC(t *testing.T) { + c := Configurator{base: &Config{ + VerifyIncomingRPC: true, + }} + verify := c.verifyIncomingRPC() + require.Equal(t, c.base.VerifyIncomingRPC, verify) +} - c.Update(&Config{VerifyIncomingHTTPS: true, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.IncomingRPCConfig() - require.NoError(t, err) - require.Equal(t, tls.NoClientCert, tlsConf.ClientAuth) +func TestConfigurator_VerifyIncomingHTTPS(t *testing.T) { + c := Configurator{base: &Config{ + VerifyIncomingHTTPS: true, + }} + verify := c.verifyIncomingHTTPS() + require.Equal(t, c.base.VerifyIncomingHTTPS, verify) } -func TestConfigurator_IncomingHTTPSConfig(t *testing.T) { - c := NewConfigurator(&Config{}) - tlsConf, err := c.IncomingHTTPSConfig() - require.NoError(t, err) - require.Equal(t, tls.NoClientCert, tlsConf.ClientAuth) +func TestConfigurator_EnableAgentTLSForChecks(t *testing.T) { + c := Configurator{base: &Config{ + EnableAgentTLSForChecks: true, + }} + enabled := c.enableAgentTLSForChecks() + require.Equal(t, c.base.EnableAgentTLSForChecks, enabled) +} - c.Update(&Config{VerifyIncoming: true, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.IncomingHTTPSConfig() - require.NoError(t, err) +func TestConfigurator_IncomingRPCConfig(t *testing.T) { + c, err := NewConfigurator(Config{ + VerifyIncomingRPC: true, + CAFile: "../test/ca/root.cer", + CertFile: "../test/key/ourdomain.cer", + KeyFile: "../test/key/ourdomain.key", + }, nil) + require.NoError(t, err) + tlsConf := c.IncomingRPCConfig() require.Equal(t, tls.RequireAndVerifyClientCert, tlsConf.ClientAuth) - - c.Update(&Config{VerifyIncomingHTTPS: true, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.IncomingHTTPSConfig() + require.NotNil(t, tlsConf.GetConfigForClient) + tlsConf, err = tlsConf.GetConfigForClient(nil) require.NoError(t, err) require.Equal(t, tls.RequireAndVerifyClientCert, tlsConf.ClientAuth) - - c.Update(&Config{VerifyIncomingRPC: true, CAFile: "../test/ca/root.cer", CertFile: "../test/key/ourdomain.cer", KeyFile: "../test/key/ourdomain.key"}) - tlsConf, err = c.IncomingHTTPSConfig() - require.NoError(t, err) - require.Equal(t, tls.NoClientCert, tlsConf.ClientAuth) } -func TestConfigurator_OutgoingRPCConfig(t *testing.T) { - c := NewConfigurator(&Config{}) - tlsConf, err := c.OutgoingRPCConfig() - require.NoError(t, err) - require.Nil(t, tlsConf) - - c.Update(&Config{VerifyOutgoing: true}) - tlsConf, err = c.OutgoingRPCConfig() - require.Error(t, err) - - c.Update(&Config{VerifyOutgoing: true, CAFile: "../test/ca/root.cer"}) - tlsConf, err = c.OutgoingRPCConfig() - require.NoError(t, err) - - c.Update(&Config{VerifyOutgoing: true, CAPath: "../test/ca_path"}) - tlsConf, err = c.OutgoingRPCConfig() - require.NoError(t, err) +func TestConfigurator_IncomingHTTPSConfig(t *testing.T) { + c := Configurator{base: &Config{}} + require.Equal(t, []string{"h2", "http/1.1"}, c.IncomingHTTPSConfig().NextProtos) } func TestConfigurator_OutgoingTLSConfigForChecks(t *testing.T) { - c := NewConfigurator(&Config{}) - tlsConf, err := c.OutgoingTLSConfigForCheck("") - require.NoError(t, err) - require.False(t, tlsConf.InsecureSkipVerify) + c := Configurator{base: &Config{ + TLSMinVersion: "tls12", + EnableAgentTLSForChecks: false, + }} + tlsConf := c.OutgoingTLSConfigForCheck(true) + require.Equal(t, true, tlsConf.InsecureSkipVerify) + require.Equal(t, uint16(0), tlsConf.MinVersion) + + c.base.EnableAgentTLSForChecks = true + c.base.ServerName = "servername" + tlsConf = c.OutgoingTLSConfigForCheck(true) + require.Equal(t, true, tlsConf.InsecureSkipVerify) + require.Equal(t, TLSLookup[c.base.TLSMinVersion], tlsConf.MinVersion) + require.Equal(t, c.base.ServerName, tlsConf.ServerName) +} - c.Update(&Config{EnableAgentTLSForChecks: true}) - tlsConf, err = c.OutgoingTLSConfigForCheck("") - require.NoError(t, err) - require.False(t, tlsConf.InsecureSkipVerify) +func TestConfigurator_OutgoingRPCConfig(t *testing.T) { + c := Configurator{base: &Config{}} + require.Nil(t, c.OutgoingRPCConfig()) + c.base.VerifyOutgoing = true + require.NotNil(t, c.OutgoingRPCConfig()) +} - c.AddCheck("c1", true) - c.Update(&Config{EnableAgentTLSForChecks: true}) - tlsConf, err = c.OutgoingTLSConfigForCheck("c1") - require.NoError(t, err) - require.True(t, tlsConf.InsecureSkipVerify) +func TestConfigurator_OutgoingRPCWrapper(t *testing.T) { + c := Configurator{base: &Config{}} + require.Nil(t, c.OutgoingRPCWrapper()) + c.base.VerifyOutgoing = true + wrap := c.OutgoingRPCWrapper() + require.NotNil(t, wrap) + t.Log("TODO: actually call wrap here eventually") +} - c.AddCheck("c1", false) - c.Update(&Config{EnableAgentTLSForChecks: true}) - tlsConf, err = c.OutgoingTLSConfigForCheck("c1") +func TestConfigurator_UpdateChecks(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) require.NoError(t, err) - require.False(t, tlsConf.InsecureSkipVerify) + require.NoError(t, c.Update(Config{})) + require.Error(t, c.Update(Config{VerifyOutgoing: true})) + require.Error(t, c.Update(Config{VerifyIncoming: true, + CAFile: "../test/ca/root.cer"})) + require.False(t, c.base.VerifyIncoming) + require.False(t, c.base.VerifyOutgoing) + require.Equal(t, c.version, 2) +} - c.AddCheck("c1", false) - c.Update(&Config{EnableAgentTLSForChecks: true}) - tlsConf, err = c.OutgoingTLSConfigForCheck("c1") +func TestConfigurator_UpdateSetsStuff(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) require.NoError(t, err) - require.False(t, tlsConf.InsecureSkipVerify) + require.Nil(t, c.cas) + require.Nil(t, c.cert) + require.Equal(t, c.base, &Config{}) + require.Equal(t, 1, c.version) - c.Update(&Config{EnableAgentTLSForChecks: true, NodeName: "node", ServerName: "server"}) - tlsConf, err = c.OutgoingTLSConfigForCheck("") - require.NoError(t, err) - require.Equal(t, "server", tlsConf.ServerName) + require.Error(t, c.Update(Config{VerifyOutgoing: true})) + require.Equal(t, c.version, 1) - c.Update(&Config{EnableAgentTLSForChecks: true, ServerName: "server"}) - tlsConf, err = c.OutgoingTLSConfigForCheck("") - require.NoError(t, err) - require.Equal(t, "server", tlsConf.ServerName) + config := Config{ + CAFile: "../test/ca/root.cer", + CertFile: "../test/key/ourdomain.cer", + KeyFile: "../test/key/ourdomain.key", + } + require.NoError(t, c.Update(config)) + require.NotNil(t, c.cas) + require.Len(t, c.cas.Subjects(), 1) + require.NotNil(t, c.cert) + require.Equal(t, c.base, &config) + require.Equal(t, 2, c.version) +} - c.Update(&Config{EnableAgentTLSForChecks: true, NodeName: "node"}) - tlsConf, err = c.OutgoingTLSConfigForCheck("") - require.NoError(t, err) - require.Equal(t, "node", tlsConf.ServerName) +func TestConfigurator_ServerNameOrNodeName(t *testing.T) { + c := Configurator{base: &Config{}} + type variant struct { + server, node, expected string + } + variants := []variant{ + {"", "", ""}, + {"a", "", "a"}, + {"", "b", "b"}, + {"a", "b", "a"}, + } + for _, v := range variants { + c.base.ServerName = v.server + c.base.NodeName = v.node + require.Equal(t, v.expected, c.serverNameOrNodeName()) + } } diff --git a/website/source/docs/agent/options.html.md b/website/source/docs/agent/options.html.md index c3a360bade3b..e0cdfd7ffcc8 100644 --- a/website/source/docs/agent/options.html.md +++ b/website/source/docs/agent/options.html.md @@ -1731,6 +1731,8 @@ items which are reloaded include: * Services * Watches * HTTP Client Address +* TLS Configuration + * Please be aware that this is currently limited to reload a configuration that is already TLS enabled. You cannot enable or disable TLS only with reloading. * Node Metadata * Metric Prefix Filter * Discard Check Output