Skip to content

Commit

Permalink
Issue #138: Add option to disable cert fallback
Browse files Browse the repository at this point in the history
The default behavior of the Go TLS server implementation
is to fall back to the first provided certificate if no
exact matching certificate could be found. It can be
desirable to disable this behavior to have more control
over when a TLS connection is established. This patch
adds a 'strictmatch' option to the listener which allows
to disable the default fallback behavior.

TODO(fs): need test for new getCertificate() function
  • Loading branch information
magiconair committed Aug 30, 2016
1 parent 34d6959 commit 6d9a165
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 25 deletions.
28 changes: 18 additions & 10 deletions cert/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@ import (
)

// Source provides the interface for dynamic certificate sources.
//
// Certificates() loads certificates for TLS connections.
// The first certificate is used as the default certificate
// if the client does not support SNI or no matching certificate
// could be found. TLS certificates can be updated at runtime.
//
// LoadClientCAs() provides certificates for client certificate
// authentication.
type Source interface {
// Certificates() loads certificates for TLS connections.
// The first certificate is used as the default certificate
// if the client does not support SNI or no matching certificate
// could be found. TLS certificates can be updated at runtime.
Certificates() chan []tls.Certificate

// LoadClientCAs() provides certificates for client certificate
// authentication.
LoadClientCAs() (*x509.CertPool, error)
}

Expand Down Expand Up @@ -72,6 +71,11 @@ func NewSource(cfg config.CertSource) (Source, error) {
}
}

const (
StrictMatch = true
NoStrictMatch = false
)

// TLSConfig creates a tls.Config which sets the
// GetCertificate field to a certificate store
// which uses the given source to update the
Expand All @@ -80,14 +84,18 @@ func NewSource(cfg config.CertSource) (Source, error) {
// It also sets the ClientCAs field if
// src.LoadClientCAs returns a non-nil value
// and sets ClientAuth to RequireAndVerifyClientCert.
func TLSConfig(src Source) (*tls.Config, error) {
func TLSConfig(src Source, strictMatch bool) (*tls.Config, error) {
clientCAs, err := src.LoadClientCAs()
if err != nil {
return nil, err
}

store := NewStore()
x := &tls.Config{GetCertificate: store.GetCertificate}
x := &tls.Config{
GetCertificate: func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
return getCertificate(store.certstore(), clientHello, strictMatch)
},
}

if clientCAs != nil {
x.ClientCAs = clientCAs
Expand Down
2 changes: 1 addition & 1 deletion cert/source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ func TestVaultSource(t *testing.T) {
// the HTTPS client can validate the certificate presented by the
// server.
func testSource(t *testing.T, source Source, rootCAs *x509.CertPool, sleep time.Duration) {
srvConfig, err := TLSConfig(source)
srvConfig, err := TLSConfig(source, NoStrictMatch)
if err != nil {
t.Fatalf("TLSConfig: got %q want nil", err)
}
Expand Down
24 changes: 13 additions & 11 deletions cert/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,18 @@ func (s *Store) SetCertificates(certs []tls.Certificate) {
log.Printf("[INFO] cert: Store has certificates for [%q]", strings.Join(names, ","))
}

// GetCertificate returns a matching certificate for the given clientHello if possible
// or the first certificate from the store.
func (s *Store) GetCertificate(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
return getCertificate(s.cs.Load().(certstore), clientHello)
func (s *Store) certstore() certstore {
return s.cs.Load().(certstore)
}

func getCertificate(cs certstore, clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
func getCertificate(cs certstore, clientHello *tls.ClientHelloInfo, strictMatch bool) (cert *tls.Certificate, err error) {
if len(cs.Certificates) == 0 {
return nil, errors.New("cert: no certificates certstoreured")
return nil, errors.New("cert: no certificates stored")
}

if len(cs.Certificates) == 1 || cs.NameToCertificate == nil {
// There's only one choice, so no point doing any work.
// There's only one choice, so no point doing any work.
// However, if fallback is disabled we need to check.
if !strictMatch && (len(cs.Certificates) == 1 || cs.NameToCertificate == nil) {
return &cs.Certificates[0], nil
}

Expand All @@ -59,8 +58,7 @@ func getCertificate(cs certstore, clientHello *tls.ClientHelloInfo) (cert *tls.C
return cert, nil
}

// try replacing labels in the name with wildcards until we get a
// match.
// try replacing labels in the name with wildcards until we get a match
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
Expand All @@ -70,7 +68,11 @@ func getCertificate(cs certstore, clientHello *tls.ClientHelloInfo) (cert *tls.C
}
}

// If nothing matches, return the first certificate.
// If nothing matches, return the first certificate
// unless fallback to the first cert is disabled.
if strictMatch {
return nil, nil
}
return &cs.Certificates[0], nil
}

Expand Down
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Listen struct {
ReadTimeout time.Duration
WriteTimeout time.Duration
CertSource CertSource
StrictMatch bool
}

type UI struct {
Expand Down
2 changes: 2 additions & 0 deletions config/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ func parseListen(cfg string, cs map[string]CertSource, readTimeout, writeTimeout
}
l.CertSource = c
l.Scheme = "https"
case "strictmatch":
l.StrictMatch = (v == "true")
}
}
return
Expand Down
12 changes: 12 additions & 0 deletions config/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ func TestParseListen(t *testing.T) {
},
"",
},
{
":123;cs=name;strictmatch=true",
Listen{
Addr: ":123",
Scheme: "https",
CertSource: CertSource{
Type: "foo",
},
StrictMatch: true,
},
"",
},
}

for i, tt := range tests {
Expand Down
12 changes: 10 additions & 2 deletions fabio.properties
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,16 @@
#
# General options:
#
# read timeout: rt=<duration>
# write timeout: wt=<duration>
# rt: Sets the read timeout as a duration value (e.g. '3s')
#
# wt: Sets the write timeout as a duration value (e.g. '3s')
#
# strictmatch: When set to 'true' the certificate source must provide
# a certificate that matches the hostname for the connection
# to be established. Otherwise, the first certificate is used
# if no matching certificate was found. This matches the default
# behavior of the Go TLS server implementation.
#
#
# HTTPS listeners require a certificate source which is
# configured by setting the 'cs' option to the name of
Expand Down
2 changes: 1 addition & 1 deletion listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func listenAndServe(l config.Listen, h http.Handler) {
exit.Fatal("[FATAL] ", err)
}

srv.TLSConfig, err = cert.TLSConfig(src)
srv.TLSConfig, err = cert.TLSConfig(src, l.StrictMatch)
if err != nil {
exit.Fatal("[FATAL] ", err)
}
Expand Down

0 comments on commit 6d9a165

Please sign in to comment.