Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: adjust TLS behaviour for dumpling and lightning #37479

Merged
merged 21 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 43 additions & 74 deletions br/pkg/lightning/common/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,97 +17,62 @@ package common
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"net/http"
"net/http/httptest"
"os"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/httputil"
"github.com/pingcap/tidb/util"
"github.com/tikv/client-go/v2/config"
pd "github.com/tikv/pd/client"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

type TLS struct {
caPath string
certPath string
keyPath string
inner *tls.Config
client *http.Client
url string
caPath string
certPath string
keyPath string
caBytes []byte
certBytes []byte
keyBytes []byte
inner *tls.Config
client *http.Client
url string
}

// ToTLSConfig constructs a `*tls.Config` from the CA, certification and key
// paths.
//
// If the CA path is empty, returns nil.
func ToTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
if len(caPath) == 0 {
return nil, nil
}

// Create a certificate pool from CA
certPool := x509.NewCertPool()
ca, err := os.ReadFile(caPath)
// NewTLS constructs a new HTTP client with TLS configured with the CA,
// certificate and key paths.
func NewTLS(caPath, certPath, keyPath, host string, caBytes, certBytes, keyBytes []byte) (*TLS, error) {
inner, err := util.NewTLSConfig(
util.WithCAPath(caPath),
util.WithCertAndKeyPath(certPath, keyPath),
util.WithCAContent(caBytes),
util.WithCertAndKeyContent(certBytes, keyBytes),
)
if err != nil {
return nil, errors.Annotate(err, "could not read ca certificate")
}

// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
return nil, errors.New("failed to append ca certs")
}

tlsConfig := &tls.Config{
RootCAs: certPool,
NextProtos: []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2.
MinVersion: tls.VersionTLS12,
}

if len(certPath) != 0 && len(keyPath) != 0 {
loadCert := func() (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, errors.Annotate(err, "could not load client key pair")
}
return &cert, nil
}
tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return loadCert()
}
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return loadCert()
}
return nil, errors.Trace(err)
}
return tlsConfig, nil
}

// NewTLS constructs a new HTTP client with TLS configured with the CA,
// certificate and key paths.
//
// If the CA path is empty, returns an instance where TLS is disabled.
func NewTLS(caPath, certPath, keyPath, host string) (*TLS, error) {
if len(caPath) == 0 {
if inner == nil {
return &TLS{
inner: nil,
client: &http.Client{},
url: "http://" + host,
}, nil
}
inner, err := ToTLSConfig(caPath, certPath, keyPath)
if err != nil {
return nil, errors.Trace(err)
}

return &TLS{
caPath: caPath,
certPath: certPath,
keyPath: keyPath,
inner: inner,
client: httputil.NewClient(inner),
url: "https://" + host,
caPath: caPath,
certPath: certPath,
keyPath: keyPath,
caBytes: caBytes,
certBytes: certBytes,
keyBytes: keyBytes,
inner: inner,
client: httputil.NewClient(inner),
url: "https://" + host,
}, nil
}

Expand All @@ -129,11 +94,9 @@ func (tc *TLS) WithHost(host string) *TLS {
} else {
url = "http://" + host
}
return &TLS{
inner: tc.inner,
client: tc.client,
url: url,
}
shallowClone := *tc
shallowClone.url = url
return &shallowClone
}

// ToGRPCDialOption constructs a gRPC dial option.
Expand All @@ -156,14 +119,20 @@ func (tc *TLS) GetJSON(ctx context.Context, path string, v interface{}) error {
return GetJSON(ctx, tc.client, tc.url+path, v)
}

// ToPDSecurityOption converts the TLS configuration to a PD security option.
func (tc *TLS) ToPDSecurityOption() pd.SecurityOption {
return pd.SecurityOption{
CAPath: tc.caPath,
CertPath: tc.certPath,
KeyPath: tc.keyPath,
CAPath: tc.caPath,
CertPath: tc.certPath,
KeyPath: tc.keyPath,
SSLCABytes: tc.caBytes,
SSLCertBytes: tc.certBytes,
SSLKEYBytes: tc.keyBytes,
}
}

