Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client,server: better err msg when PD endpoint missing certificate #2138

Merged
merged 12 commits into from
Jun 29, 2021
10 changes: 4 additions & 6 deletions cmd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,11 @@ func newCliCommand() *cobra.Command {
if err != nil {
return errors.Annotate(err, "fail to validate TLS settings")
}
if tlsConfig != nil {
if strings.Contains(cliPdAddr, "http://") {
return errors.New("PD endpoint scheme should be https")
}
} else if !strings.Contains(cliPdAddr, "http://") {
return errors.New("PD endpoint scheme should be http")

if err := verifyPdEndpoint(cliPdAddr, tlsConfig != nil); err != nil {
return errors.Trace(err)
}

grpcTLSOption, err := credential.ToGRPCDialOption()
if err != nil {
return errors.Annotate(err, "fail to validate TLS settings")
Expand Down
8 changes: 2 additions & 6 deletions cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,8 @@ func loadAndVerifyServerConfig(cmd *cobra.Command) (*config.ServerConfig, error)
return nil, cerror.ErrInvalidServerOption.GenWithStack("empty PD address")
}
for _, ep := range strings.Split(serverPdAddr, ",") {
if conf.Security.IsTLSEnabled() {
if strings.Index(ep, "http://") == 0 {
return nil, cerror.ErrInvalidServerOption.GenWithStack("PD endpoint scheme should be https")
}
} else if strings.Index(ep, "http://") != 0 {
return nil, cerror.ErrInvalidServerOption.GenWithStack("PD endpoint scheme should be http")
if err := verifyPdEndpoint(ep, conf.Security.IsTLSEnabled()); err != nil {
return nil, cerror.ErrInvalidServerOption.Wrap(err).GenWithStackByCause()
}
}

Expand Down
16 changes: 15 additions & 1 deletion cmd/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,21 @@ func (s *serverSuite) TestLoadAndVerifyServerConfig(c *check.C) {
initServerCmd(cmd)
c.Assert(cmd.ParseFlags([]string{"--pd=aa"}), check.IsNil)
_, err = loadAndVerifyServerConfig(cmd)
c.Assert(err, check.ErrorMatches, ".*PD endpoint scheme should be http.*")
c.Assert(err, check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// test invalid PD address(without host)
cmd = new(cobra.Command)
initServerCmd(cmd)
c.Assert(cmd.ParseFlags([]string{"--pd=http://"}), check.IsNil)
_, err = loadAndVerifyServerConfig(cmd)
c.Assert(err, check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// test missing certificate
cmd = new(cobra.Command)
initServerCmd(cmd)
c.Assert(cmd.ParseFlags([]string{"--pd=https://aa"}), check.IsNil)
_, err = loadAndVerifyServerConfig(cmd)
c.Assert(err, check.ErrorMatches, ".*PD endpoint scheme is https, please provide certificate.*")

// test undefined flag
cmd = new(cobra.Command)
Expand Down
29 changes: 29 additions & 0 deletions cmd/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ var errOwnerNotFound = liberrors.New("owner not found")

var tsGapWarnning int64 = 86400 * 1000 // 1 day in milliseconds

// Endpoint schemes.
const (
HTTP = "http"
HTTPS = "https"
)

func addSecurityFlags(flags *pflag.FlagSet, isServer bool) {
flags.StringVar(&caPath, "ca", "", "CA certificate path for TLS connection")
flags.StringVar(&certPath, "cert", "", "Certificate path for TLS connection")
Expand Down Expand Up @@ -375,3 +381,26 @@ func confirmLargeDataGap(ctx context.Context, cmd *cobra.Command, startTs uint64
}
return nil
}

// verifyPdEndpoint verifies whether the pd endpoint is a valid http or https URL.
// The certificate is required when using https.
func verifyPdEndpoint(pdEndpoint string, useTLS bool) error {
u, err := url.Parse(pdEndpoint)
if err != nil {
return errors.Annotate(err, "parse PD endpoint")
}
if (u.Scheme != HTTP && u.Scheme != HTTPS) || u.Host == "" {
return errors.New("PD endpoint should be a valid http or https URL")
}

if useTLS {
if u.Scheme == HTTP {
return errors.New("PD endpoint scheme should be https")
}
} else {
if u.Scheme == HTTPS {
return errors.New("PD endpoint scheme is https, please provide certificate")
}
}
return nil
}
37 changes: 37 additions & 0 deletions cmd/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,40 @@ func (s *utilsSuite) TestProxyFields(c *check.C) {
}
}
}

func (s *utilsSuite) TestVerifyPdEndpoint(c *check.C) {
defer testleak.AfterTest(c)()
// empty URL.
url := ""
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// invalid URL.
url = "\n hi"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*invalid control character in URL.*")

// http URL without host.
url = "http://"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// https URL without host.
url = "https://"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// postgres scheme.
url = "postgres://postgres@localhost/cargo_registry"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// https scheme without TLS.
url = "https://aa"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint scheme is https, please provide certificate.*")

// http scheme with TLS.
url = "http://aa"
c.Assert(verifyPdEndpoint(url, true), check.ErrorMatches, ".*PD endpoint scheme should be https.*")

// valid http URL.
c.Assert(verifyPdEndpoint("http://aa", false), check.IsNil)

// valid https URL with TLS.
c.Assert(verifyPdEndpoint("https://aa", true), check.IsNil)
}