diff --git a/cmd/extensions/main.go b/cmd/extensions/main.go index 33c0411004..02fcbb8847 100644 --- a/cmd/extensions/main.go +++ b/cmd/extensions/main.go @@ -139,6 +139,11 @@ func main() { } // https server and the items that share the Mux for routing httpsServer := https.NewServer(ctlConf.CertFile, ctlConf.KeyFile) + cancelTLS, err := httpsServer.WatchForCertificateChanges() + if err != nil { + logger.WithError(err).Fatal("Got an error while watching certificate changes") + } + defer cancelTLS() wh := webhooks.NewWebHook(httpsServer.Mux) api := apiserver.NewAPIServer(httpsServer.Mux) diff --git a/pkg/util/https/server.go b/pkg/util/https/server.go index b1b6311571..34a2a471b8 100644 --- a/pkg/util/https/server.go +++ b/pkg/util/https/server.go @@ -16,50 +16,111 @@ package https import ( "context" + cryptotls "crypto/tls" "net/http" + "sync" + "time" + "agones.dev/agones/pkg/util/fswatch" "agones.dev/agones/pkg/util/runtime" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) +const ( + tlsDir = "/certs/" +) + // tls is a http server interface to enable easier testing type tls interface { Close() error ListenAndServeTLS(certFile, keyFile string) error } +// certServer holds the Server certificate +type certServer struct { + certs *cryptotls.Certificate + certMu sync.Mutex +} + // Server is a HTTPs server that conforms to the runner interface // we use in /cmd/controller, and has a public Mux that can be updated // has a default 404 handler, to make discovery of k8s services a bit easier. type Server struct { - logger *logrus.Entry - Mux *http.ServeMux - tls tls - certFile string - keyFile string + certServer certServer + logger *logrus.Entry + Mux *http.ServeMux + tls tls + certFile string + keyFile string } // NewServer returns a Server instance. func NewServer(certFile, keyFile string) *Server { mux := http.NewServeMux() - tls := &http.Server{ - Addr: ":8081", - Handler: mux, - } wh := &Server{ Mux: mux, - tls: tls, certFile: certFile, keyFile: keyFile, } - wh.Mux.HandleFunc("/", wh.defaultHandler) wh.logger = runtime.NewLoggerWithType(wh) + wh.setupServer() + wh.Mux.HandleFunc("/", wh.defaultHandler) return wh } +func (s *Server) setupServer() { + s.tls = &http.Server{ + Addr: ":8081", + Handler: s.Mux, + TLSConfig: &cryptotls.Config{ + GetCertificate: s.getCertificate, + }, + } + + tlsCert, err := cryptotls.LoadX509KeyPair(tlsDir+"server.crt", tlsDir+"server.key") + if err != nil { + s.logger.WithError(err).Error("could not load Initial TLS certs; keeping old one") + return + } + + s.certServer.certMu.Lock() + defer s.certServer.certMu.Unlock() + s.certServer.certs = &tlsCert +} + +// getCertificate returns the current TLS certificate +func (s *Server) getCertificate(hello *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { + s.certServer.certMu.Lock() + defer s.certServer.certMu.Unlock() + return s.certServer.certs, nil +} + +// WatchForCertificateChanges watches for changes in the certificate files +func (s *Server) WatchForCertificateChanges() (func(), error) { + + cancelTLS, err := fswatch.Watch(s.logger, tlsDir, time.Second, func() { + // Load the new TLS certificate + s.logger.Info("TLS certs changed, reloading") + tlsCert, err := cryptotls.LoadX509KeyPair(tlsDir+"server.crt", tlsDir+"server.key") + if err != nil { + s.logger.WithError(err).Error("could not load TLS certs; keeping old one") + return + } + s.certServer.certMu.Lock() + defer s.certServer.certMu.Unlock() + s.certServer.certs = &tlsCert + s.logger.Info("TLS certs updated") + }) + if err != nil { + s.logger.WithError(err).Fatal("could not create watcher for TLS certs") + return nil, err + } + return cancelTLS, nil +} + // Run runs the webhook server, starting a https listener. // Will close the http server on stop channel close. func (s *Server) Run(ctx context.Context, _ int) error {