Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: fix scheme when tls config not match #7901

Merged
merged 19 commits into from
Mar 12, 2024
Merged
2 changes: 1 addition & 1 deletion client/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (ci *clientInner) requestWithRetry(
return errs.ErrClientNoAvailableMember
}
for _, cli := range clients {
addr := cli.GetHTTPAddress()
addr := cli.GetAddress()
statusCode, err = ci.doRequest(ctx, addr, reqInfo, headerOpts...)
if err == nil || noNeedRetry(statusCode) {
return err
Expand Down
2 changes: 1 addition & 1 deletion client/mock_pd_service_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func NewMockPDServiceDiscovery(urls []string, tlsCfg *tls.Config) *mockPDService
func (m *mockPDServiceDiscovery) Init() error {
m.clients = make([]ServiceClient, 0, len(m.urls))
for _, url := range m.urls {
m.clients = append(m.clients, newPDServiceClient(url, url, m.tlsCfg, nil, false))
m.clients = append(m.clients, newPDServiceClient(url, url, nil, false))
}
return nil
}
Expand Down
84 changes: 34 additions & 50 deletions client/pd_service_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ const (
updateMemberTimeout = time.Second // Use a shorter timeout to recover faster from network isolation.
updateMemberBackOffBaseTime = 100 * time.Millisecond

httpScheme = "http"
httpsScheme = "https"
httpScheme = "http://"
httpsScheme = "https://"
)

// MemberHealthCheckInterval might be changed in the unit to shorten the testing time.
Expand Down Expand Up @@ -124,10 +124,8 @@ type ServiceDiscovery interface {

// ServiceClient is an interface that defines a set of operations for a raw PD gRPC client to specific PD server.
type ServiceClient interface {
// GetAddress returns the address information of the PD server.
// GetAddress returns the address with HTTP scheme of the PD server.
Copy link
Member

@HuSharp HuSharp Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests seems like will return https scheme?

GetAddress() string
// GetHTTPAddress returns the address with HTTP scheme of the PD server.
GetHTTPAddress() string
// GetClientConn returns the gRPC connection of the service client
GetClientConn() *grpc.ClientConn
// BuildGRPCTargetContext builds a context object with a gRPC context.
Expand Down Expand Up @@ -158,34 +156,15 @@ type pdServiceClient struct {
networkFailure atomic.Bool
}

func newPDServiceClient(addr, leaderAddr string, tlsCfg *tls.Config, conn *grpc.ClientConn, isLeader bool) ServiceClient {
var httpAddress string
if tlsCfg == nil {
if strings.HasPrefix(addr, httpsScheme) {
addr = strings.TrimPrefix(addr, httpsScheme)
httpAddress = fmt.Sprintf("%s%s", httpScheme, addr)
} else if strings.HasPrefix(addr, httpScheme) {
httpAddress = addr
} else {
httpAddress = fmt.Sprintf("%s://%s", httpScheme, addr)
}
} else {
if strings.HasPrefix(addr, httpsScheme) {
httpAddress = addr
} else if strings.HasPrefix(addr, httpScheme) {
addr = strings.TrimPrefix(addr, httpScheme)
httpAddress = fmt.Sprintf("%s%s", httpsScheme, addr)
} else {
httpAddress = fmt.Sprintf("%s://%s", httpsScheme, addr)
}
}

// NOTE: In the current implementation, the address passed in is bound to have an http scheme,
// because it is processed in `newPDServiceDiscovery`, and the url returned by etcd member is its own.
// When testing, the address is also bound to have an http scheme.
HuSharp marked this conversation as resolved.
Show resolved Hide resolved
func newPDServiceClient(addr, leaderAddr string, conn *grpc.ClientConn, isLeader bool) ServiceClient {
cli := &pdServiceClient{
HuSharp marked this conversation as resolved.
Show resolved Hide resolved
addr: addr,
httpAddress: httpAddress,
conn: conn,
isLeader: isLeader,
leaderAddr: leaderAddr,
addr: addr,
conn: conn,
isLeader: isLeader,
leaderAddr: leaderAddr,
}
if conn == nil {
cli.networkFailure.Store(true)
Expand All @@ -201,14 +180,6 @@ func (c *pdServiceClient) GetAddress() string {
return c.addr
}

// GetHTTPAddress implements ServiceClient.
func (c *pdServiceClient) GetHTTPAddress() string {
if c == nil {
return ""
}
return c.httpAddress
}

// BuildGRPCTargetContext implements ServiceClient.
func (c *pdServiceClient) BuildGRPCTargetContext(ctx context.Context, toLeader bool) context.Context {
if c == nil || c.isLeader {
Expand Down Expand Up @@ -506,7 +477,7 @@ func newPDServiceDiscovery(
tlsCfg: tlsCfg,
option: option,
}
urls = addrsToUrls(urls)
urls = addrsToUrls(urls, tlsCfg)
pdsd.urls.Store(urls)
return pdsd
}
Expand Down Expand Up @@ -1032,7 +1003,7 @@ func (c *pdServiceDiscovery) switchLeader(addrs []string) (bool, error) {
// If gRPC connect is created successfully or leader is new, still saves.
if addr != oldLeader.GetAddress() || newConn != nil {
// Set PD leader and Global TSO Allocator (which is also the PD leader)
leaderClient := newPDServiceClient(addr, addr, c.tlsCfg, newConn, true)
leaderClient := newPDServiceClient(addr, addr, newConn, true)
c.leader.Store(leaderClient)
}
// Run callbacks
Expand Down Expand Up @@ -1069,15 +1040,15 @@ func (c *pdServiceDiscovery) updateFollowers(members []*pdpb.Member, leader *pdp
log.Warn("[pd] failed to connect follower", zap.String("follower", addr), errs.ZapError(err))
continue
}
follower := newPDServiceClient(addr, leader.GetClientUrls()[0], c.tlsCfg, conn, false)
follower := newPDServiceClient(addr, leader.GetClientUrls()[0], conn, false)
c.followers.Store(addr, follower)
changed = true
}
delete(followers, addr)
} else {
changed = true
conn, err := c.GetOrCreateGRPCConn(addr)
follower := newPDServiceClient(addr, leader.GetClientUrls()[0], c.tlsCfg, conn, false)
follower := newPDServiceClient(addr, leader.GetClientUrls()[0], conn, false)
if err != nil || conn == nil {
log.Warn("[pd] failed to connect follower", zap.String("follower", addr), errs.ZapError(err))
}
Expand Down Expand Up @@ -1150,15 +1121,28 @@ func (c *pdServiceDiscovery) GetOrCreateGRPCConn(addr string) (*grpc.ClientConn,
return grpcutil.GetOrCreateGRPCConn(c.ctx, &c.clientConns, addr, c.tlsCfg, c.option.gRPCDialOptions...)
}

func addrsToUrls(addrs []string) []string {
func addrsToUrls(addrs []string, tlsCfg *tls.Config) []string {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func addrsToUrls(addrs []string, tlsCfg *tls.Config) []string {
func addrsToURLs(addrs []string, tlsCfg *tls.Config) []string {

// Add default schema "http://" to addrs.
urls := make([]string, 0, len(addrs))
for _, addr := range addrs {
if strings.Contains(addr, "://") {
urls = append(urls, addr)
} else {
urls = append(urls, "http://"+addr)
}
urls = append(urls, addrToUrl(addr, tlsCfg))
}
return urls
}

func addrToUrl(addr string, tlsCfg *tls.Config) string {
if tlsCfg == nil {
if strings.HasPrefix(addr, httpsScheme) {
addr = fmt.Sprintf("%s%s", httpScheme, strings.TrimPrefix(addr, httpsScheme))
} else if !strings.HasPrefix(addr, httpScheme) {
addr = fmt.Sprintf("%s%s", httpScheme, addr)
}
} else {
if strings.HasPrefix(addr, httpScheme) {
Copy link
Member Author

@CabinfeverB CabinfeverB Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used for url.Parse. If we don't add scheme, it will return error when creating gRPC connection.
If etcd member URL is not the same as this, the connection map will add a new KV pair(holds both https://127.0.0.1 and http://127.0.0.1)

addr = fmt.Sprintf("%s%s", httpsScheme, strings.TrimPrefix(addr, httpScheme))
} else if !strings.HasPrefix(addr, httpsScheme) {
addr = fmt.Sprintf("%s%s", httpsScheme, addr)
}
}
return addr
}
40 changes: 22 additions & 18 deletions client/pd_service_discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,14 @@ func (suite *serviceClientTestSuite) SetupSuite() {
leaderConn, err1 := grpc.Dial(suite.leaderServer.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
followerConn, err2 := grpc.Dial(suite.followerServer.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err1 == nil && err2 == nil {
suite.followerClient = newPDServiceClient(suite.followerServer.addr, suite.leaderServer.addr, nil, followerConn, false)
suite.leaderClient = newPDServiceClient(suite.leaderServer.addr, suite.leaderServer.addr, nil, leaderConn, true)
suite.followerClient = newPDServiceClient(
addrToUrl(suite.followerServer.addr, nil),
addrToUrl(suite.leaderServer.addr, nil),
followerConn, false)
suite.leaderClient = newPDServiceClient(
addrToUrl(suite.leaderServer.addr, nil),
addrToUrl(suite.leaderServer.addr, nil),
leaderConn, true)
suite.followerServer.server.leaderConn = suite.leaderClient.GetClientConn()
suite.followerServer.server.leaderAddr = suite.leaderClient.GetAddress()
return
Expand All @@ -166,16 +172,14 @@ func (suite *serviceClientTestSuite) TearDownSuite() {

func (suite *serviceClientTestSuite) TestServiceClient() {
re := suite.Require()
leaderAddress := suite.leaderServer.addr
followerAddress := suite.followerServer.addr
leaderAddress := addrToUrl(suite.leaderServer.addr, nil)
followerAddress := addrToUrl(suite.followerServer.addr, nil)

follower := suite.followerClient
leader := suite.leaderClient

re.Equal(follower.GetAddress(), followerAddress)
re.Equal(leader.GetAddress(), leaderAddress)
re.Equal(follower.GetHTTPAddress(), "http://"+followerAddress)
re.Equal(leader.GetHTTPAddress(), "http://"+leaderAddress)

re.True(follower.Available())
re.True(leader.Available())
Expand Down Expand Up @@ -303,16 +307,16 @@ func (suite *serviceClientTestSuite) TestServiceClientBalancer() {

func TestHTTPScheme(t *testing.T) {
re := require.New(t)
cli := newPDServiceClient("127.0.0.1:2379", "127.0.0.1:2379", nil, nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("https://127.0.0.1:2379", "127.0.0.1:2379", nil, nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("http://127.0.0.1:2379", "127.0.0.1:2379", nil, nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("127.0.0.1:2379", "127.0.0.1:2379", &tls.Config{}, nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("https://127.0.0.1:2379", "127.0.0.1:2379", &tls.Config{}, nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("http://127.0.0.1:2379", "127.0.0.1:2379", &tls.Config{}, nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetHTTPAddress())
cli := newPDServiceClient(addrToUrl("127.0.0.1:2379", nil), addrToUrl("127.0.0.1:2379", nil), nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("https://127.0.0.1:2379", nil), addrToUrl("127.0.0.1:2379", nil), nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("http://127.0.0.1:2379", nil), addrToUrl("127.0.0.1:2379", nil), nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("127.0.0.1:2379", &tls.Config{}), addrToUrl("127.0.0.1:2379", &tls.Config{}), nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("https://127.0.0.1:2379", &tls.Config{}), addrToUrl("127.0.0.1:2379", &tls.Config{}), nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("http://127.0.0.1:2379", &tls.Config{}), addrToUrl("127.0.0.1:2379", &tls.Config{}), nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetAddress())
}
17 changes: 11 additions & 6 deletions tests/integrations/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ func TestClientLeaderChange(t *testing.T) {
defer cluster.Destroy()

endpoints := runServer(re, cluster)
cli := setupCli(re, ctx, endpoints)
endpointsWithWrongURL := append([]string{}, endpoints...)
// inject wrong http scheme
for i := range endpointsWithWrongURL {
endpointsWithWrongURL[i] = "https://" + strings.TrimPrefix(endpointsWithWrongURL[i], "http://")
}
cli := setupCli(re, ctx, endpointsWithWrongURL)
defer cli.Close()
innerCli, ok := cli.(interface{ GetServiceDiscovery() pd.ServiceDiscovery })
re.True(ok)
Expand All @@ -127,14 +132,14 @@ func TestClientLeaderChange(t *testing.T) {
re.True(cluster.CheckTSOUnique(ts1))

leader := cluster.GetLeader()
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader).GetConfig().ClientUrls)
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader))

err = cluster.GetServer(leader).Stop()
re.NoError(err)
leader = cluster.WaitLeader()
re.NotEmpty(leader)

waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader).GetConfig().ClientUrls)
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader))

// Check TS won't fall back after leader changed.
testutil.Eventually(re, func() bool {
Expand Down Expand Up @@ -955,10 +960,10 @@ func setupCli(re *require.Assertions, ctx context.Context, endpoints []string, o
return cli
}

func waitLeader(re *require.Assertions, cli pd.ServiceDiscovery, leader string) {
func waitLeader(re *require.Assertions, cli pd.ServiceDiscovery, leader *tests.TestServer) {
testutil.Eventually(re, func() bool {
cli.ScheduleCheckMemberChanged()
return cli.GetServingAddr() == leader
return cli.GetServingAddr() == leader.GetConfig().ClientUrls && leader.GetAddr() == cli.GetServingAddr()
})
}

Expand Down Expand Up @@ -1853,7 +1858,7 @@ func (suite *clientTestSuite) TestMemberUpdateBackOff() {
re.True(ok)

leader := cluster.GetLeader()
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader).GetConfig().ClientUrls)
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader))
memberID := cluster.GetServer(leader).GetLeader().GetMemberId()

re.NoError(failpoint.Enable("github.com/tikv/pd/server/leaderLoopCheckAgain", fmt.Sprintf("return(\"%d\")", memberID)))
Expand Down
Loading