diff --git a/client/client.go b/client/client.go index 72e1c487df6..abe399c5d02 100644 --- a/client/client.go +++ b/client/client.go @@ -15,6 +15,8 @@ package pd import ( "context" + "reflect" + "sort" "strings" "sync" "time" @@ -135,7 +137,8 @@ type client struct { ctx context.Context cancel context.CancelFunc - security SecurityOption + security SecurityOption + gRPCDialOptions []grpc.DialOption } // SecurityOption records options about tls @@ -145,13 +148,23 @@ type SecurityOption struct { KeyPath string } +// ClientOption configures client. +type ClientOption func(c *client) + +// WithGRPCDialOptions configures the client with gRPC dial options. +func WithGRPCDialOptions(opts ...grpc.DialOption) ClientOption { + return func(c *client) { + c.gRPCDialOptions = append(c.gRPCDialOptions, opts...) + } +} + // NewClient creates a PD client. -func NewClient(pdAddrs []string, security SecurityOption) (Client, error) { - return NewClientWithContext(context.Background(), pdAddrs, security) +func NewClient(pdAddrs []string, security SecurityOption, opts ...ClientOption) (Client, error) { + return NewClientWithContext(context.Background(), pdAddrs, security, opts...) } // NewClientWithContext creates a PD client with context. -func NewClientWithContext(ctx context.Context, pdAddrs []string, security SecurityOption) (Client, error) { +func NewClientWithContext(ctx context.Context, pdAddrs []string, security SecurityOption, opts ...ClientOption) (Client, error) { log.Info("[pd] create pd client with endpoints", zap.Strings("pd-address", pdAddrs)) ctx1, cancel := context.WithCancel(ctx) c := &client{ @@ -164,6 +177,9 @@ func NewClientWithContext(ctx context.Context, pdAddrs []string, security Securi security: security, } c.connMu.clientConns = make(map[string]*grpc.ClientConn) + for _, opt := range opts { + opt(c) + } if err := c.initRetry(c.initClusterID); err != nil { cancel() @@ -188,6 +204,14 @@ func (c *client) updateURLs(members []*pdpb.Member) { for _, m := range members { urls = append(urls, m.GetClientUrls()...) } + + sort.Strings(urls) + // the url list is same. + if reflect.DeepEqual(c.urls, urls) { + return + } + + log.Info("[pd] update member urls", zap.Strings("old-urls", c.urls), zap.Strings("new-urls", urls)) c.urls = urls } @@ -228,7 +252,7 @@ func (c *client) updateLeader() error { ctx, cancel := context.WithTimeout(c.ctx, updateLeaderTimeout) members, err := c.getMembers(ctx, u) if err != nil { - log.Warn("cannot update leader", zap.String("address", u), zap.Error(err)) + log.Warn("[pd] cannot update leader", zap.String("address", u), zap.Error(err)) } cancel() if err != nil || members.GetLeader() == nil || len(members.GetLeader().GetClientUrls()) == 0 { @@ -289,7 +313,7 @@ func (c *client) getOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) { return conn, nil } - cc, err := grpcutil.GetClientConn(addr, c.security.CAPath, c.security.CertPath, c.security.KeyPath) + cc, err := grpcutil.GetClientConn(addr, c.security.CAPath, c.security.CertPath, c.security.KeyPath, c.gRPCDialOptions...) if err != nil { return nil, errors.WithStack(err) } diff --git a/client/client_test.go b/client/client_test.go index 6a8fd8f5fe4..1943eab6722 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/pd/server" "github.com/pingcap/pd/server/core" "go.uber.org/goleak" + "google.golang.org/grpc" ) func TestClient(t *testing.T) { @@ -470,6 +471,28 @@ func (s *testClientSuite) TestScatterRegion(c *C) { c.Succeed() } +func (s *testClientSuite) TestUpdateURLs(c *C) { + members := []*pdpb.Member{ + {Name: "pd4", ClientUrls: []string{"tmp//pd4"}}, + {Name: "pd1", ClientUrls: []string{"tmp//pd1"}}, + {Name: "pd3", ClientUrls: []string{"tmp//pd3"}}, + {Name: "pd2", ClientUrls: []string{"tmp//pd2"}}, + } + getURLs := func(ms []*pdpb.Member) (urls []string) { + for _, m := range ms { + urls = append(urls, m.GetClientUrls()[0]) + } + return + } + cli := &client{} + cli.updateURLs(members[1:]) + c.Assert(cli.urls, DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2]})) + cli.updateURLs(members[1:]) + c.Assert(cli.urls, DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2]})) + cli.updateURLs(members) + c.Assert(cli.urls, DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]})) +} + func (s *testClientSuite) TestTsLessEqual(c *C) { c.Assert(tsLessEqual(9, 9, 9, 9), IsTrue) c.Assert(tsLessEqual(8, 9, 9, 8), IsTrue) @@ -490,3 +513,17 @@ func (s *testClientCtxSuite) TestClientCtx(c *C) { c.Assert(err, NotNil) c.Assert(time.Since(start), Less, time.Second*4) } + +var _ = Suite(&testClientDialOptionSuite{}) + +type testClientDialOptionSuite struct{} + +func (s *testClientDialOptionSuite) TestGRPCDialOption(c *C) { + start := time.Now() + ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond) + defer cancel() + // nolint + _, err := NewClientWithContext(ctx, []string{"localhost:8080"}, SecurityOption{}, WithGRPCDialOptions(grpc.WithBlock(), grpc.WithTimeout(time.Second))) + c.Assert(err, NotNil) + c.Assert(time.Since(start), Greater, 800*time.Millisecond) +}