// ToTiKVSecurityConfig converts the TLS configuration to a TiKV security config.
// TODO: TiKV does not support pass in content.
func (tc *TLS) ToTiKVSecurityConfig() config.Security {
return config.Security{
ClusterSSLCA: tc.caPath,
Expand Down
28 changes: 10 additions & 18 deletions br/pkg/lightning/common/security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestGetJSONInsecure(t *testing.T) {
u, err := url.Parse(mockServer.URL)
require.NoError(t, err)

tls, err := common.NewTLS("", "", "", u.Host)
tls, err := common.NewTLS("", "", "", u.Host, nil, nil, nil)
require.NoError(t, err)

var result struct{ Path string }
Expand Down Expand Up @@ -73,15 +73,8 @@ func TestGetJSONSecure(t *testing.T) {
func TestInvalidTLS(t *testing.T) {
tempDir := t.TempDir()
caPath := filepath.Join(tempDir, "ca.pem")
_, err := common.NewTLS(caPath, "", "", "localhost")
require.Regexp(t, "could not read ca certificate:.*", err.Error())

err = os.WriteFile(caPath, []byte("invalid ca content"), 0o644)
require.NoError(t, err)
_, err = common.NewTLS(caPath, "", "", "localhost")
require.Regexp(t, "failed to append ca certs", err.Error())

err = os.WriteFile(caPath, []byte(`-----BEGIN CERTIFICATE-----
caContent := []byte(`-----BEGIN CERTIFICATE-----
MIIBITCBxwIUf04/Hucshr7AynmgF8JeuFUEf9EwCgYIKoZIzj0EAwIwEzERMA8G
A1UEAwwIYnJfdGVzdHMwHhcNMjIwNDEzMDcyNDQxWhcNMjIwNDE1MDcyNDQxWjAT
MREwDwYDVQQDDAhicl90ZXN0czBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABL+X
Expand All @@ -90,20 +83,19 @@ wczUg0AbaFFaCI+FAk3K9vbB9JeIORgGKS+F1TKip5tvm96g7S5lq8SgY38SXVc3
ze4ZnCkwJdP2VdpI3WZsoI7zAiEAjP8X1c0iFwYxdAbQAveX+9msVrzyUpZOohi4
RtgQTNI=
-----END CERTIFICATE-----
`), 0o644)
`)
err := os.WriteFile(caPath, caContent, 0o644)
require.NoError(t, err)

certPath := filepath.Join(tempDir, "test.pem")
keyPath := filepath.Join(tempDir, "test.key")
tls, err := common.NewTLS(caPath, certPath, keyPath, "localhost")
_, err = tls.TLSConfig().GetCertificate(nil)
require.Regexp(t, "could not load client key pair: open.*", err.Error())

err = os.WriteFile(certPath, []byte("invalid cert content"), 0o644)
certContent := []byte("invalid cert content")
err = os.WriteFile(certPath, certContent, 0o644)
require.NoError(t, err)
err = os.WriteFile(keyPath, []byte("invalid key content"), 0o600)
keyContent := []byte("invalid key content")
err = os.WriteFile(keyPath, keyContent, 0o600)
require.NoError(t, err)
tls, err = common.NewTLS(caPath, certPath, keyPath, "localhost")
_, err = tls.TLSConfig().GetCertificate(nil)
require.Regexp(t, "could not load client key pair: tls.*", err.Error())
_, err = common.NewTLS(caPath, "", "", "localhost", caContent, certContent, keyContent)
require.ErrorContains(t, err, "tls: failed to find any PEM data in certificate input")
}
31 changes: 27 additions & 4 deletions br/pkg/lightning/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ import (
"github.com/BurntSushi/toml"
"github.com/docker/go-units"
gomysql "github.com/go-sql-driver/mysql"
"github.com/google/uuid"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/log"
tidbcfg "github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util"
filter "github.com/pingcap/tidb/util/table-filter"
router "github.com/pingcap/tidb/util/table-router"
"go.uber.org/atomic"
Expand Down Expand Up @@ -155,7 +157,15 @@ func (cfg *Config) String() string {

func (cfg *Config) ToTLS() (*common.TLS, error) {
hostPort := net.JoinHostPort(cfg.TiDB.Host, strconv.Itoa(cfg.TiDB.StatusPort))
return common.NewTLS(cfg.Security.CAPath, cfg.Security.CertPath, cfg.Security.KeyPath, hostPort)
return common.NewTLS(
cfg.Security.CAPath,
cfg.Security.CertPath,
cfg.Security.KeyPath,
hostPort,
cfg.Security.CABytes,
cfg.Security.CertBytes,
cfg.Security.KeyBytes,
)
}

type Lightning struct {
Expand Down Expand Up @@ -559,6 +569,11 @@ type Security struct {
// TLSConfigName is used to set tls config for lightning in DM, so we don't expose this field to user
// DM may running many lightning instances at same time, so we need to set different tls config name for each lightning
TLSConfigName string `toml:"-" json:"-"`

// When DM/engine uses lightning as a library, it can directly pass in the content
CABytes []byte `toml:"-" json:"-"`
CertBytes []byte `toml:"-" json:"-"`
KeyBytes []byte `toml:"-" json:"-"`
}

// RegisterMySQL registers the TLS config with name "cluster" or security.TLSConfigName
Expand All @@ -567,7 +582,13 @@ func (sec *Security) RegisterMySQL() error {
if sec == nil {
return nil
}
tlsConfig, err := common.ToTLSConfig(sec.CAPath, sec.CertPath, sec.KeyPath)

tlsConfig, err := util.NewTLSConfig(
util.WithCAPath(sec.CAPath),
util.WithCertAndKeyPath(sec.CertPath, sec.KeyPath),
util.WithCAContent(sec.CABytes),
util.WithCertAndKeyContent(sec.CertBytes, sec.KeyBytes),
)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -1151,9 +1172,11 @@ func (cfg *Config) CheckAndAdjustSecurity() error {

switch cfg.TiDB.TLS {
case "":
if len(cfg.TiDB.Security.CAPath) > 0 {
if len(cfg.TiDB.Security.CAPath) > 0 || len(cfg.TiDB.Security.CABytes) > 0 ||
len(cfg.TiDB.Security.CertPath) > 0 || len(cfg.TiDB.Security.CertBytes) > 0 ||
len(cfg.TiDB.Security.KeyPath) > 0 || len(cfg.TiDB.Security.KeyBytes) > 0 {
if cfg.TiDB.Security.TLSConfigName == "" {
cfg.TiDB.Security.TLSConfigName = "cluster" // adjust this the default value
cfg.TiDB.Security.TLSConfigName = uuid.NewString() // adjust this the default value
}
cfg.TiDB.TLS = cfg.TiDB.Security.TLSConfigName
} else {
Expand Down
13 changes: 9 additions & 4 deletions br/pkg/lightning/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ func TestAdjustWillBatchImportRatioInvalid(t *testing.T) {
}

func TestAdjustSecuritySection(t *testing.T) {
uuidHolder := "<uuid>"
testCases := []struct {
input string
expectedCA string
Expand All @@ -302,7 +303,7 @@ func TestAdjustSecuritySection(t *testing.T) {
ca-path = "/path/to/ca.pem"
`,
expectedCA: "/path/to/ca.pem",
expectedTLS: "cluster",
expectedTLS: uuidHolder,
},
{
input: `
Expand All @@ -321,7 +322,7 @@ func TestAdjustSecuritySection(t *testing.T) {
ca-path = "/path/to/ca2.pem"
`,
expectedCA: "/path/to/ca2.pem",
expectedTLS: "cluster",
expectedTLS: uuidHolder,
},
{
input: `
Expand All @@ -330,7 +331,7 @@ func TestAdjustSecuritySection(t *testing.T) {
ca-path = "/path/to/ca2.pem"
`,
expectedCA: "/path/to/ca2.pem",
expectedTLS: "cluster",
expectedTLS: uuidHolder,
},
{
input: `
Expand All @@ -356,7 +357,11 @@ func TestAdjustSecuritySection(t *testing.T) {
err = cfg.Adjust(context.Background())
require.NoError(t, err, comment)
require.Equal(t, tc.expectedCA, cfg.TiDB.Security.CAPath, comment)
require.Equal(t, tc.expectedTLS, cfg.TiDB.TLS, comment)
if tc.expectedTLS == uuidHolder {
require.NotEmpty(t, cfg.TiDB.TLS, comment)
} else {
require.Equal(t, tc.expectedTLS, cfg.TiDB.TLS, comment)
}
}
// test different tls config name
cfg := config.NewConfig()
Expand Down
10 changes: 9 additions & 1 deletion br/pkg/lightning/lightning.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@ func New(globalCfg *config.GlobalConfig) *Lightning {
os.Exit(1)
}

tls, err := common.NewTLS(globalCfg.Security.CAPath, globalCfg.Security.CertPath, globalCfg.Security.KeyPath, globalCfg.App.StatusAddr)
tls, err := common.NewTLS(
globalCfg.Security.CAPath,
globalCfg.Security.CertPath,
globalCfg.Security.KeyPath,
globalCfg.App.StatusAddr,
globalCfg.Security.CABytes,
globalCfg.Security.CertBytes,
globalCfg.Security.KeyBytes,
)
if err != nil {
log.L().Fatal("failed to load TLS certificates", zap.Error(err))
}
Expand Down
Loading