Skip to content

Commit

Permalink
client,server: better err msg when PD endpoint missing certificate (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Rustin170506 authored Jun 29, 2021
1 parent dfc046f commit 3fd3637
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 13 deletions.
10 changes: 4 additions & 6 deletions cmd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,11 @@ func newCliCommand() *cobra.Command {
if err != nil {
return errors.Annotate(err, "fail to validate TLS settings")
}
if tlsConfig != nil {
if strings.Contains(cliPdAddr, "http://") {
return errors.New("PD endpoint scheme should be https")
}
} else if !strings.Contains(cliPdAddr, "http://") {
return errors.New("PD endpoint scheme should be http")

if err := verifyPdEndpoint(cliPdAddr, tlsConfig != nil); err != nil {
return errors.Trace(err)
}

grpcTLSOption, err := credential.ToGRPCDialOption()
if err != nil {
return errors.Annotate(err, "fail to validate TLS settings")
Expand Down
8 changes: 2 additions & 6 deletions cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,8 @@ func loadAndVerifyServerConfig(cmd *cobra.Command) (*config.ServerConfig, error)
return nil, cerror.ErrInvalidServerOption.GenWithStack("empty PD address")
}
for _, ep := range strings.Split(serverPdAddr, ",") {
if conf.Security.IsTLSEnabled() {
if strings.Index(ep, "http://") == 0 {
return nil, cerror.ErrInvalidServerOption.GenWithStack("PD endpoint scheme should be https")
}
} else if strings.Index(ep, "http://") != 0 {
return nil, cerror.ErrInvalidServerOption.GenWithStack("PD endpoint scheme should be http")
if err := verifyPdEndpoint(ep, conf.Security.IsTLSEnabled()); err != nil {
return nil, cerror.ErrInvalidServerOption.Wrap(err).GenWithStackByCause()
}
}

Expand Down
16 changes: 15 additions & 1 deletion cmd/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,21 @@ func (s *serverSuite) TestLoadAndVerifyServerConfig(c *check.C) {
initServerCmd(cmd)
c.Assert(cmd.ParseFlags([]string{"--pd=aa"}), check.IsNil)
_, err = loadAndVerifyServerConfig(cmd)
c.Assert(err, check.ErrorMatches, ".*PD endpoint scheme should be http.*")
c.Assert(err, check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// test invalid PD address(without host)
cmd = new(cobra.Command)
initServerCmd(cmd)
c.Assert(cmd.ParseFlags([]string{"--pd=http://"}), check.IsNil)
_, err = loadAndVerifyServerConfig(cmd)
c.Assert(err, check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// test missing certificate
cmd = new(cobra.Command)
initServerCmd(cmd)
c.Assert(cmd.ParseFlags([]string{"--pd=https://aa"}), check.IsNil)
_, err = loadAndVerifyServerConfig(cmd)
c.Assert(err, check.ErrorMatches, ".*PD endpoint scheme is https, please provide certificate.*")

// test undefined flag
cmd = new(cobra.Command)
Expand Down
29 changes: 29 additions & 0 deletions cmd/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ var errOwnerNotFound = liberrors.New("owner not found")

var tsGapWarnning int64 = 86400 * 1000 // 1 day in milliseconds

// Endpoint schemes.
const (
HTTP = "http"
HTTPS = "https"
)

func addSecurityFlags(flags *pflag.FlagSet, isServer bool) {
flags.StringVar(&caPath, "ca", "", "CA certificate path for TLS connection")
flags.StringVar(&certPath, "cert", "", "Certificate path for TLS connection")
Expand Down Expand Up @@ -375,3 +381,26 @@ func confirmLargeDataGap(ctx context.Context, cmd *cobra.Command, startTs uint64
}
return nil
}

// verifyPdEndpoint verifies whether the pd endpoint is a valid http or https URL.
// The certificate is required when using https.
func verifyPdEndpoint(pdEndpoint string, useTLS bool) error {
u, err := url.Parse(pdEndpoint)
if err != nil {
return errors.Annotate(err, "parse PD endpoint")
}
if (u.Scheme != HTTP && u.Scheme != HTTPS) || u.Host == "" {
return errors.New("PD endpoint should be a valid http or https URL")
}

if useTLS {
if u.Scheme == HTTP {
return errors.New("PD endpoint scheme should be https")
}
} else {
if u.Scheme == HTTPS {
return errors.New("PD endpoint scheme is https, please provide certificate")
}
}
return nil
}
37 changes: 37 additions & 0 deletions cmd/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,40 @@ func (s *utilsSuite) TestProxyFields(c *check.C) {
}
}
}

func (s *utilsSuite) TestVerifyPdEndpoint(c *check.C) {
defer testleak.AfterTest(c)()
// empty URL.
url := ""
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// invalid URL.
url = "\n hi"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*invalid control character in URL.*")

// http URL without host.
url = "http://"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// https URL without host.
url = "https://"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// postgres scheme.
url = "postgres://postgres@localhost/cargo_registry"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint should be a valid http or https URL.*")

// https scheme without TLS.
url = "https://aa"
c.Assert(verifyPdEndpoint(url, false), check.ErrorMatches, ".*PD endpoint scheme is https, please provide certificate.*")

// http scheme with TLS.
url = "http://aa"
c.Assert(verifyPdEndpoint(url, true), check.ErrorMatches, ".*PD endpoint scheme should be https.*")

// valid http URL.
c.Assert(verifyPdEndpoint("http://aa", false), check.IsNil)

// valid https URL with TLS.
c.Assert(verifyPdEndpoint("https://aa", true), check.IsNil)
}

0 comments on commit 3fd3637

Please sign in to comment.