diff --git a/cmd/client.go b/cmd/client.go index 3bffc3fb804..dc3bac3533f 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -137,13 +137,11 @@ func newCliCommand() *cobra.Command { if err != nil { return errors.Annotate(err, "fail to validate TLS settings") } - if tlsConfig != nil { - if strings.Contains(cliPdAddr, "http://") { - return errors.New("PD endpoint scheme should be https") - } - } else if !strings.Contains(cliPdAddr, "http://") { - return errors.New("PD endpoint scheme should be http") + + if err := verifyPdEndpoint(cliPdAddr, tlsConfig != nil); err != nil { + return errors.Trace(err) } + grpcTLSOption, err := credential.ToGRPCDialOption() if err != nil { return errors.Annotate(err, "fail to validate TLS settings") diff --git a/cmd/server.go b/cmd/server.go index ed9fc2a128d..b57076273f2 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -203,12 +203,8 @@ func loadAndVerifyServerConfig(cmd *cobra.Command) (*config.ServerConfig, error) return nil, cerror.ErrInvalidServerOption.GenWithStack("empty PD address") } for _, ep := range strings.Split(serverPdAddr, ",") { - if conf.Security.IsTLSEnabled() { - if strings.Index(ep, "http://") == 0 { - return nil, cerror.ErrInvalidServerOption.GenWithStack("PD endpoint scheme should be https") - } - } else if strings.Index(ep, "http://") != 0 { - return nil, cerror.ErrInvalidServerOption.GenWithStack("PD endpoint scheme should be http") + if err := verifyPdEndpoint(ep, conf.Security.IsTLSEnabled()); err != nil { + return nil, cerror.ErrInvalidServerOption.Wrap(err).GenWithStackByCause() } } diff --git a/cmd/server_test.go b/cmd/server_test.go index c3c444b2908..0cfaa158938 100644 --- a/cmd/server_test.go +++ b/cmd/server_test.go @@ -88,7 +88,21 @@ func (s *serverSuite) TestLoadAndVerifyServerConfig(c *check.C) { initServerCmd(cmd) c.Assert(cmd.ParseFlags([]string{"--pd=aa"}), check.IsNil) _, err = loadAndVerifyServerConfig(cmd) - c.Assert(err, check.ErrorMatches, ".*PD endpoint scheme should be http.*") + c.Assert(err, check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*") + + // test invalid PD address(without host) + cmd = new(cobra.Command) + initServerCmd(cmd) + c.Assert(cmd.ParseFlags([]string{"--pd=http://"}), check.IsNil) + _, err = loadAndVerifyServerConfig(cmd) + c.Assert(err, check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*") + + // test missing certificate + cmd = new(cobra.Command) + initServerCmd(cmd) + c.Assert(cmd.ParseFlags([]string{"--pd=https://aa"}), check.IsNil) + _, err = loadAndVerifyServerConfig(cmd) + c.Assert(err, check.ErrorMatches, ".*PD endpoint scheme is https, please provide certificate.*") // test undefined flag cmd = new(cobra.Command) diff --git a/cmd/util.go b/cmd/util.go index 77aaf0d609b..abceca37cdb 100644 --- a/cmd/util.go +++ b/cmd/util.go @@ -60,6 +60,12 @@ var errOwnerNotFound = liberrors.New("owner not found") var tsGapWarnning int64 = 86400 * 1000 // 1 day in milliseconds +// Endpoint schemes. +const ( + HTTP = "http" + HTTPS = "https" +) + func addSecurityFlags(flags *pflag.FlagSet, isServer bool) { flags.StringVar(&caPath, "ca", "", "CA certificate path for TLS connection") flags.StringVar(&certPath, "cert", "", "Certificate path for TLS connection") @@ -375,3 +381,26 @@ func confirmLargeDataGap(ctx context.Context, cmd *cobra.Command, startTs uint64 } return nil } + +// verifyPdEndpoint verifies whether the pd endpoint is a valid http or https URL. +// The certificate is required when using https. +func verifyPdEndpoint(pdEndpoint string, useTLS bool) error { + u, err := url.Parse(pdEndpoint) + if err != nil { + return errors.Annotate(err, "parse PD endpoint") + } + if (u.Scheme != HTTP && u.Scheme != HTTPS) || u.Host == "" { + return errors.New("PD endpoint should be a valid http or https URL") + } + + if useTLS { + if u.Scheme == HTTP { + return errors.New("PD endpoint scheme should be https") + } + } else { + if u.Scheme == HTTPS { + return errors.New("PD endpoint scheme is https, please provide certificate") + } + } + return nil +} diff --git a/cmd/util_test.go b/cmd/util_test.go index 703ecb6819d..09514314187 100644 --- a/cmd/util_test.go +++ b/cmd/util_test.go @@ -55,3 +55,40 @@ func (s *utilsSuite) TestProxyFields(c *check.C) { } } } + +func (s *utilsSuite) TestVerifyPdEndpoint(c *check.C) { + defer testleak.AfterTest(c)() + // empty URL. + url := "" + c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*") + + // invalid URL. + url = "\n hi" + c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*invalid control character in URL.*") + + // http URL without host. + url = "http://" + c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*") + + // https URL without host. + url = "https://" + c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*") + + // postgres scheme. + url = "postgres://postgres@localhost/cargo_registry" + c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*") + + // https scheme without TLS. + url = "https://aa" + c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint scheme is https, please provide certificate.*") + + // http scheme with TLS. + url = "http://aa" + c.Assert(verifyPdEndpoint(url, true), check.ErrorMatches, ".*PD endpoint scheme should be https.*") + + // valid http URL. + c.Assert(verifyPdEndpoint("http://aa", false), check.IsNil) + + // valid https URL with TLS. + c.Assert(verifyPdEndpoint("https://aa", true), check.IsNil) +}