diff --git a/client/client.go b/client/client.go index 5e00aea837f..86c2e04523a 100644 --- a/client/client.go +++ b/client/client.go @@ -15,10 +15,6 @@ package pd import ( "context" - "crypto/tls" - "crypto/x509" - "io/ioutil" - "net/url" "strings" "sync" "time" @@ -26,11 +22,11 @@ import ( opentracing "github.com/opentracing/opentracing-go" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" - log "github.com/pingcap/log" + "github.com/pingcap/log" + "github.com/pingcap/pd/pkg/grpcutil" "github.com/pkg/errors" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) // Client is a PD (Placement Driver) client. @@ -272,43 +268,7 @@ func (c *client) getOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) { return conn, nil } - opt := grpc.WithInsecure() - if len(c.security.CAPath) != 0 { - - certificates := []tls.Certificate{} - if len(c.security.CertPath) != 0 && len(c.security.KeyPath) != 0 { - // Load the client certificates from disk - certificate, err := tls.LoadX509KeyPair(c.security.CertPath, c.security.KeyPath) - if err != nil { - return nil, errors.Errorf("could not load client key pair: %s", err) - } - certificates = append(certificates, certificate) - } - - // Create a certificate pool from the certificate authority - certPool := x509.NewCertPool() - ca, err := ioutil.ReadFile(c.security.CAPath) - if err != nil { - return nil, errors.Errorf("could not read ca certificate: %s", err) - } - - // Append the certificates from the CA - if !certPool.AppendCertsFromPEM(ca) { - return nil, errors.New("failed to append ca certs") - } - - creds := credentials.NewTLS(&tls.Config{ - Certificates: certificates, - RootCAs: certPool, - }) - - opt = grpc.WithTransportCredentials(creds) - } - u, err := url.Parse(addr) - if err != nil { - return nil, errors.WithStack(err) - } - cc, err := grpc.Dial(u.Host, opt) + cc, err := grpcutil.GetClientConn(addr, c.security.CAPath, c.security.CertPath, c.security.KeyPath) if err != nil { return nil, errors.WithStack(err) } diff --git a/pkg/grpcutil/grpcutil.go b/pkg/grpcutil/grpcutil.go new file mode 100644 index 00000000000..af120b0f22f --- /dev/null +++ b/pkg/grpcutil/grpcutil.go @@ -0,0 +1,69 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package grpcutil + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net/url" + + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +// GetClientConn returns a gRPC client connection. +func GetClientConn(addr string, caPath string, certPath string, keyPath string) (*grpc.ClientConn, error) { + opt := grpc.WithInsecure() + if len(caPath) != 0 { + certificates := []tls.Certificate{} + if len(certPath) != 0 && len(keyPath) != 0 { + // Load the client certificates from disk + certificate, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, errors.Errorf("could not load client key pair: %s", err) + } + certificates = append(certificates, certificate) + } + + // Create a certificate pool from the certificate authority + certPool := x509.NewCertPool() + ca, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, errors.Errorf("could not read ca certificate: %s", err) + } + + // Append the certificates from the CA + if !certPool.AppendCertsFromPEM(ca) { + return nil, errors.New("failed to append ca certs") + } + + creds := credentials.NewTLS(&tls.Config{ + Certificates: certificates, + RootCAs: certPool, + }) + + opt = grpc.WithTransportCredentials(creds) + } + u, err := url.Parse(addr) + if err != nil { + return nil, errors.WithStack(err) + } + cc, err := grpc.Dial(u.Host, opt) + if err != nil { + return nil, errors.WithStack(err) + } + return cc, nil +} diff --git a/server/config.go b/server/config.go index 8799d4fe6fc..79c231a3df1 100644 --- a/server/config.go +++ b/server/config.go @@ -775,15 +775,23 @@ type SecurityConfig struct { KeyPath string `toml:"key-path" json:"key-path"` } +// ConvertToMap is used to convert SecurityConfig to a map. +func (s *SecurityConfig) ConvertToMap() map[string]string { + return map[string]string{ + "caPath": s.CAPath, + "certPath": s.CertPath, + "keyPath": s.KeyPath} +} + // ToTLSConfig generatres tls config. -func (s SecurityConfig) ToTLSConfig() (*tls.Config, error) { - if len(s.CertPath) == 0 && len(s.KeyPath) == 0 { +func ToTLSConfig(config map[string]string) (*tls.Config, error) { + if len(config["certPath"]) == 0 && len(config["keyPath"]) == 0 { return nil, nil } tlsInfo := transport.TLSInfo{ - CertFile: s.CertPath, - KeyFile: s.KeyPath, - TrustedCAFile: s.CAPath, + CertFile: config["certPath"], + KeyFile: config["keyPath"], + TrustedCAFile: config["caPath"], } tlsConfig, err := tlsInfo.ClientConfig() if err != nil { diff --git a/server/config_test.go b/server/config_test.go index ea09c4c4ea4..2c98851cbb7 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -32,7 +32,7 @@ type testConfigSuite struct{} func (s *testConfigSuite) TestTLS(c *C) { cfg := NewConfig() - tls, err := cfg.Security.ToTLSConfig() + tls, err := ToTLSConfig(cfg.Security.ConvertToMap()) c.Assert(err, IsNil) c.Assert(tls, IsNil) } diff --git a/server/join.go b/server/join.go index 744dd4af978..3c4cd999b28 100644 --- a/server/join.go +++ b/server/join.go @@ -108,7 +108,7 @@ func PrepareJoinCluster(cfg *Config) error { } // Below are cases without data directory. - tlsConfig, err := cfg.Security.ToTLSConfig() + tlsConfig, err := ToTLSConfig(cfg.Security.ConvertToMap()) if err != nil { return err } diff --git a/server/region_syncer/client.go b/server/region_syncer/client.go index 3a794c16a95..e5581bf8fbc 100644 --- a/server/region_syncer/client.go +++ b/server/region_syncer/client.go @@ -15,14 +15,14 @@ package syncer import ( "context" - "net/url" "time" "github.com/pingcap/kvproto/pkg/pdpb" - log "github.com/pingcap/log" + "github.com/pingcap/log" + "github.com/pingcap/pd/pkg/grpcutil" "github.com/pingcap/pd/server/core" + "github.com/pkg/errors" "go.uber.org/zap" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -51,14 +51,9 @@ func (s *RegionSyncer) reset() { func (s *RegionSyncer) establish(addr string) (ClientStream, error) { s.reset() - u, err := url.Parse(addr) + cc, err := grpcutil.GetClientConn(addr, s.securityConfig["caPath"], s.securityConfig["certPath"], s.securityConfig["keyPath"]) if err != nil { - return nil, err - } - - cc, err := grpc.Dial(u.Host, grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(msgSize))) - if err != nil { - return nil, err + return nil, errors.WithStack(err) } ctx, cancel := context.WithCancel(s.server.Context()) diff --git a/server/region_syncer/server.go b/server/region_syncer/server.go index 377c53fa7a1..796a4c2217c 100644 --- a/server/region_syncer/server.go +++ b/server/region_syncer/server.go @@ -22,7 +22,7 @@ import ( "github.com/juju/ratelimit" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" - log "github.com/pingcap/log" + "github.com/pingcap/log" "github.com/pingcap/pd/server/core" "github.com/pkg/errors" "go.uber.org/zap" @@ -59,19 +59,21 @@ type Server interface { GetStorage() *core.KV Name() string GetMetaRegions() []*metapb.Region + GetSecurityConfig() map[string]string } // RegionSyncer is used to sync the region information without raft. type RegionSyncer struct { sync.RWMutex - streams map[string]ServerStream - ctx context.Context - cancel context.CancelFunc - server Server - closed chan struct{} - wg sync.WaitGroup - history *historyBuffer - limit *ratelimit.Bucket + streams map[string]ServerStream + ctx context.Context + cancel context.CancelFunc + server Server + closed chan struct{} + wg sync.WaitGroup + history *historyBuffer + limit *ratelimit.Bucket + securityConfig map[string]string } // NewRegionSyncer returns a region syncer. @@ -81,11 +83,12 @@ type RegionSyncer struct { // no longer etcd but go-leveldb. func NewRegionSyncer(s Server) *RegionSyncer { return &RegionSyncer{ - streams: make(map[string]ServerStream), - server: s, - closed: make(chan struct{}), - history: newHistoryBuffer(defaultHistoryBufferSize, s.GetStorage().GetRegionKV()), - limit: ratelimit.NewBucketWithRate(defaultBucketRate, defaultBucketCapacity), + streams: make(map[string]ServerStream), + server: s, + closed: make(chan struct{}), + history: newHistoryBuffer(defaultHistoryBufferSize, s.GetStorage().GetRegionKV()), + limit: ratelimit.NewBucketWithRate(defaultBucketRate, defaultBucketCapacity), + securityConfig: s.GetSecurityConfig(), } } diff --git a/server/server.go b/server/server.go index 0462214f520..226d20a222f 100644 --- a/server/server.go +++ b/server/server.go @@ -156,7 +156,7 @@ func (s *Server) startEtcd(ctx context.Context) error { if err != nil { return errors.WithStack(err) } - tlsConfig, err := s.cfg.Security.ToTLSConfig() + tlsConfig, err := ToTLSConfig(s.cfg.Security.ConvertToMap()) if err != nil { return err } @@ -726,9 +726,9 @@ func (s *Server) GetClusterVersion() semver.Version { return s.scheduleOpt.loadClusterVersion() } -// GetSecurityConfig get the security config. -func (s *Server) GetSecurityConfig() *SecurityConfig { - return &s.cfg.Security +// GetSecurityConfig get paths of the security config. +func (s *Server) GetSecurityConfig() map[string]string { + return s.cfg.Security.ConvertToMap() } // IsNamespaceExist returns whether the namespace exists. diff --git a/server/util.go b/server/util.go index f3119a4a1cd..9ce968b21ce 100644 --- a/server/util.go +++ b/server/util.go @@ -281,7 +281,7 @@ func subTimeByWallClock(after time.Time, before time.Time) time.Duration { // InitHTTPClient initials a http client. func InitHTTPClient(svr *Server) error { - tlsConfig, err := svr.GetSecurityConfig().ToTLSConfig() + tlsConfig, err := ToTLSConfig(svr.GetSecurityConfig()) if err != nil { return err }