diff --git a/pkg/netutil/netutil.go b/pkg/netutil/netutil.go index bd8c1fc9320..0e8f6c666f6 100644 --- a/pkg/netutil/netutil.go +++ b/pkg/netutil/netutil.go @@ -174,21 +174,13 @@ func URLStringsEqual(ctx context.Context, lg *zap.Logger, a []string, b []string if len(a) != len(b) { return false, fmt.Errorf("len(%q) != len(%q)", a, b) } - urlsA := make([]url.URL, 0) - for _, str := range a { - u, err := url.Parse(str) - if err != nil { - return false, fmt.Errorf("failed to parse %q", str) - } - urlsA = append(urlsA, *u) + urlsA, err := stringsToURLs(a) + if err != nil { + return false, err } - urlsB := make([]url.URL, 0) - for _, str := range b { - u, err := url.Parse(str) - if err != nil { - return false, fmt.Errorf("failed to parse %q", str) - } - urlsB = append(urlsB, *u) + urlsB, err := stringsToURLs(b) + if err != nil { + return false, err } return urlsEqual(ctx, lg, urlsA, urlsB) } @@ -201,6 +193,18 @@ func urlsToStrings(us []url.URL) []string { return rs } +func stringsToURLs(us []string) ([]url.URL, error) { + urls := make([]url.URL, 0, len(us)) + for _, str := range us { + u, err := url.Parse(str) + if err != nil { + return nil, fmt.Errorf("failed to parse %q", str) + } + urls = append(urls, *u) + } + return urls, nil +} + func IsNetworkTimeoutError(err error) bool { nerr, ok := err.(net.Error) return ok && nerr.Timeout() diff --git a/pkg/netutil/netutil_test.go b/pkg/netutil/netutil_test.go index 42b05ca295a..7d1d17aa269 100644 --- a/pkg/netutil/netutil_test.go +++ b/pkg/netutil/netutil_test.go @@ -17,6 +17,7 @@ package netutil import ( "context" "errors" + "fmt" "net" "net/url" "reflect" @@ -292,11 +293,33 @@ func TestURLsEqual(t *testing.T) { } } func TestURLStringsEqual(t *testing.T) { - result, err := URLStringsEqual(context.TODO(), zap.NewExample(), []string{"http://127.0.0.1:8080"}, []string{"http://127.0.0.1:8080"}) - if !result { - t.Errorf("unexpected result %v", result) + defer func() { resolveTCPAddr = resolveTCPAddrDefault }() + errOnResolve := func(ctx context.Context, addr string) (*net.TCPAddr, error) { + return nil, fmt.Errorf("unexpected attempt to resolve: %q", addr) + } + cases := []struct { + urlsA []string + urlsB []string + resolver func(ctx context.Context, addr string) (*net.TCPAddr, error) + }{ + {[]string{"http://127.0.0.1:8080"}, []string{"http://127.0.0.1:8080"}, resolveTCPAddrDefault}, + {[]string{ + "http://host1:8080", + "http://host2:8080", + }, []string{ + "http://host1:8080", + "http://host2:8080", + }, errOnResolve}, } - if err != nil { - t.Errorf("unexpected error %v", err) + for idx, c := range cases { + t.Logf("TestURLStringsEqual, case #%d", idx) + resolveTCPAddr = c.resolver + result, err := URLStringsEqual(context.TODO(), zap.NewExample(), c.urlsA, c.urlsB) + if !result { + t.Errorf("unexpected result %v", result) + } + if err != nil { + t.Errorf("unexpected error %v", err) + } } }