Skip to content

Commit

Permalink
🐛 Refactor certificate watcher to use polling, instead of fsnotify (#…
Browse files Browse the repository at this point in the history
…3020)

* Reestablish watch for the certificate paths

* Remove fsnotify and use cached read watcher

* Simplify return
  • Loading branch information
m-messiah authored Nov 26, 2024
1 parent b88f351 commit 8e44a43
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 120 deletions.
1 change: 0 additions & 1 deletion examples/scratch-env/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/evanphx/json-patch/v5 v5.9.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/zapr v1.3.0 // indirect
Expand Down
2 changes: 0 additions & 2 deletions examples/scratch-env/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8
github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ=
github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg=
github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.23.0

require (
github.com/evanphx/json-patch/v5 v5.9.0
github.com/fsnotify/fsnotify v1.7.0
github.com/go-logr/logr v1.4.2
github.com/go-logr/zapr v1.3.0
github.com/google/go-cmp v0.6.0
Expand Down Expand Up @@ -41,6 +40,7 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
Expand Down
161 changes: 57 additions & 104 deletions pkg/certwatcher/certwatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,58 +17,55 @@ limitations under the License.
package certwatcher

import (
"bytes"
"context"
"crypto/tls"
"fmt"
"os"
"sync"
"time"

"github.com/fsnotify/fsnotify"
kerrors "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
"sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics"
logf "sigs.k8s.io/controller-runtime/pkg/internal/log"
)

var log = logf.RuntimeLog.WithName("certwatcher")

// CertWatcher watches certificate and key files for changes. When either file
// changes, it reads and parses both and calls an optional callback with the new
// certificate.
const defaultWatchInterval = 10 * time.Second

// CertWatcher watches certificate and key files for changes.
// It always returns the cached version,
// but periodically reads and parses certificate and key for changes
// and calls an optional callback with the new certificate.
type CertWatcher struct {
sync.RWMutex

currentCert *tls.Certificate
watcher *fsnotify.Watcher
interval time.Duration

certPath string
keyPath string

cachedKeyPEMBlock []byte

// callback is a function to be invoked when the certificate changes.
callback func(tls.Certificate)
}

// New returns a new CertWatcher watching the given certificate and key.
func New(certPath, keyPath string) (*CertWatcher, error) {
var err error

cw := &CertWatcher{
certPath: certPath,
keyPath: keyPath,
interval: defaultWatchInterval,
}

// Initial read of certificate and key.
if err := cw.ReadCertificate(); err != nil {
return nil, err
}

cw.watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, err
}
return cw, cw.ReadCertificate()
}

return cw, nil
// WithWatchInterval sets the watch interval and returns the CertWatcher pointer
func (cw *CertWatcher) WithWatchInterval(interval time.Duration) *CertWatcher {
cw.interval = interval
return cw
}

// RegisterCallback registers a callback to be invoked when the certificate changes.
Expand All @@ -91,72 +88,64 @@ func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate,

// Start starts the watch on the certificate and key files.
func (cw *CertWatcher) Start(ctx context.Context) error {
files := sets.New(cw.certPath, cw.keyPath)

{
var watchErr error
if err := wait.PollUntilContextTimeout(ctx, 1*time.Second, 10*time.Second, true, func(ctx context.Context) (done bool, err error) {
for _, f := range files.UnsortedList() {
if err := cw.watcher.Add(f); err != nil {
watchErr = err
return false, nil //nolint:nilerr // We want to keep trying.
}
// We've added the watch, remove it from the set.
files.Delete(f)
}
return true, nil
}); err != nil {
return fmt.Errorf("failed to add watches: %w", kerrors.NewAggregate([]error{err, watchErr}))
}
}

go cw.Watch()
ticker := time.NewTicker(cw.interval)
defer ticker.Stop()

log.Info("Starting certificate watcher")

// Block until the context is done.
<-ctx.Done()

return cw.watcher.Close()
}

