From 6fe44d75b200cb406ff1cbc19f4a62400aa1db8e Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 2 Apr 2024 11:28:47 +0800 Subject: [PATCH] client: support specifying target member (#7909) ref tikv/pd#7905 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- .gitignore | 1 + client/errs/errno.go | 1 + client/grpcutil/grpcutil.go | 10 +++++----- client/http/client.go | 19 ++++++++++++++++++- client/http/client_test.go | 16 +++++++++++++++- client/http/interface.go | 5 ++++- client/http/request_info.go | 7 +++++++ client/mock_pd_service_discovery.go | 2 +- 8 files changed, 52 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 748d24872b6..b9be6099e24 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ coverage.xml coverage *.txt go.work* +embedded_assets_handler.go diff --git a/client/errs/errno.go b/client/errs/errno.go index c095bbe4b4a..c3f5c27275a 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -51,6 +51,7 @@ var ( ErrClientGetClusterInfo = errors.Normalize("get cluster info failed", errors.RFCCodeText("PD:client:ErrClientGetClusterInfo")) ErrClientUpdateMember = errors.Normalize("update member failed, %v", errors.RFCCodeText("PD:client:ErrUpdateMember")) ErrClientNoAvailableMember = errors.Normalize("no available member", errors.RFCCodeText("PD:client:ErrClientNoAvailableMember")) + ErrClientNoTargetMember = errors.Normalize("no target member", errors.RFCCodeText("PD:client:ErrClientNoTargetMember")) ErrClientProtoUnmarshal = errors.Normalize("failed to unmarshal proto", errors.RFCCodeText("PD:proto:ErrClientProtoUnmarshal")) ErrClientGetMultiResponse = errors.Normalize("get invalid value response %v, must only one", errors.RFCCodeText("PD:client:ErrClientGetMultiResponse")) ErrClientGetServingEndpoint = errors.Normalize("get serving endpoint failed", errors.RFCCodeText("PD:client:ErrClientGetServingEndpoint")) diff --git a/client/grpcutil/grpcutil.go b/client/grpcutil/grpcutil.go index fb9e84f0ca1..0e987825c02 100644 --- a/client/grpcutil/grpcutil.go +++ b/client/grpcutil/grpcutil.go @@ -124,17 +124,17 @@ func getValueFromMetadata(ctx context.Context, key string, f func(context.Contex // GetOrCreateGRPCConn returns the corresponding grpc client connection of the given addr. // Returns the old one if's already existed in the clientConns; otherwise creates a new one and returns it. -func GetOrCreateGRPCConn(ctx context.Context, clientConns *sync.Map, addr string, tlsCfg *tls.Config, opt ...grpc.DialOption) (*grpc.ClientConn, error) { - conn, ok := clientConns.Load(addr) +func GetOrCreateGRPCConn(ctx context.Context, clientConns *sync.Map, url string, tlsCfg *tls.Config, opt ...grpc.DialOption) (*grpc.ClientConn, error) { + conn, ok := clientConns.Load(url) if ok { // TODO: check the connection state. return conn.(*grpc.ClientConn), nil } dCtx, cancel := context.WithTimeout(ctx, dialTimeout) defer cancel() - cc, err := GetClientConn(dCtx, addr, tlsCfg, opt...) + cc, err := GetClientConn(dCtx, url, tlsCfg, opt...) failpoint.Inject("unreachableNetwork2", func(val failpoint.Value) { - if val, ok := val.(string); ok && val == addr { + if val, ok := val.(string); ok && val == url { cc = nil err = errors.Errorf("unreachable network") } @@ -142,7 +142,7 @@ func GetOrCreateGRPCConn(ctx context.Context, clientConns *sync.Map, addr string if err != nil { return nil, err } - conn, loaded := clientConns.LoadOrStore(addr, cc) + conn, loaded := clientConns.LoadOrStore(url, cc) if !loaded { // Successfully stored the connection. return cc, nil diff --git a/client/http/client.go b/client/http/client.go index 18802346a4c..30144ebe2c5 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -129,8 +129,13 @@ func (ci *clientInner) requestWithRetry( if len(clients) == 0 { return errs.ErrClientNoAvailableMember } + skipNum := 0 for _, cli := range clients { url := cli.GetURL() + if reqInfo.targetURL != "" && reqInfo.targetURL != url { + skipNum++ + continue + } statusCode, err = ci.doRequest(ctx, url, reqInfo, headerOpts...) if err == nil || noNeedRetry(statusCode) { return err @@ -138,6 +143,9 @@ func (ci *clientInner) requestWithRetry( log.Debug("[pd] request url failed", zap.String("source", ci.source), zap.Bool("is-leader", cli.IsConnectedToLeader()), zap.String("url", url), zap.Error(err)) } + if skipNum == len(clients) { + return errs.ErrClientNoTargetMember + } return err } if reqInfo.bo == nil { @@ -244,6 +252,7 @@ type client struct { callerID string respHandler respHandleFunc bo *retry.Backoffer + targetURL string } // ClientOption configures the HTTP client. @@ -343,6 +352,13 @@ func (c *client) WithBackoffer(bo *retry.Backoffer) Client { return &newClient } +// WithTargetURL sets and returns a new client with the given target URL. +func (c *client) WithTargetURL(targetURL string) Client { + newClient := *c + newClient.targetURL = targetURL + return &newClient +} + // Header key definition constants. const ( pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle" @@ -363,7 +379,8 @@ func (c *client) request(ctx context.Context, reqInfo *requestInfo, headerOpts . return c.inner.requestWithRetry(ctx, reqInfo. WithCallerID(c.callerID). WithRespHandler(c.respHandler). - WithBackoffer(c.bo), + WithBackoffer(c.bo). + WithTargetURL(c.targetURL), headerOpts...) } diff --git a/client/http/client_test.go b/client/http/client_test.go index 49faefefaec..8769fa53f9a 100644 --- a/client/http/client_test.go +++ b/client/http/client_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/stretchr/testify/require" + "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/retry" "go.uber.org/atomic" ) @@ -49,7 +50,7 @@ func TestPDAllowFollowerHandleHeader(t *testing.T) { re.Equal(2, checked) } -func TestCallerID(t *testing.T) { +func TestWithCallerID(t *testing.T) { re := require.New(t) checked := 0 expectedVal := atomic.NewString(defaultCallerID) @@ -96,3 +97,16 @@ func TestWithBackoffer(t *testing.T) { re.InDelta(3*time.Second, time.Since(start), float64(250*time.Millisecond)) re.ErrorIs(err, context.DeadlineExceeded) } + +func TestWithTargetURL(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := newClientWithMockServiceDiscovery("test-with-target-url", []string{"http://127.0.0.1", "http://127.0.0.2", "http://127.0.0.3"}) + defer c.Close() + + _, err := c.WithTargetURL("http://127.0.0.4").GetStatus(ctx) + re.ErrorIs(err, errs.ErrClientNoTargetMember) + _, err = c.WithTargetURL("http://127.0.0.2").GetStatus(ctx) + re.ErrorContains(err, "connect: connection refused") +} diff --git a/client/http/interface.go b/client/http/interface.go index 13d684e648b..7b15291d9e7 100644 --- a/client/http/interface.go +++ b/client/http/interface.go @@ -116,6 +116,8 @@ type Client interface { WithRespHandler(func(resp *http.Response, res any) error) Client // WithBackoffer sets and returns a new client with the given backoffer. WithBackoffer(*retry.Backoffer) Client + // WithTargetURL sets and returns a new client with the given target URL. + WithTargetURL(string) Client // Close gracefully closes the HTTP client. Close() } @@ -472,7 +474,8 @@ func (c *client) GetStatus(ctx context.Context) (*State, error) { WithName(getStatusName). WithURI(Status). WithMethod(http.MethodGet). - WithResp(&status)) + WithResp(&status), + WithAllowFollowerHandle()) if err != nil { return nil, err } diff --git a/client/http/request_info.go b/client/http/request_info.go index 93a4ecf5307..0ce7072d1ba 100644 --- a/client/http/request_info.go +++ b/client/http/request_info.go @@ -91,6 +91,7 @@ type requestInfo struct { res any respHandler respHandleFunc bo *retry.Backoffer + targetURL string } // newRequestInfo creates a new request info. @@ -146,6 +147,12 @@ func (ri *requestInfo) WithBackoffer(bo *retry.Backoffer) *requestInfo { return ri } +// WithTargetURL sets the target URL of the request. +func (ri *requestInfo) WithTargetURL(targetURL string) *requestInfo { + ri.targetURL = targetURL + return ri +} + func (ri *requestInfo) getURL(addr string) string { return fmt.Sprintf("%s%s", addr, ri.uri) } diff --git a/client/mock_pd_service_discovery.go b/client/mock_pd_service_discovery.go index b33c8405af9..17613a2f9e4 100644 --- a/client/mock_pd_service_discovery.go +++ b/client/mock_pd_service_discovery.go @@ -41,7 +41,7 @@ func NewMockPDServiceDiscovery(urls []string, tlsCfg *tls.Config) *mockPDService func (m *mockPDServiceDiscovery) Init() error { m.clients = make([]ServiceClient, 0, len(m.urls)) for _, url := range m.urls { - m.clients = append(m.clients, newPDServiceClient(url, url, nil, false)) + m.clients = append(m.clients, newPDServiceClient(url, m.urls[0], nil, false)) } return nil }