diff --git a/client/client.go b/client/client.go index e7a9d0d8a00..b6c3eee6036 100644 --- a/client/client.go +++ b/client/client.go @@ -15,10 +15,6 @@ package pd import ( "context" - "crypto/tls" - "crypto/x509" - "io/ioutil" - "net/url" "strings" "sync" "time" @@ -27,10 +23,10 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/pingcap/pd/pkg/grpcutil" "github.com/pkg/errors" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) // Client is a PD (Placement Driver) client. @@ -272,43 +268,7 @@ func (c *client) getOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) { return conn, nil } - opt := grpc.WithInsecure() - if len(c.security.CAPath) != 0 { - - certificates := []tls.Certificate{} - if len(c.security.CertPath) != 0 && len(c.security.KeyPath) != 0 { - // Load the client certificates from disk - certificate, err := tls.LoadX509KeyPair(c.security.CertPath, c.security.KeyPath) - if err != nil { - return nil, errors.Errorf("could not load client key pair: %s", err) - } - certificates = append(certificates, certificate) - } - - // Create a certificate pool from the certificate authority - certPool := x509.NewCertPool() - ca, err := ioutil.ReadFile(c.security.CAPath) - if err != nil { - return nil, errors.Errorf("could not read ca certificate: %s", err) - } - - // Append the certificates from the CA - if !certPool.AppendCertsFromPEM(ca) { - return nil, errors.New("failed to append ca certs") - } - - creds := credentials.NewTLS(&tls.Config{ - Certificates: certificates, - RootCAs: certPool, - }) - - opt = grpc.WithTransportCredentials(creds) - } - u, err := url.Parse(addr) - if err != nil { - return nil, errors.WithStack(err) - } - cc, err := grpc.Dial(u.Host, opt) + cc, err := grpcutil.GetClientConn(addr, c.security.CAPath, c.security.CertPath, c.security.KeyPath) if err != nil { return nil, errors.WithStack(err) } diff --git a/pkg/grpcutil/grpcutil.go b/pkg/grpcutil/grpcutil.go new file mode 100644 index 00000000000..af120b0f22f --- /dev/null +++ b/pkg/grpcutil/grpcutil.go @@ -0,0 +1,69 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package grpcutil + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net/url" + + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +// GetClientConn returns a gRPC client connection. +func GetClientConn(addr string, caPath string, certPath string, keyPath string) (*grpc.ClientConn, error) { + opt := grpc.WithInsecure() + if len(caPath) != 0 { + certificates := []tls.Certificate{} + if len(certPath) != 0 && len(keyPath) != 0 { + // Load the client certificates from disk + certificate, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, errors.Errorf("could not load client key pair: %s", err) + } + certificates = append(certificates, certificate) + } + + // Create a certificate pool from the certificate authority + certPool := x509.NewCertPool() + ca, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, errors.Errorf("could not read ca certificate: %s", err) + } + + // Append the certificates from the CA + if !certPool.AppendCertsFromPEM(ca) { + return nil, errors.New("failed to append ca certs") + } + + creds := credentials.NewTLS(&tls.Config{ + Certificates: certificates, + RootCAs: certPool, + }) + + opt = grpc.WithTransportCredentials(creds) + } + u, err := url.Parse(addr) + if err != nil { + return nil, errors.WithStack(err) + } + cc, err := grpc.Dial(u.Host, opt) + if err != nil { + return nil, errors.WithStack(err) + } + return cc, nil +} diff --git a/server/region_syncer/client.go b/server/region_syncer/client.go index 0e8140a1e5e..7e524b5ddfe 100644 --- a/server/region_syncer/client.go +++ b/server/region_syncer/client.go @@ -15,14 +15,14 @@ package syncer import ( "context" - "net/url" "time" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/pingcap/pd/pkg/grpcutil" "github.com/pingcap/pd/server/core" + "github.com/pkg/errors" "go.uber.org/zap" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -51,14 +51,9 @@ func (s *RegionSyncer) reset() { func (s *RegionSyncer) establish(addr string) (ClientStream, error) { s.reset() - u, err := url.Parse(addr) + cc, err := grpcutil.GetClientConn(addr, s.securityConfig.CAPath, s.securityConfig.CertPath, s.securityConfig.KeyPath) if err != nil { - return nil, err - } - - cc, err := grpc.Dial(u.Host, grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(msgSize))) - if err != nil { - return nil, err + return nil, errors.WithStack(err) } ctx, cancel := context.WithCancel(s.server.Context()) diff --git a/server/region_syncer/server.go b/server/region_syncer/server.go index ca430a71f23..e41d29fca68 100644 --- a/server/region_syncer/server.go +++ b/server/region_syncer/server.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/pingcap/pd/server/config" "github.com/pingcap/pd/server/core" "github.com/pkg/errors" "go.uber.org/zap" @@ -59,19 +60,21 @@ type Server interface { GetStorage() *core.Storage Name() string GetMetaRegions() []*metapb.Region + GetSecurityConfig() *config.SecurityConfig } // RegionSyncer is used to sync the region information without raft. type RegionSyncer struct { sync.RWMutex - streams map[string]ServerStream - ctx context.Context - cancel context.CancelFunc - server Server - closed chan struct{} - wg sync.WaitGroup - history *historyBuffer - limit *ratelimit.Bucket + streams map[string]ServerStream + ctx context.Context + cancel context.CancelFunc + server Server + closed chan struct{} + wg sync.WaitGroup + history *historyBuffer + limit *ratelimit.Bucket + securityConfig *config.SecurityConfig } // NewRegionSyncer returns a region syncer. @@ -81,11 +84,12 @@ type RegionSyncer struct { // no longer etcd but go-leveldb. func NewRegionSyncer(s Server) *RegionSyncer { return &RegionSyncer{ - streams: make(map[string]ServerStream), - server: s, - closed: make(chan struct{}), - history: newHistoryBuffer(defaultHistoryBufferSize, s.GetStorage().GetRegionStorage()), - limit: ratelimit.NewBucketWithRate(defaultBucketRate, defaultBucketCapacity), + streams: make(map[string]ServerStream), + server: s, + closed: make(chan struct{}), + history: newHistoryBuffer(defaultHistoryBufferSize, s.GetStorage().GetRegionStorage()), + limit: ratelimit.NewBucketWithRate(defaultBucketRate, defaultBucketCapacity), + securityConfig: s.GetSecurityConfig(), } }