// Watch reads events from the watcher's channel and reacts to changes.
func (cw *CertWatcher) Watch() {
for {
select {
case event, ok := <-cw.watcher.Events:
// Channel is closed.
if !ok {
return
case <-ctx.Done():
return nil
case <-ticker.C:
if err := cw.ReadCertificate(); err != nil {
log.Error(err, "failed read certificate")
}
}
}
}

cw.handleEvent(event)

case err, ok := <-cw.watcher.Errors:
// Channel is closed.
if !ok {
return
}
// updateCachedCertificate checks if the new certificate differs from the cache,
// updates it and returns the result if it was updated or not
func (cw *CertWatcher) updateCachedCertificate(cert *tls.Certificate, keyPEMBlock []byte) bool {
cw.Lock()
defer cw.Unlock()

log.Error(err, "certificate watch error")
}
if cw.currentCert != nil &&
bytes.Equal(cw.currentCert.Certificate[0], cert.Certificate[0]) &&
bytes.Equal(cw.cachedKeyPEMBlock, keyPEMBlock) {
log.V(7).Info("certificate already cached")
return false
}
cw.currentCert = cert
cw.cachedKeyPEMBlock = keyPEMBlock
return true
}

// ReadCertificate reads the certificate and key files from disk, parses them,
// and updates the current certificate on the watcher. If a callback is set, it
// and updates the current certificate on the watcher if updated. If a callback is set, it
// is invoked with the new certificate.
func (cw *CertWatcher) ReadCertificate() error {
metrics.ReadCertificateTotal.Inc()
cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath)
certPEMBlock, err := os.ReadFile(cw.certPath)
if err != nil {
metrics.ReadCertificateErrors.Inc()
return err
}
keyPEMBlock, err := os.ReadFile(cw.keyPath)
if err != nil {
metrics.ReadCertificateErrors.Inc()
return err
}

cw.Lock()
cw.currentCert = &cert
cw.Unlock()
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
metrics.ReadCertificateErrors.Inc()
return err
}

if !cw.updateCachedCertificate(&cert, keyPEMBlock) {
return nil
}

log.Info("Updated current TLS certificate")

Expand All @@ -170,39 +159,3 @@ func (cw *CertWatcher) ReadCertificate() error {
}
return nil
}

func (cw *CertWatcher) handleEvent(event fsnotify.Event) {
// Only care about events which may modify the contents of the file.
if !(isWrite(event) || isRemove(event) || isCreate(event) || isChmod(event)) {
return
}

log.V(1).Info("certificate event", "event", event)

// If the file was removed or renamed, re-add the watch to the previous name
if isRemove(event) || isChmod(event) {
if err := cw.watcher.Add(event.Name); err != nil {
log.Error(err, "error re-watching file")
}
}

if err := cw.ReadCertificate(); err != nil {
log.Error(err, "error re-reading certificate")
}
}

func isWrite(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Write)
}

func isCreate(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Create)
}

func isRemove(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Remove)
}

func isChmod(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Chmod)
}
1 change: 1 addition & 0 deletions pkg/certwatcher/certwatcher_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
)
Expand Down
50 changes: 39 additions & 11 deletions pkg/certwatcher/certwatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/prometheus/client_golang/prometheus/testutil"

