Skip to content

Commit

Permalink
util,server: Automatically create TLS certificates
Browse files Browse the repository at this point in the history
If no `ssl-cert` or `ssl-key` are specified: Create a self signed
cert in the temp storage and use that.

This allows TLS to be used when no user created certificates are
available.

Especially for `tiup playground` and other simple cases this should be
sufficient.

Note that for `caching_sha2_password` support we will either need TLS
connections or RSA keypairs. This brings us a step closer in that
direction.

The created certificate are valid for 90 days and new certificates are
created every 30 days.

See also:
- "Automatic SSL and RSA File Generation" on https://dev.mysql.com/doc/refman/8.0/en/creating-ssl-rsa-files-using-mysql.html
- https://docs.pingcap.com/tidb/stable/enable-tls-between-clients-and-servers
- pingcap#9411
- pingcap#18084
  • Loading branch information
dveeden committed Aug 2, 2021
1 parent af39de5 commit 7763c4f
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 18 deletions.
2 changes: 1 addition & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
if s.ReloadTLS {
logutil.BgLogger().Info("execute reload tls", zap.Bool("NoRollbackOnError", s.NoRollbackOnError))
sm := e.ctx.GetSessionManager()
tlsCfg, err := util.LoadTLSCertificates(
tlsCfg, _, err := util.LoadTLSCertificates(
variable.GetSysVar("ssl_ca").Value,
variable.GetSysVar("ssl_key").Value,
variable.GetSysVar("ssl_cert").Value,
Expand Down
18 changes: 17 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,23 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
globalConnID: util.GlobalConnID{ServerID: 0, Is64bits: true},
}
setTxnScope()
tlsConfig, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
tlsConfig, autoReload, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)

// Automatically reload auto-generated certificates.
// The certificates are re-created every 30 days and are valid for 90 days.
if autoReload {
go func() {
for range time.Tick(time.Hour * 24 * 30) { // 30 days
logutil.BgLogger().Info("Rotating automatically created TLS Certificates")
tlsConfig, _, err = util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
if err != nil {
logutil.BgLogger().Warn("TLS Certificate rotation failed", zap.Error(err))
}
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
}
}()
}

if err != nil {
logutil.BgLogger().Error("secure connection cert/key/ca load fail", zap.Error(err))
}
Expand Down
19 changes: 5 additions & 14 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -905,18 +905,8 @@ func (ts *tidbTestSerialSuite) TestTLS(c *C) {
c.Assert(err, IsNil)
}()
time.Sleep(time.Millisecond * 100)
err = cli.runTestTLSConnection(c, connOverrider) // We should get ErrNoTLS.
c.Assert(err, NotNil)
c.Assert(errors.Cause(err).Error(), Equals, mysql.ErrNoTLS.Error())

// Test SSL/TLS session vars
var v *variable.SessionVars
stats, err := server.Stats(v)
err = cli.runTestTLSConnection(c, connOverrider) // Relying on automatically created TLS certificates
c.Assert(err, IsNil)
c.Assert(stats, HasKey, "Ssl_server_not_after")
c.Assert(stats, HasKey, "Ssl_server_not_before")
c.Assert(stats["Ssl_server_not_after"], Equals, "")
c.Assert(stats["Ssl_server_not_before"], Equals, "")

server.Close()

Expand Down Expand Up @@ -952,7 +942,8 @@ func (ts *tidbTestSerialSuite) TestTLS(c *C) {
cli.runTestRegression(c, connOverrider, "TLSRegression")

// Test SSL/TLS session vars
stats, err = server.Stats(v)
var v *variable.SessionVars
stats, err := server.Stats(v)
c.Assert(err, IsNil)
c.Assert(stats, HasKey, "Ssl_server_not_after")
c.Assert(stats, HasKey, "Ssl_server_not_before")
Expand Down Expand Up @@ -996,9 +987,9 @@ func (ts *tidbTestSerialSuite) TestTLS(c *C) {
c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.CANotAuthorizedForThisName}), IsFalse)
c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.Expired}), IsTrue)

_, err = util.LoadTLSCertificates("", "wrong key", "wrong cert")
_, _, err = util.LoadTLSCertificates("", "wrong key", "wrong cert")
c.Assert(err, NotNil)
_, err = util.LoadTLSCertificates("wrong ca", "/tmp/server-key.pem", "/tmp/server-cert.pem")
_, _, err = util.LoadTLSCertificates("wrong ca", "/tmp/server-key.pem", "/tmp/server-cert.pem")
c.Assert(err, NotNil)
}

Expand Down
77 changes: 75 additions & 2 deletions util/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ package util

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/http"
"os"
Expand Down Expand Up @@ -434,9 +438,18 @@ type SequenceTable interface {
}

// LoadTLSCertificates loads CA/KEY/CERT for special paths.
func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error) {
func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, autoReload bool, err error) {
autoReload = false
if len(cert) == 0 || len(key) == 0 {
return
autoReload = true
tempStoragePath := config.GetGlobalConfig().TempStoragePath
cert = tempStoragePath + "/cert.pem"
key = tempStoragePath + "/key.pem"
err = createTLSCertificates(cert, key)
if err != nil {
logutil.BgLogger().Warn("TLS Certificate creation failed", zap.Error(err))
return
}
}

var tlsCert tls.Certificate
Expand Down Expand Up @@ -546,3 +559,63 @@ func QueryStrForLog(query string) string {
}
return query
}

func createTLSCertificates(certpath string, keypath string) error {
privkey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return err
}

certValidity := 90 * 24 * time.Hour // 90 days
notBefore := time.Now()
notAfter := notBefore.Add(certValidity)
hostname, err := os.Hostname()
if err != nil {
return err
}

template := x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: notBefore,
NotAfter: notAfter,
DNSNames: []string{hostname},
}

// DER: Distinguished Encoding Rules, this is the ASN.1 encoding rule of the certificate.
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privkey.PublicKey, privkey)
if err != nil {
return err
}

certOut, err := os.Create(certpath)
if err != nil {
return err
}
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
return err
}
if err := certOut.Close(); err != nil {
return err
}

keyOut, err := os.OpenFile(keypath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}

privBytes, err := x509.MarshalPKCS8PrivateKey(privkey)
if err != nil {
return err
}

if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
return err
}

if err := keyOut.Close(); err != nil {
return err
}

logutil.BgLogger().Info("TLS Certificates created", zap.String("cert", certpath), zap.String("key", keypath), zap.Duration("validity", certValidity))
return nil
}

0 comments on commit 7763c4f

Please sign in to comment.