Skip to content

Commit

Permalink
TLS certs reloader (dexidp#2964)
Browse files Browse the repository at this point in the history
Signed-off-by: Sean Liao <sean+git@liao.dev>
  • Loading branch information
seankhliao authored and michaelliau committed Oct 4, 2023
1 parent c9c020c commit 3256ffd
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 29 deletions.
165 changes: 136 additions & 29 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ import (
"net/http"
"net/http/pprof"
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"sync/atomic"
"syscall"
"time"

gosundheit "github.com/AppsFlyer/go-sundheit"
"github.com/AppsFlyer/go-sundheit/checks"
gosundheithttp "github.com/AppsFlyer/go-sundheit/http"
"github.com/fsnotify/fsnotify"
"github.com/ghodss/yaml"
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/oklog/run"
Expand Down Expand Up @@ -142,41 +146,26 @@ func runServe(options serveOptions) error {
}

if c.GRPC.TLSCert != "" {
// Parse certificates from certificate file and key file for server.
cert, err := tls.LoadX509KeyPair(c.GRPC.TLSCert, c.GRPC.TLSKey)
if err != nil {
return fmt.Errorf("invalid config: error parsing gRPC certificate file: %v", err)
}

tlsConfig := tls.Config{
Certificates: []tls.Certificate{cert},
baseTLSConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: allowedTLSCiphers,
PreferServerCipherSuites: true,
}

if c.GRPC.TLSClientCA != "" {
// Parse certificates from client CA file to a new CertPool.
cPool := x509.NewCertPool()
clientCert, err := os.ReadFile(c.GRPC.TLSClientCA)
if err != nil {
return fmt.Errorf("invalid config: reading from client CA file: %v", err)
}
if !cPool.AppendCertsFromPEM(clientCert) {
return errors.New("invalid config: failed to parse client CA")
}

tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = cPool
tlsConfig, err := newTLSReloader(logger, c.GRPC.TLSCert, c.GRPC.TLSKey, c.GRPC.TLSClientCA, baseTLSConfig)
if err != nil {
return fmt.Errorf("invalid config: get gRPC TLS: %v", err)
}

if c.GRPC.TLSClientCA != "" {
// Only add metrics if client auth is enabled
grpcOptions = append(grpcOptions,
grpc.StreamInterceptor(grpcMetrics.StreamServerInterceptor()),
grpc.UnaryInterceptor(grpcMetrics.UnaryServerInterceptor()),
)
}

grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(&tlsConfig)))
grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(tlsConfig)))
}

s, err := c.Storage.Config.Open(logger)
Expand Down Expand Up @@ -431,18 +420,25 @@ func runServe(options serveOptions) error {
return fmt.Errorf("listening (%s) on %s: %v", name, c.Web.HTTPS, err)
}

baseTLSConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: allowedTLSCiphers,
PreferServerCipherSuites: true,
}

tlsConfig, err := newTLSReloader(logger, c.Web.TLSCert, c.Web.TLSKey, "", baseTLSConfig)
if err != nil {
return fmt.Errorf("invalid config: get HTTP TLS: %v", err)
}

server := &http.Server{
Handler: serv,
TLSConfig: &tls.Config{
CipherSuites: allowedTLSCiphers,
PreferServerCipherSuites: true,
MinVersion: tls.VersionTLS12,
},
Handler: serv,
TLSConfig: tlsConfig,
}
defer server.Close()

