Skip to content

Commit

Permalink
store/tikv: refactor: remove import tidb.config.security in store/tikv (
Browse files Browse the repository at this point in the history
#22538)

Signed-off-by: shirly <AndreMouche@126.com>
  • Loading branch information
AndreMouche authored Jan 27, 2021
1 parent 6931ece commit fbcf75a
Show file tree
Hide file tree
Showing 16 changed files with 244 additions and 80 deletions.
2 changes: 1 addition & 1 deletion cmd/benchraw/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/log"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/store/tikv"
"github.com/pingcap/tidb/store/tikv/config"
"go.uber.org/zap"
)

Expand Down
50 changes: 4 additions & 46 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
package config

import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/url"
"os"
"os/user"
Expand All @@ -33,6 +30,7 @@ import (
zaplog "github.com/pingcap/log"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
tikvcfg "github.com/pingcap/tidb/store/tikv/config"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/versioninfo"
tracing "github.com/uber/jaeger-client-go/config"
Expand Down Expand Up @@ -369,49 +367,9 @@ func (e *ErrConfigValidationFailed) Error() string {

}

// ToTLSConfig generates tls's config based on security section of the config.
func (s *Security) ToTLSConfig() (tlsConfig *tls.Config, err error) {
if len(s.ClusterSSLCA) != 0 {
certPool := x509.NewCertPool()
// Create a certificate pool from the certificate authority
var ca []byte
ca, err = ioutil.ReadFile(s.ClusterSSLCA)
if err != nil {
err = errors.Errorf("could not read ca certificate: %s", err)
return
}
// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
err = errors.New("failed to append ca certs")
return
}
tlsConfig = &tls.Config{
RootCAs: certPool,
ClientCAs: certPool,
}

if len(s.ClusterSSLCert) != 0 && len(s.ClusterSSLKey) != 0 {
getCert := func() (*tls.Certificate, error) {
// Load the client certificates from disk
cert, err := tls.LoadX509KeyPair(s.ClusterSSLCert, s.ClusterSSLKey)
if err != nil {
return nil, errors.Errorf("could not load client key pair: %s", err)
}
return &cert, nil
}
// pre-test cert's loading.
if _, err = getCert(); err != nil {
return
}
tlsConfig.GetClientCertificate = func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, err error) {
return getCert()
}
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, err error) {
return getCert()
}
}
}
return
// ClusterSecurity returns Security info for cluster
func (s *Security) ClusterSecurity() tikvcfg.Security {
return tikvcfg.NewSecurity(s.ClusterSSLCA, s.ClusterSSLCert, s.ClusterSSLKey, s.ClusterVerifyCN)
}

// Status is the status section of the config.
Expand Down
3 changes: 2 additions & 1 deletion config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ xkNuJ2BlEGkwWLiRbKy1lNBBFUXKuhh3L/EIY10WTnr3TQzeL6H1
conf.Security.ClusterSSLCA = certFile
conf.Security.ClusterSSLCert = certFile
conf.Security.ClusterSSLKey = keyFile
tlsConfig, err := conf.Security.ToTLSConfig()
clusterSecurity := conf.Security.ClusterSecurity()
tlsConfig, err := clusterSecurity.ToTLSConfig()
c.Assert(err, IsNil)
c.Assert(tlsConfig, NotNil)

