-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
febbbd3
commit 232908b
Showing
4 changed files
with
341 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
FROM squidfunk/mkdocs-material:9.5 | ||
RUN pip install mkdocs-include-markdown-plugin | ||
RUN pip install mkdocs-include-markdown-plugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package certreloader | ||
|
||
import ( | ||
"crypto/tls" | ||
"fmt" | ||
"sync" | ||
"time" | ||
) | ||
|
||
type Config struct { | ||
keyPath string | ||
certPath string | ||
reloadInterval time.Duration | ||
} | ||
|
||
type certReloader struct { | ||
cert *tls.Certificate | ||
mu sync.RWMutex | ||
nextReload time.Time | ||
Config | ||
} | ||
|
||
func NewCertReloader(config Config) (*certReloader, error) { | ||
reloader := certReloader{ | ||
Config: config, | ||
} | ||
|
||
reloader.mu.Lock() | ||
defer reloader.mu.Unlock() | ||
cert, err := reloader.loadCertificate() | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to load initial certificate: %w", err) | ||
} | ||
reloader.cert = &cert | ||
|
||
return &reloader, nil | ||
} | ||
|
||
func (r *certReloader) GetCertificate() (*tls.Certificate, error) { | ||
now := time.Now() | ||
// Read locking here before we do the time comparison | ||
// If a reload is in progress this will block and we will skip reloading in the current | ||
// call once we can continue | ||
r.mu.RLock() | ||
shouldReload := r.reloadInterval != 0 && r.nextReload.Before(now) | ||
r.mu.RUnlock() | ||
if shouldReload { | ||
// Need to release the read lock, otherwise we deadlock | ||
r.mu.Lock() | ||
defer r.mu.Unlock() | ||
cert, err := r.loadCertificate() | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to load TLS cert and key: %w", err) | ||
} | ||
r.cert = &cert | ||
r.nextReload = now.Add(r.reloadInterval) | ||
return r.cert, nil | ||
} | ||
return r.cert, nil | ||
} | ||
|
||
func (c *certReloader) loadCertificate() (tls.Certificate, error) { | ||
newCert, err := tls.LoadX509KeyPair(c.certPath, c.keyPath) | ||
if err != nil { | ||
return tls.Certificate{}, fmt.Errorf("failed to load key pair: %w", err) | ||
} | ||
|
||
return newCert, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
package certreloader | ||
|
||
import ( | ||
"bytes" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"crypto/x509/pkix" | ||
"encoding/pem" | ||
"fmt" | ||
"io" | ||
"math/big" | ||
"net" | ||
"os" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestNewCertReloader(t *testing.T) { | ||
cert1, key1, cleanup := generateValidCertificateFiles(t) | ||
defer cleanup() | ||
_, key2, cleanup := generateValidCertificateFiles(t) | ||
defer cleanup() | ||
|
||
tcs := []struct { | ||
name string | ||
config Config | ||
err error | ||
}{ | ||
{ | ||
name: "no config set", | ||
config: Config{}, | ||
err: fmt.Errorf("failed to load initial certificate: failed to load key pair: open : no such file or directory"), | ||
}, | ||
{ | ||
name: "invalid certs", | ||
config: Config{certPath: cert1, keyPath: key2}, | ||
err: fmt.Errorf("failed to load initial certificate: failed to load key pair: tls: private key does not match public key"), | ||
}, | ||
|
||
{ | ||
name: "valid certs", | ||
config: Config{certPath: cert1, keyPath: key1}, | ||
err: nil, | ||
}, | ||
} | ||
|
||
for _, tc := range tcs { | ||
t.Run(tc.name, func(t *testing.T) { | ||
reloader, err := NewCertReloader(tc.config) | ||
if err != nil { | ||
if tc.err == nil { | ||
t.Fatalf("NewCertReloader returned error when no error was expected: %s", err) | ||
} else if tc.err.Error() != err.Error() { | ||
t.Fatalf("expected error did not matched received error. expected: %v, received: %v", tc.err, err) | ||
} | ||
} else { | ||
if reloader == nil { | ||
t.Fatal("expected reloader to not be nil") | ||
} | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestCertificateReload(t *testing.T) { | ||
newCert, newKey, cleanup := generateValidCertificateFiles(t) | ||
defer cleanup() | ||
|
||
tcs := []struct { | ||
name string | ||
reloadInterval time.Duration | ||
newCert string | ||
newKey string | ||
err error | ||
}{ | ||
{ | ||
name: "reloads after interval", | ||
reloadInterval: time.Microsecond * 100, | ||
newCert: newCert, | ||
newKey: newKey, | ||
err: nil, | ||
}, | ||
} | ||
|
||
for _, tc := range tcs { | ||
t.Run(tc.name, func(t *testing.T) { | ||
cert, key, cleanup := generateValidCertificateFiles(t) | ||
defer cleanup() | ||
reloader, err := NewCertReloader(Config{ | ||
certPath: cert, | ||
keyPath: key, | ||
reloadInterval: tc.reloadInterval, | ||
}) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
// TODO: copy instead of rename | ||
if err := copyFile(tc.newCert, cert); err != nil { | ||
t.Fatalf("failed to move %s -> %s: %s", newCert, cert, err) | ||
} | ||
if err := copyFile(tc.newKey, key); err != nil { | ||
t.Fatalf("failed to move %s -> %s: %s", newKey, key, err) | ||
} | ||
time.Sleep(tc.reloadInterval * 2) | ||
|
||
actualCert, err := reloader.GetCertificate() | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
actualCertParsed, err := x509.ParseCertificate(actualCert.Certificate[0]) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
expectedCert, err := tls.LoadX509KeyPair(tc.newCert, tc.newKey) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
expectedCertParsed, err := x509.ParseCertificate(expectedCert.Certificate[0]) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if expectedCertParsed.DNSNames[0] != actualCertParsed.DNSNames[0] { | ||
t.Fatalf("expected certificate was not returned by GetCertificate. expectedCert: %v, actualCert: %v", expectedCertParsed.DNSNames[0], actualCertParsed.DNSNames[0]) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func generateValidCertificate(t *testing.T) (*bytes.Buffer, *bytes.Buffer) { | ||
t.Helper() | ||
|
||
// set up our CA certificate | ||
ca := &x509.Certificate{ | ||
SerialNumber: big.NewInt(2019), | ||
Subject: pkix.Name{ | ||
Organization: []string{"Company, INC."}, | ||
Country: []string{"US"}, | ||
Province: []string{""}, | ||
Locality: []string{"San Francisco"}, | ||
StreetAddress: []string{"Golden Gate Bridge"}, | ||
PostalCode: []string{"94016"}, | ||
}, | ||
NotBefore: time.Now(), | ||
NotAfter: time.Now().AddDate(10, 0, 0), | ||
IsCA: true, | ||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, | ||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, | ||
BasicConstraintsValid: true, | ||
} | ||
|
||
// create our private and public key | ||
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
// create the CA | ||
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
// pem encode | ||
caPEM := new(bytes.Buffer) | ||
pem.Encode(caPEM, &pem.Block{ | ||
Type: "CERTIFICATE", | ||
Bytes: caBytes, | ||
}) | ||
|
||
caPrivKeyPEM := new(bytes.Buffer) | ||
pem.Encode(caPrivKeyPEM, &pem.Block{ | ||
Type: "RSA PRIVATE KEY", | ||
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey), | ||
}) | ||
|
||
// set up our server certificate | ||
cert := &x509.Certificate{ | ||
SerialNumber: big.NewInt(2019), | ||
Subject: pkix.Name{ | ||
Organization: []string{"Company, INC."}, | ||
Country: []string{"US"}, | ||
Province: []string{""}, | ||
Locality: []string{"San Francisco"}, | ||
StreetAddress: []string{"Golden Gate Bridge"}, | ||
PostalCode: []string{"94016"}, | ||
}, | ||
DNSNames: []string{randString(8)}, | ||
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, | ||
NotBefore: time.Now(), | ||
NotAfter: time.Now().AddDate(10, 0, 0), | ||
SubjectKeyId: []byte{1, 2, 3, 4, 6}, | ||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, | ||
KeyUsage: x509.KeyUsageDigitalSignature, | ||
} | ||
|
||
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) | ||
if err != nil { | ||
t.Fatalf("failed to create private key: %s", err) | ||
} | ||
|
||
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) | ||
if err != nil { | ||
t.Fatalf("failed to create certificate: %s", err) | ||
} | ||
|
||
certPEM := new(bytes.Buffer) | ||
pem.Encode(certPEM, &pem.Block{ | ||
Type: "CERTIFICATE", | ||
Bytes: certBytes, | ||
}) | ||
|
||
certPrivKeyPEM := new(bytes.Buffer) | ||
pem.Encode(certPrivKeyPEM, &pem.Block{ | ||
Type: "RSA PRIVATE KEY", | ||
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), | ||
}) | ||
|
||
return certPEM, certPrivKeyPEM | ||
} | ||
|
||
func generateValidCertificateFiles(t *testing.T) (string, string, func()) { | ||
t.Helper() | ||
certFile, err := os.CreateTemp("", "certreloader_cert") | ||
if err != nil { | ||
t.Fatalf("failed to create certFile: %s", err) | ||
} | ||
defer certFile.Close() | ||
keyFile, err := os.CreateTemp("", "certreloader_key") | ||
if err != nil { | ||
t.Fatalf("failed to create keyFile: %s", err) | ||
} | ||
defer keyFile.Close() | ||
|
||
certBytes, keyBytes := generateValidCertificate(t) | ||
if _, err := io.Copy(certFile, certBytes); err != nil { | ||
t.Fatalf("failed to copy certBytes into %s: %s", certFile.Name(), err) | ||
} | ||
if _, err := io.Copy(keyFile, keyBytes); err != nil { | ||
t.Fatalf("failed to copy keyBytes into %s: %s", keyFile.Name(), err) | ||
} | ||
|
||
return certFile.Name(), keyFile.Name(), func() { | ||
os.Remove(certFile.Name()) | ||
os.Remove(keyFile.Name()) | ||
} | ||
} | ||
|
||
func copyFile(src, dst string) error { | ||
data, err := os.ReadFile(src) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return os.WriteFile(dst, data, 0o777) | ||
} | ||
|
||
func randString(n int) string { | ||
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" | ||
bytes := make([]byte, n) | ||
rand.Read(bytes) | ||
for i, b := range bytes { | ||
bytes[i] = alphanum[b%byte(len(alphanum))] | ||
} | ||
return string(bytes) | ||
} |