"sigs.k8s.io/controller-runtime/pkg/certwatcher"
"sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics"
)
Expand Down Expand Up @@ -80,7 +81,7 @@ var _ = Describe("CertWatcher", func() {
go func() {
defer GinkgoRecover()
defer close(doneCh)
Expect(watcher.Start(ctx)).To(Succeed())
Expect(watcher.WithWatchInterval(time.Second).Start(ctx)).To(Succeed())
}()
// wait till we read first cert
Eventually(func() error {
Expand Down Expand Up @@ -113,7 +114,7 @@ var _ = Describe("CertWatcher", func() {
Eventually(func() bool {
secondcert, _ := watcher.GetCertificate(nil)
first := firstcert.PrivateKey.(*rsa.PrivateKey)
return first.Equal(secondcert.PrivateKey)
return first.Equal(secondcert.PrivateKey) || firstcert.Leaf.SerialNumber == secondcert.Leaf.SerialNumber
}).ShouldNot(BeTrue())

ctxCancel()
Expand Down Expand Up @@ -143,14 +144,41 @@ var _ = Describe("CertWatcher", func() {
Eventually(func() bool {
secondcert, _ := watcher.GetCertificate(nil)
first := firstcert.PrivateKey.(*rsa.PrivateKey)
return first.Equal(secondcert.PrivateKey)
return first.Equal(secondcert.PrivateKey) || firstcert.Leaf.SerialNumber == secondcert.Leaf.SerialNumber
}).ShouldNot(BeTrue())

ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
Expect(called.Load()).To(BeNumerically(">=", 1))
})

It("should reload currentCert after move out", func() {
doneCh := startWatcher()
called := atomic.Int64{}
watcher.RegisterCallback(func(crt tls.Certificate) {
called.Add(1)
Expect(crt.Certificate).ToNot(BeEmpty())
})

firstcert, _ := watcher.GetCertificate(nil)

Expect(os.Rename(certPath, certPath+".old")).To(Succeed())
Expect(os.Rename(keyPath, keyPath+".old")).To(Succeed())

err := writeCerts(certPath, keyPath, "192.168.0.3")
Expect(err).ToNot(HaveOccurred())

Eventually(func() bool {
secondcert, _ := watcher.GetCertificate(nil)
first := firstcert.PrivateKey.(*rsa.PrivateKey)
return first.Equal(secondcert.PrivateKey) || firstcert.Leaf.SerialNumber == secondcert.Leaf.SerialNumber
}, "10s", "1s").ShouldNot(BeTrue())

ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
Expect(called.Load()).To(BeNumerically(">=", 1))
})

Context("prometheus metric read_certificate_total", func() {
var readCertificateTotalBefore float64
var readCertificateErrorsBefore float64
Expand All @@ -165,8 +193,8 @@ var _ = Describe("CertWatcher", func() {

Eventually(func() error {
readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
if readCertificateTotalAfter != readCertificateTotalBefore+1.0 {
return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
if readCertificateTotalAfter < readCertificateTotalBefore+1.0 {
return fmt.Errorf("metric read certificate total expected at least: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
}
return nil
}, "4s").Should(Succeed())
Expand All @@ -180,8 +208,8 @@ var _ = Describe("CertWatcher", func() {

Eventually(func() error {
readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
if readCertificateTotalAfter != readCertificateTotalBefore+1.0 {
return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
if readCertificateTotalAfter < readCertificateTotalBefore+1.0 {
return fmt.Errorf("metric read certificate total expected at least: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
}
readCertificateTotalBefore = readCertificateTotalAfter
return nil
Expand All @@ -192,15 +220,15 @@ var _ = Describe("CertWatcher", func() {
// Note, we are checking two errors here, because os.Remove generates two fsnotify events: Chmod + Remove
Eventually(func() error {
readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
if readCertificateTotalAfter != readCertificateTotalBefore+2.0 {
return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+2.0, readCertificateTotalAfter)
if readCertificateTotalAfter < readCertificateTotalBefore+2.0 {
return fmt.Errorf("metric read certificate total expected at least: %v and got: %v", readCertificateTotalBefore+2.0, readCertificateTotalAfter)
}
return nil
}, "4s").Should(Succeed())
Eventually(func() error {
readCertificateErrorsAfter := testutil.ToFloat64(metrics.ReadCertificateErrors)
if readCertificateErrorsAfter != readCertificateErrorsBefore+2.0 {
return fmt.Errorf("metric read certificate errors expected: %v and got: %v", readCertificateErrorsBefore+2.0, readCertificateErrorsAfter)
if readCertificateErrorsAfter < readCertificateErrorsBefore+2.0 {
return fmt.Errorf("metric read certificate errors expected at least: %v and got: %v", readCertificateErrorsBefore+2.0, readCertificateErrorsAfter)
}
return nil
}, "4s").Should(Succeed())
Expand Down
2 changes: 1 addition & 1 deletion pkg/certwatcher/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func Example() {
panic(err)
}

// Start goroutine with certwatcher running fsnotify against supplied certdir
// Start goroutine with certwatcher running against supplied cert
go func() {
if err := watcher.Start(ctx); err != nil {
panic(err)
Expand Down
1 change: 1 addition & 0 deletions pkg/certwatcher/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package metrics

import (
"github.com/prometheus/client_golang/prometheus"

"sigs.k8s.io/controller-runtime/pkg/metrics"
)

Expand Down

0 comments on commit 8e44a43

Please sign in to comment.