diff --git a/filters/fadein/fadein.go b/filters/fadein/fadein.go index 64105c3202..6080c8d2ac 100644 --- a/filters/fadein/fadein.go +++ b/filters/fadein/fadein.go @@ -2,14 +2,12 @@ package fadein import ( "fmt" - "net" - "net/url" - "strings" "time" log "github.com/sirupsen/logrus" "github.com/zalando/skipper/eskip" "github.com/zalando/skipper/filters" + snet "github.com/zalando/skipper/net" "github.com/zalando/skipper/routing" ) @@ -101,42 +99,6 @@ func NewEndpointCreated() filters.Spec { func (endpointCreated) Name() string { return filters.EndpointCreatedName } -func normalizeSchemeHost(s, h string) (string, string, error) { - // endpoint address cannot contain path, the rest is not case sensitive - s, h = strings.ToLower(s), strings.ToLower(h) - - hh, p, err := net.SplitHostPort(h) - if err != nil { - // what is the actual right way of doing this, considering IPv6 addresses, too? - if !strings.Contains(err.Error(), "missing port") { - return "", "", err - } - - p = "" - } else { - h = hh - } - - switch { - case p == "" && s == "http": - p = "80" - case p == "" && s == "https": - p = "443" - } - - h = net.JoinHostPort(h, p) - return s, h, nil -} - -func normalizeEndpoint(e string) (string, string, error) { - u, err := url.Parse(e) - if err != nil { - return "", "", err - } - - return normalizeSchemeHost(u.Scheme, u.Host) -} - func endpointKey(scheme, host string) string { return fmt.Sprintf("%s://%s", scheme, host) } @@ -151,7 +113,7 @@ func (endpointCreated) CreateFilter(args []interface{}) (filters.Filter, error) return nil, filters.ErrInvalidFilterParameters } - s, h, err := normalizeEndpoint(e) + s, h, err := snet.SchemeHost(e) if err != nil { return nil, err } diff --git a/loadbalancer/algorithm.go b/loadbalancer/algorithm.go index 37b897ef5c..38a703f382 100644 --- a/loadbalancer/algorithm.go +++ b/loadbalancer/algorithm.go @@ -5,10 +5,7 @@ import ( "fmt" "math" "math/rand" - "net" - "net/url" "sort" - "strings" "sync" "sync/atomic" "time" @@ -428,12 +425,7 @@ func (a Algorithm) String() string { func parseEndpoints(r *routing.Route) error { r.LBEndpoints = make([]routing.LBEndpoint, len(r.Route.LBEndpoints)) for i, e := range r.Route.LBEndpoints { - eu, err := url.ParseRequestURI(e) - if err != nil { - return err - } - - scheme, host, err := normalizeSchemeHost(eu.Scheme, eu.Host) + scheme, host, err := snet.SchemeHost(e) if err != nil { return err } @@ -463,33 +455,6 @@ func setAlgorithm(r *routing.Route) error { return nil } -func normalizeSchemeHost(s, h string) (string, string, error) { - // endpoint address cannot contain path, the rest is not case sensitive - s, h = strings.ToLower(s), strings.ToLower(h) - - hh, p, err := net.SplitHostPort(h) - if err != nil { - // what is the actual right way of doing this, considering IPv6 addresses, too? - if !strings.Contains(err.Error(), "missing port") { - return "", "", err - } - - p = "" - } else { - h = hh - } - - switch { - case p == "" && s == "http": - p = "80" - case p == "" && s == "https": - p = "443" - } - - h = net.JoinHostPort(h, p) - return s, h, nil -} - // Do implements routing.PostProcessor func (p *algorithmProvider) Do(r []*routing.Route) []*routing.Route { rr := make([]*routing.Route, 0, len(r)) diff --git a/net/net.go b/net/net.go index 661244d4e1..2e1e7990f9 100644 --- a/net/net.go +++ b/net/net.go @@ -1,9 +1,11 @@ package net import ( + "fmt" "net" "net/http" "net/netip" + "net/url" "strings" "go4.org/netipx" @@ -154,3 +156,48 @@ func ParseIPCIDRs(cidrs []string) (*netipx.IPSet, error) { return ips, nil } + +// SchemeHost parses URI string (without #fragment part) and returns schema used in this URI as first return value and +// host[:port] part as second return value. Port is never omitted for HTTP(S): if no port is specified in URI, default port for given +// schema is used. If URI is invalid, error is returned. +func SchemeHost(input string) (string, string, error) { + u, err := url.ParseRequestURI(input) + if err != nil { + return "", "", err + } + if u.Scheme == "" { + return "", "", fmt.Errorf(`parse %q: missing scheme`, input) + } + if u.Host == "" { + return "", "", fmt.Errorf(`parse %q: missing host`, input) + } + + // endpoint address cannot contain path, the rest is not case sensitive + s, h := strings.ToLower(u.Scheme), strings.ToLower(u.Host) + + hh, p, err := net.SplitHostPort(h) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + // Trim is needed to remove brackets from IPv6 addresses, JoinHostPort will add them in case of any IPv6 address, + // so we need to remove them to avoid duplicate pairs of brackets. + h = strings.Trim(h, "[]") + switch s { + case "http": + p = "80" + case "https": + p = "443" + default: + p = "" + } + } else { + return "", "", err + } + } else { + h = hh + } + + if p != "" { + h = net.JoinHostPort(h, p) + } + return s, h, nil +} diff --git a/net/net_test.go b/net/net_test.go index a513f16b04..1f2412ac38 100644 --- a/net/net_test.go +++ b/net/net_test.go @@ -1,14 +1,41 @@ package net import ( + "fmt" "net" "net/http" "net/netip" + "path/filepath" "reflect" + "runtime" "strings" "testing" + + "github.com/stretchr/testify/assert" ) +type tc[T any] struct { + location string + in T +} + +// https://github.com/golang/go/issues/52751 +func testCase[T any](in T) tc[T] { + _, file, line, _ := runtime.Caller(1) + location := fmt.Sprintf("%s:%d", filepath.Base(file), line) + return tc[T]{location: location, in: in} +} + +func (tc *tc[T]) logLocation(t *testing.T) { + t.Helper() + t.Cleanup(func() { + t.Helper() + if t.Failed() { + t.Logf("Test case location: %s", tc.location) + } + }) +} + func TestRemoteAddr(t *testing.T) { for _, tt := range []struct { name string @@ -236,3 +263,263 @@ func TestIPNetsDoNotContain(t *testing.T) { }) } } + +type TestSchemeHostItem struct { + input string + scheme string + host string + err string +} + +func TestSchemeHost(t *testing.T) { + for _, ti := range []tc[TestSchemeHostItem]{ + testCase(TestSchemeHostItem{ + input: "http://example.com", + scheme: "http", + host: "example.com:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://example.com:80", + scheme: "http", + host: "example.com:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://example.com:8080", + scheme: "http", + host: "example.com:8080", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "https://example.com", + scheme: "https", + host: "example.com:443", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "https://example.com:443", + scheme: "https", + host: "example.com:443", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "https://example.com:8080", + scheme: "https", + host: "example.com:8080", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "fastcgi://example.com", + scheme: "fastcgi", + host: "example.com", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "fastcgi://example.com:9000", + scheme: "fastcgi", + host: "example.com:9000", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "fastcgi://example.com:8080", + scheme: "fastcgi", + host: "example.com:8080", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "fastcgi://foo/bar", + scheme: "fastcgi", + host: "foo", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "postgres://example.com", + scheme: "postgres", + host: "example.com", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "postgres://example.com:5432", + scheme: "postgres", + host: "example.com:5432", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "postgresql://example.com", + scheme: "postgresql", + host: "example.com", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "postgresql://example.com:5432", + scheme: "postgresql", + host: "example.com:5432", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "someprotocol://example.com", + scheme: "someprotocol", + host: "example.com", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "someprotocol://example.com:12345", + scheme: "someprotocol", + host: "example.com:12345", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "example.com", + scheme: "", + host: "", + err: `parse "example.com": invalid URI for request`, + }), + testCase(TestSchemeHostItem{ + input: "example.com/", + scheme: "", + host: "", + err: `parse "example.com/": invalid URI for request`, + }), + testCase(TestSchemeHostItem{ + input: "example.com:80", + scheme: "", + host: "", + err: `parse "example.com:80": missing host`, + }), + + testCase(TestSchemeHostItem{ + input: "hTTP://exAMPLe.com", + scheme: "http", + host: "example.com:80", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "http://example.com/foo/bar", + scheme: "http", + host: "example.com:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://example.com:80/foo/bar", + scheme: "http", + host: "example.com:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://example.com:8080/foo/bar", + scheme: "http", + host: "example.com:8080", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "http://example.com?foo=bar", + scheme: "http", + host: "example.com:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://example.com:80?foo=bar", + scheme: "http", + host: "example.com:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://example.com:8080?foo=bar", + scheme: "http", + host: "example.com:8080", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "http://192.168.0.1", + scheme: "http", + host: "192.168.0.1:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://192.168.0.1:80", + scheme: "http", + host: "192.168.0.1:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://192.168.0.1:8080", + scheme: "http", + host: "192.168.0.1:8080", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "http://[2001:db8:3333:4444:5555:6666:7777:8888]", + scheme: "http", + host: "[2001:db8:3333:4444:5555:6666:7777:8888]:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://[2001:db8:3333:4444:5555:6666:7777:8888]:80", + scheme: "http", + host: "[2001:db8:3333:4444:5555:6666:7777:8888]:80", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "http://[2001:db8:3333:4444:5555:6666:7777:8888]:8080", + scheme: "http", + host: "[2001:db8:3333:4444:5555:6666:7777:8888]:8080", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "fastcgi://192.168.0.1", + scheme: "fastcgi", + host: "192.168.0.1", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "fastcgi://192.168.0.1:9000", + scheme: "fastcgi", + host: "192.168.0.1:9000", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "fastcgi://[2001:db8:3333:4444:5555:6666:7777:8888]", + scheme: "fastcgi", + host: "2001:db8:3333:4444:5555:6666:7777:8888", + err: "", + }), + testCase(TestSchemeHostItem{ + input: "fastcgi://[2001:db8:3333:4444:5555:6666:7777:8888]:9000", + scheme: "fastcgi", + host: "[2001:db8:3333:4444:5555:6666:7777:8888]:9000", + err: "", + }), + + testCase(TestSchemeHostItem{ + input: "/foo", + scheme: "", + host: "", + err: `parse "/foo": missing scheme`, + }), + } { + t.Run(ti.in.input, func(t *testing.T) { + ti.logLocation(t) + + scheme, host, err := SchemeHost(ti.in.input) + if ti.in.err != "" { + assert.EqualError(t, err, ti.in.err) + } else { + if assert.NoError(t, err) { + assert.Equal(t, ti.in.scheme, scheme) + assert.Equal(t, ti.in.host, host) + } + } + }) + } +} diff --git a/routing/datasource.go b/routing/datasource.go index 68b3ad465d..36f8b7ce91 100644 --- a/routing/datasource.go +++ b/routing/datasource.go @@ -3,7 +3,6 @@ package routing import ( "errors" "fmt" - "net/url" "sort" "sync" "time" @@ -11,6 +10,7 @@ import ( "github.com/zalando/skipper/eskip" "github.com/zalando/skipper/filters" "github.com/zalando/skipper/logging" + "github.com/zalando/skipper/net" "github.com/zalando/skipper/predicates" ) @@ -203,12 +203,7 @@ func splitBackend(r *eskip.Route) (string, string, error) { return "", "", nil } - bu, err := url.ParseRequestURI(r.Backend) - if err != nil { - return "", "", err - } - - return bu.Scheme, bu.Host, nil + return net.SchemeHost(r.Backend) } // creates a filter instance based on its definition and its