Expand Down
6 changes: 4 additions & 2 deletions executor/memtable_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ func getServerInfoByGRPC(ctx context.Context, address string, tp diagnosticspb.S
opt := grpc.WithInsecure()
security := config.GetGlobalConfig().Security
if len(security.ClusterSSLCA) != 0 {
tlsConfig, err := security.ToTLSConfig()
clusterSecurity := security.ClusterSecurity()
tlsConfig, err := clusterSecurity.ToTLSConfig()
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -527,7 +528,8 @@ func (e *clusterLogRetriever) startRetrieving(
opt := grpc.WithInsecure()
security := config.GetGlobalConfig().Security
if len(security.ClusterSSLCA) != 0 {
tlsConfig, err := security.ToTLSConfig()
clusterSecurity := security.ClusterSecurity()
tlsConfig, err := clusterSecurity.ToTLSConfig()
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
3 changes: 2 additions & 1 deletion server/http_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ func (s *Server) listenStatusHTTPServer() error {
}

logutil.BgLogger().Info("for status and metrics report", zap.String("listening on addr", s.statusAddr))
tlsConfig, err := s.cfg.Security.ToTLSConfig()
clusterSecurity := s.cfg.Security.ClusterSecurity()
tlsConfig, err := clusterSecurity.ToTLSConfig()
if err != nil {
logutil.BgLogger().Error("invalid TLS config", zap.Error(err))
return errors.Trace(err)
Expand Down
4 changes: 2 additions & 2 deletions server/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ func NewRPCServer(config *config.Config, dom *domain.Domain, sm util.SessionMana
}
// For redirection the cop task.
mocktikv.GRPCClientFactory = func() mocktikv.Client {
return tikv.NewTestRPCClient(config.Security)
return tikv.NewTestRPCClient(config.Security.ClusterSecurity())
}
unistore.GRPCClientFactory = func() unistore.Client {
return tikv.NewTestRPCClient(config.Security)
return tikv.NewTestRPCClient(config.Security.ClusterSecurity())
}
diagnosticspb.RegisterDiagnosticsServer(s, rpcSrv)
tikvpb.RegisterTikvServer(s, rpcSrv)
Expand Down
13 changes: 7 additions & 6 deletions store/tikv/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ import (
"github.com/pingcap/kvproto/pkg/mpp"
"github.com/pingcap/kvproto/pkg/tikvpb"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
tidbcfg "github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/store/tikv/config"
"github.com/pingcap/tidb/store/tikv/tikvrpc"
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/logutil"
Expand Down Expand Up @@ -118,7 +119,7 @@ func (a *connArray) Init(addr string, security config.Security, idleNotify *uint
opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))
}

cfg := config.GetGlobalConfig()
cfg := tidbcfg.GetGlobalConfig()
var (
unaryInterceptor grpc.UnaryClientInterceptor
streamInterceptor grpc.StreamClientInterceptor
Expand Down Expand Up @@ -255,7 +256,7 @@ func NewTestRPCClient(security config.Security) Client {
return newRPCClient(security)
}

func (c *rpcClient) getConnArray(addr string, enableBatch bool, opt ...func(cfg *config.TiKVClient)) (*connArray, error) {
func (c *rpcClient) getConnArray(addr string, enableBatch bool, opt ...func(cfg *tidbcfg.TiKVClient)) (*connArray, error) {
c.RLock()
if c.isClosed {
c.RUnlock()
Expand All @@ -273,13 +274,13 @@ func (c *rpcClient) getConnArray(addr string, enableBatch bool, opt ...func(cfg
return array, nil
}

func (c *rpcClient) createConnArray(addr string, enableBatch bool, opts ...func(cfg *config.TiKVClient)) (*connArray, error) {
func (c *rpcClient) createConnArray(addr string, enableBatch bool, opts ...func(cfg *tidbcfg.TiKVClient)) (*connArray, error) {
c.Lock()
defer c.Unlock()
array, ok := c.conns[addr]
if !ok {
var err error
client := config.GetGlobalConfig().TiKVClient
client := tidbcfg.GetGlobalConfig().TiKVClient
for _, opt := range opts {
opt(&client)
}
Expand Down Expand Up @@ -363,7 +364,7 @@ func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R

// TiDB RPC server supports batch RPC, but batch connection will send heart beat, It's not necessary since
// request to TiDB is not high frequency.
if config.GetGlobalConfig().TiKVClient.MaxBatchSize > 0 && enableBatch {
if tidbcfg.GetGlobalConfig().TiKVClient.MaxBatchSize > 0 && enableBatch {
if batchReq := req.ToBatchCommandsRequest(); batchReq != nil {
defer trace.StartRegion(ctx, req.Type.String()).End()
return sendBatchRequest(ctx, addr, connArray.batchConn, batchReq, timeout)
Expand Down
5 changes: 3 additions & 2 deletions store/tikv/client_fail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/tikvpb"
"github.com/pingcap/tidb/config"
tidbcfg "github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/store/tikv/config"
"github.com/pingcap/tidb/store/tikv/tikvrpc"
)

Expand Down Expand Up @@ -52,7 +53,7 @@ func (s *testClientFailSuite) TestPanicInRecvLoop(c *C) {
})

// Start batchRecvLoop, and it should panic in `failPendingRequests`.
_, err := rpcClient.getConnArray(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 })
_, err := rpcClient.getConnArray(addr, true, func(cfg *tidbcfg.TiKVClient) { cfg.GrpcConnectionCount = 1 })
c.Assert(err, IsNil, Commentf("cannot establish local connection due to env problems(e.g. heavy load in test machine), please retry again"))

req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
Expand Down
9 changes: 5 additions & 4 deletions store/tikv/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import (
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/kvproto/pkg/tikvpb"
"github.com/pingcap/tidb/config"
tidbcfg "github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/store/tikv/config"
"github.com/pingcap/tidb/store/tikv/tikvrpc"
)

Expand All @@ -47,13 +48,13 @@ var _ = SerialSuites(&testClientFailSuite{})
var _ = SerialSuites(&testClientSerialSuite{})

func setMaxBatchSize(size uint) {
newConf := config.NewConfig()
newConf := tidbcfg.NewConfig()
newConf.TiKVClient.MaxBatchSize = size
config.StoreGlobalConfig(newConf)
tidbcfg.StoreGlobalConfig(newConf)
}

func (s *testClientSerialSuite) TestConn(c *C) {
maxBatchSize := config.GetGlobalConfig().TiKVClient.MaxBatchSize
maxBatchSize := tidbcfg.GetGlobalConfig().TiKVClient.MaxBatchSize
setMaxBatchSize(0)

client := newRPCClient(config.Security{})
Expand Down
85 changes: 85 additions & 0 deletions store/tikv/config/security.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package config

import (
"crypto/tls"
"crypto/x509"
"io/ioutil"

"github.com/pingcap/errors"
)

// Security is the security section of the config.
type Security struct {
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
}

// NewSecurity creates a Security.
func NewSecurity(sslCA, sslCert, sslKey string, verityCN []string) Security {
return Security{
ClusterSSLCA: sslCA,
ClusterSSLCert: sslCert,
ClusterSSLKey: sslKey,
ClusterVerifyCN: verityCN,
}
}

// ToTLSConfig generates tls's config based on security section of the config.
func (s *Security) ToTLSConfig() (tlsConfig *tls.Config, err error) {
if len(s.ClusterSSLCA) != 0 {
certPool := x509.NewCertPool()
// Create a certificate pool from the certificate authority
var ca []byte
ca, err = ioutil.ReadFile(s.ClusterSSLCA)
if err != nil {
err = errors.Errorf("could not read ca certificate: %s", err)
return
}
// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
err = errors.New("failed to append ca certs")
return
}
tlsConfig = &tls.Config{
RootCAs: certPool,
ClientCAs: certPool,
}

if len(s.ClusterSSLCert) != 0 && len(s.ClusterSSLKey) != 0 {
getCert := func() (*tls.Certificate, error) {
// Load the client certificates from disk
cert, err := tls.LoadX509KeyPair(s.ClusterSSLCert, s.ClusterSSLKey)
if err != nil {
return nil, errors.Errorf("could not load client key pair: %s", err)
}
return &cert, nil
}
// pre-test cert's loading.
if _, err = getCert(); err != nil {
return
}
tlsConfig.GetClientCertificate = func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, err error) {
return getCert()
}
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, err error) {
return getCert()
}
}
}
return
}
Loading

0 comments on commit fbcf75a

Please sign in to comment.