Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: fix scheme when tls config not match #7901

Merged
merged 19 commits into from
Mar 12, 2024
Merged
16 changes: 8 additions & 8 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ type Client interface {
GetClusterID(ctx context.Context) uint64
// GetAllMembers gets the members Info from PD
GetAllMembers(ctx context.Context) ([]*pdpb.Member, error)
// GetLeaderAddr returns current leader's address. It returns "" before
// GetLeaderURL returns current leader's URL. It returns "" before
// syncing leader from server.
GetLeaderAddr() string
GetLeaderURL() string
// GetRegion gets a region and its leader Peer from PD by key.
// The region may expire after split. Caller is responsible for caching and
// taking care of region change.
Expand Down Expand Up @@ -575,7 +575,7 @@ func (c *client) setup() error {
}

// Register callbacks
c.pdSvcDiscovery.AddServingAddrSwitchedCallback(c.scheduleUpdateTokenConnection)
c.pdSvcDiscovery.AddServingURLSwitchedCallback(c.scheduleUpdateTokenConnection)

// Create dispatchers
c.createTokenDispatcher()
Expand Down Expand Up @@ -680,9 +680,9 @@ func (c *client) GetClusterID(context.Context) uint64 {
return c.pdSvcDiscovery.GetClusterID()
}

// GetLeaderAddr returns the leader address.
func (c *client) GetLeaderAddr() string {
return c.pdSvcDiscovery.GetServingAddr()
// GetLeaderURL returns the leader URL.
func (c *client) GetLeaderURL() string {
return c.pdSvcDiscovery.GetServingURL()
}

// GetServiceDiscovery returns the client-side service discovery object
Expand Down Expand Up @@ -1403,8 +1403,8 @@ func IsLeaderChange(err error) bool {
}

func trimHTTPPrefix(str string) string {
str = strings.TrimPrefix(str, "http://")
str = strings.TrimPrefix(str, "https://")
str = strings.TrimPrefix(str, httpScheme)
str = strings.TrimPrefix(str, httpsScheme)
return str
}

Expand Down
12 changes: 6 additions & 6 deletions client/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,13 @@ func (ci *clientInner) requestWithRetry(
return errs.ErrClientNoAvailableMember
}
for _, cli := range clients {
addr := cli.GetHTTPAddress()
statusCode, err = ci.doRequest(ctx, addr, reqInfo, headerOpts...)
url := cli.GetURL()
statusCode, err = ci.doRequest(ctx, url, reqInfo, headerOpts...)
if err == nil || noNeedRetry(statusCode) {
return err
}
log.Debug("[pd] request addr failed",
zap.String("source", ci.source), zap.Bool("is-leader", cli.IsConnectedToLeader()), zap.String("addr", addr), zap.Error(err))
log.Debug("[pd] request url failed",
zap.String("source", ci.source), zap.Bool("is-leader", cli.IsConnectedToLeader()), zap.String("url", url), zap.Error(err))
}
return err
}
Expand All @@ -160,19 +160,19 @@ func noNeedRetry(statusCode int) bool {

func (ci *clientInner) doRequest(
ctx context.Context,
addr string, reqInfo *requestInfo,
url string, reqInfo *requestInfo,
headerOpts ...HeaderOption,
) (int, error) {
var (
source = ci.source
callerID = reqInfo.callerID
name = reqInfo.name
url = reqInfo.getURL(addr)
method = reqInfo.method
body = reqInfo.body
res = reqInfo.res
respHandler = reqInfo.respHandler
)
url = reqInfo.getURL(url)
logFields := []zap.Field{
zap.String("source", source),
zap.String("name", name),
Expand Down
4 changes: 2 additions & 2 deletions client/meta_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (c *client) Put(ctx context.Context, key, value []byte, opts ...OpOption) (
Lease: options.lease,
PrevKv: options.prevKv,
}
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr())
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderURL())
cli := c.metaStorageClient()
if cli == nil {
cancel()
Expand Down Expand Up @@ -162,7 +162,7 @@ func (c *client) Get(ctx context.Context, key []byte, opts ...OpOption) (*meta_s
Limit: options.limit,
Revision: options.revision,
}
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr())
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderURL())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about changing the definition of func BuildForwardContext to BuildForwardContext(..., url string).

cli := c.metaStorageClient()
if cli == nil {
cancel()
Expand Down
16 changes: 8 additions & 8 deletions client/mock_pd_service_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func NewMockPDServiceDiscovery(urls []string, tlsCfg *tls.Config) *mockPDService
func (m *mockPDServiceDiscovery) Init() error {
m.clients = make([]ServiceClient, 0, len(m.urls))
for _, url := range m.urls {
m.clients = append(m.clients, newPDServiceClient(url, url, m.tlsCfg, nil, false))
m.clients = append(m.clients, newPDServiceClient(url, url, nil, false))
}
return nil
}
Expand All @@ -62,13 +62,13 @@ func (m *mockPDServiceDiscovery) GetKeyspaceGroupID() uint32
func (m *mockPDServiceDiscovery) GetServiceURLs() []string { return nil }
func (m *mockPDServiceDiscovery) GetServingEndpointClientConn() *grpc.ClientConn { return nil }
func (m *mockPDServiceDiscovery) GetClientConns() *sync.Map { return nil }
func (m *mockPDServiceDiscovery) GetServingAddr() string { return "" }
func (m *mockPDServiceDiscovery) GetBackupAddrs() []string { return nil }
func (m *mockPDServiceDiscovery) GetServingURL() string { return "" }
func (m *mockPDServiceDiscovery) GetBackupURLs() []string { return nil }
func (m *mockPDServiceDiscovery) GetServiceClient() ServiceClient { return nil }
func (m *mockPDServiceDiscovery) GetOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) {
func (m *mockPDServiceDiscovery) GetOrCreateGRPCConn(url string) (*grpc.ClientConn, error) {
return nil, nil
}
func (m *mockPDServiceDiscovery) ScheduleCheckMemberChanged() {}
func (m *mockPDServiceDiscovery) CheckMemberChanged() error { return nil }
func (m *mockPDServiceDiscovery) AddServingAddrSwitchedCallback(callbacks ...func()) {}
func (m *mockPDServiceDiscovery) AddServiceAddrsSwitchedCallback(callbacks ...func()) {}
func (m *mockPDServiceDiscovery) ScheduleCheckMemberChanged() {}
func (m *mockPDServiceDiscovery) CheckMemberChanged() error { return nil }
func (m *mockPDServiceDiscovery) AddServingURLSwitchedCallback(callbacks ...func()) {}
func (m *mockPDServiceDiscovery) AddServiceURLsSwitchedCallback(callbacks ...func()) {}
Loading
Loading