Skip to content
This repository has been archived by the owner on Nov 24, 2023. It is now read-only.

config: check and correct format of addr and URL #937

Merged
merged 7 commits into from
Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions dm/master/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,9 @@ func (c *Config) configFromFile(path string) error {

// adjust adjusts configs
func (c *Config) adjust() error {
// MasterAddr's format may be "scheme://host:port", "host:port" or ":port"
host, port, err := net.SplitHostPort(utils.UnwrapScheme(c.MasterAddr))
c.MasterAddr = utils.UnwrapScheme(c.MasterAddr)
// MasterAddr's format may be "host:port" or ":port"
host, port, err := net.SplitHostPort(c.MasterAddr)
if err != nil {
return terror.ErrMasterHostPortNotValid.Delegate(err, c.MasterAddr)
}
Expand All @@ -236,8 +237,9 @@ func (c *Config) adjust() error {
}
c.AdvertiseAddr = c.MasterAddr
} else {
// AdvertiseAddr's format may be "scheme://host:port" or "host:port"
host, port, err = net.SplitHostPort(utils.UnwrapScheme(c.AdvertiseAddr))
c.AdvertiseAddr = utils.UnwrapScheme(c.AdvertiseAddr)
// AdvertiseAddr's format should be "host:port"
host, port, err = net.SplitHostPort(c.AdvertiseAddr)
if err != nil {
return terror.ErrMasterAdvertiseAddrNotValid.Delegate(err, c.AdvertiseAddr)
}
Expand Down Expand Up @@ -294,10 +296,14 @@ func (c *Config) adjust() error {

if c.PeerUrls == "" {
c.PeerUrls = defaultPeerUrls
} else {
c.PeerUrls = utils.WrapSchemes(c.PeerUrls, c.SSLCA != "")
}

if c.AdvertisePeerUrls == "" {
c.AdvertisePeerUrls = c.PeerUrls
} else {
c.AdvertisePeerUrls = utils.WrapSchemes(c.AdvertisePeerUrls, c.SSLCA != "")
}

if c.InitialCluster == "" {
Expand All @@ -306,12 +312,18 @@ func (c *Config) adjust() error {
items[i] = fmt.Sprintf("%s=%s", c.Name, item)
}
c.InitialCluster = strings.Join(items, ",")
} else {
c.InitialCluster = utils.WrapSchemesForInitialCluster(c.InitialCluster, c.SSLCA != "")
}

if c.InitialClusterState == "" {
c.InitialClusterState = defaultInitialClusterState
}

if c.Join != "" {
c.Join = utils.WrapSchemes(c.Join, c.SSLCA != "")
}

return err
}

Expand Down
10 changes: 8 additions & 2 deletions dm/worker/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ func (c *Config) Parse(arguments []string) error {

// adjust adjusts the config.
func (c *Config) adjust() error {
host, port, err := net.SplitHostPort(utils.UnwrapScheme(c.WorkerAddr))
c.WorkerAddr = utils.UnwrapScheme(c.WorkerAddr)
host, port, err := net.SplitHostPort(c.WorkerAddr)
if err != nil {
return terror.ErrWorkerHostPortNotValid.Delegate(err, c.WorkerAddr)
}
Expand All @@ -180,7 +181,8 @@ func (c *Config) adjust() error {
}
c.AdvertiseAddr = c.WorkerAddr
} else {
host, port, err = net.SplitHostPort(utils.UnwrapScheme(c.AdvertiseAddr))
c.AdvertiseAddr = utils.UnwrapScheme(c.AdvertiseAddr)
host, port, err = net.SplitHostPort(c.AdvertiseAddr)
if err != nil {
return terror.ErrWorkerHostPortNotValid.Delegate(err, c.AdvertiseAddr)
}
Expand All @@ -194,6 +196,10 @@ func (c *Config) adjust() error {
c.Name = c.AdvertiseAddr
}

if c.Join != "" {
c.Join = utils.WrapSchemes(c.Join, c.SSLCA != "")
}

return nil
}

Expand Down
3 changes: 2 additions & 1 deletion dm/worker/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/dm/pkg/ha"
"github.com/pingcap/dm/pkg/log"
"github.com/pingcap/dm/pkg/terror"
"github.com/pingcap/dm/pkg/utils"
)

// GetJoinURLs gets the endpoints from the join address.
Expand All @@ -53,7 +54,7 @@ func (s *Server) JoinMaster(endpoints []string) error {

for _, endpoint := range endpoints {
ctx1, cancel1 := context.WithTimeout(ctx, 3*time.Second)
conn, err := grpc.DialContext(ctx1, endpoint, grpc.WithBlock(), tls.ToGRPCDialOption(), grpc.WithBackoffMaxDelay(3*time.Second))
conn, err := grpc.DialContext(ctx1, utils.UnwrapScheme(endpoint), grpc.WithBlock(), tls.ToGRPCDialOption(), grpc.WithBackoffMaxDelay(3*time.Second))
csuzhangxc marked this conversation as resolved.
Show resolved Hide resolved
cancel1()
if err != nil {
if conn != nil {
Expand Down
40 changes: 40 additions & 0 deletions pkg/utils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,43 @@ func UnwrapScheme(s string) string {
}
return s
}

func wrapScheme(s string, https bool) string {
if s == "" {
return s
}
if strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://") {
return s
}
if https {
return "https://" + s
}
return "http://" + s
}

// WrapSchemes adds http or https scheme to input if missing. input could be a comma-separated list
// if input has wrong scheme, don't correct it (maybe user deliberately?)
func WrapSchemes(s string, https bool) string {
items := strings.Split(s, ",")
output := make([]string, 0, len(items))
for _, s := range items {
output = append(output, wrapScheme(s, https))
}
return strings.Join(output, ",")
}

// WrapSchemesForInitialCluster acts like WrapSchemes, except input is "name=URL,..."
func WrapSchemesForInitialCluster(s string, https bool) string {
items := strings.Split(s, ",")
output := make([]string, 0, len(items))
for _, item := range items {
kv := strings.Split(item, "=")
if len(kv) != 2 {
output = append(output, item)
continue
}

output = append(output, kv[0]+"="+wrapScheme(kv[1], https))
}
return strings.Join(output, ",")
}
38 changes: 38 additions & 0 deletions pkg/utils/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,46 @@ func (t *testUtilsSuite) TestUnwrapScheme(c *C) {
"httpsdfpoje.com",
"httpsdfpoje.com",
},
{
"",
"",
},
}
for _, ca := range cases {
c.Assert(UnwrapScheme(ca.old), Equals, ca.new)
}
}

func (t *testUtilsSuite) TestWrapSchemes(c *C) {
cases := []struct {
old string
http string
https string
}{
{
"0.0.0.0:123",
"http://0.0.0.0:123",
"https://0.0.0.0:123",
},
{
"abc.com:123",
"http://abc.com:123",
"https://abc.com:123",
},
{
// if input has wrong scheme, don't correct it (maybe user deliberately?)
"abc.com:123,http://abc.com:123,0.0.0.0:123,https://0.0.0.0:123",
"http://abc.com:123,http://abc.com:123,http://0.0.0.0:123,https://0.0.0.0:123",
"https://abc.com:123,http://abc.com:123,https://0.0.0.0:123,https://0.0.0.0:123",
},
{
"",
"",
"",
},
}
for _, ca := range cases {
c.Assert(WrapSchemes(ca.old, false), Equals, ca.http)
c.Assert(WrapSchemes(ca.old, true), Equals, ca.https)
}
}
2 changes: 1 addition & 1 deletion tests/ha_master/conf/dm-worker1.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
name = "worker1"
join = "localhost:8261,localhost:8361,localhost:8461,localhost:8561,localhost:8661"
join = "localhost:8261,http://localhost:8361,localhost:8461,localhost:8561,localhost:8661"