group.Add(func() error {
return server.ServeTLS(l, c.Web.TLSCert, c.Web.TLSKey)
return server.ServeTLS(l, "", "")
}, func(err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
Expand Down Expand Up @@ -563,3 +559,114 @@ func pprofHandler(router *http.ServeMux) {
router.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
router.HandleFunc("/debug/pprof/trace", pprof.Trace)
}

// newTLSReloader returns a [tls.Config] with GetCertificate or GetConfigForClient set
// to reload certificates from the given paths on SIGHUP or on file creates (atomic update via rename).
func newTLSReloader(logger log.Logger, certFile, keyFile, caFile string, baseConfig *tls.Config) (*tls.Config, error) {
// trigger reload on channel
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGHUP)

// files to watch
watchFiles := map[string]struct{}{
certFile: {},
keyFile: {},
}
if caFile != "" {
watchFiles[caFile] = struct{}{}
}
watchDirs := make(map[string]struct{}) // dedupe dirs
for f := range watchFiles {
dir := filepath.Dir(f)
if !strings.HasPrefix(f, dir) {
// normalize name to have ./ prefix if only a local path was provided
// can't pass "" to watcher.Add
watchFiles[dir+string(filepath.Separator)+f] = struct{}{}
}
watchDirs[dir] = struct{}{}
}
// trigger reload on file change
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("create watcher for TLS reloader: %v", err)
}
// recommended by fsnotify: watch the dir to handle renames
// https://pkg.go.dev/github.com/fsnotify/fsnotify#hdr-Watching_files
for dir := range watchDirs {
logger.Debugf("watching dir: %v", dir)
err := watcher.Add(dir)
if err != nil {
return nil, fmt.Errorf("watch dir for TLS reloader: %v", err)
}
}

// load once outside the goroutine so we can return an error on misconfig
initialConfig, err := loadTLSConfig(certFile, keyFile, caFile, baseConfig)
if err != nil {
return nil, fmt.Errorf("load TLS config: %v", err)
}

// stored version of current tls config
ptr := &atomic.Pointer[tls.Config]{}
ptr.Store(initialConfig)

// start background worker to reload certs
go func() {
loop:
for {
select {
case sig := <-sigc:
logger.Debug("reloading cert from signal: %v", sig)
case evt := <-watcher.Events:
if _, ok := watchFiles[evt.Name]; !ok || !evt.Has(fsnotify.Create) {
continue loop
}
logger.Debug("reloading cert from fsnotify: %v %v", evt.Name, evt.Op.String())
case err := <-watcher.Errors:
logger.Errorf("TLS reloader watch: %v", err)
}

loaded, err := loadTLSConfig(certFile, keyFile, caFile, baseConfig)
if err != nil {
logger.Errorf("reload TLS config: %v", err)
}
ptr.Store(loaded)
}
}()

conf := &tls.Config{}
// https://pkg.go.dev/crypto/tls#baseConfig
// Server configurations must set one of Certificates, GetCertificate or GetConfigForClient.
if caFile != "" {
// grpc will use this via tls.Server for mTLS
conf.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { return ptr.Load(), nil }
} else {
// net/http only uses Certificates or GetCertificate
conf.GetCertificate = func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { return &ptr.Load().Certificates[0], nil }
}
return conf, nil
}

// loadTLSConfig loads the given file paths into a [tls.Config]
func loadTLSConfig(certFile, keyFile, caFile string, baseConfig *tls.Config) (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("loading TLS keypair: %v", err)
}
loadedConfig := baseConfig.Clone() // copy
loadedConfig.Certificates = []tls.Certificate{cert}
if caFile != "" {
cPool := x509.NewCertPool()
clientCert, err := os.ReadFile(caFile)
if err != nil {
return nil, fmt.Errorf("reading from client CA file: %v", err)
}
if !cPool.AppendCertsFromPEM(clientCert) {
return nil, errors.New("failed to parse client CA")
}

loadedConfig.ClientAuth = tls.RequireAndVerifyClientCert
loadedConfig.ClientCAs = cPool
}
return loadedConfig, nil
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/coreos/go-oidc/v3 v3.6.0
github.com/dexidp/dex/api/v2 v2.1.0
github.com/felixge/httpsnoop v1.0.3
github.com/fsnotify/fsnotify v1.6.0
github.com/ghodss/yaml v1.0.0
github.com/go-ldap/ldap/v3 v3.4.4
github.com/go-sql-driver/mysql v1.7.1
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBd
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-asn1-ber/asn1-ber v1.5.4 h1:vXT6d/FNDiELJnLb6hGNa309LMsrCoYFvpwHDF0+Y1A=
Expand Down Expand Up @@ -303,6 +305,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down

0 comments on commit 3256ffd

Please sign in to comment.