Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write normalized scheme and host to routing.Route fields #2824

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 2 additions & 40 deletions filters/fadein/fadein.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ package fadein

import (
"fmt"
"net"
"net/url"
"strings"
"time"

log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
snet "github.com/zalando/skipper/net"
"github.com/zalando/skipper/routing"
)

Expand Down Expand Up @@ -101,42 +99,6 @@ func NewEndpointCreated() filters.Spec {

func (endpointCreated) Name() string { return filters.EndpointCreatedName }

func normalizeSchemeHost(s, h string) (string, string, error) {
// endpoint address cannot contain path, the rest is not case sensitive
s, h = strings.ToLower(s), strings.ToLower(h)

hh, p, err := net.SplitHostPort(h)
if err != nil {
// what is the actual right way of doing this, considering IPv6 addresses, too?
if !strings.Contains(err.Error(), "missing port") {
return "", "", err
}

p = ""
} else {
h = hh
}

switch {
case p == "" && s == "http":
p = "80"
case p == "" && s == "https":
p = "443"
}

h = net.JoinHostPort(h, p)
return s, h, nil
}

func normalizeEndpoint(e string) (string, string, error) {
u, err := url.Parse(e)
if err != nil {
return "", "", err
}

return normalizeSchemeHost(u.Scheme, u.Host)
}

func endpointKey(scheme, host string) string {
return fmt.Sprintf("%s://%s", scheme, host)
}
Expand All @@ -151,7 +113,7 @@ func (endpointCreated) CreateFilter(args []interface{}) (filters.Filter, error)
return nil, filters.ErrInvalidFilterParameters
}

s, h, err := normalizeEndpoint(e)
s, h, err := snet.SchemeHost(e)
if err != nil {
return nil, err
}
Expand Down
37 changes: 1 addition & 36 deletions loadbalancer/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ import (
"fmt"
"math"
"math/rand"
"net"
"net/url"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -428,12 +425,7 @@ func (a Algorithm) String() string {
func parseEndpoints(r *routing.Route) error {
r.LBEndpoints = make([]routing.LBEndpoint, len(r.Route.LBEndpoints))
for i, e := range r.Route.LBEndpoints {
eu, err := url.ParseRequestURI(e)
if err != nil {
return err
}

scheme, host, err := normalizeSchemeHost(eu.Scheme, eu.Host)
scheme, host, err := snet.SchemeHost(e)
if err != nil {
return err
}
Expand Down Expand Up @@ -463,33 +455,6 @@ func setAlgorithm(r *routing.Route) error {
return nil
}

func normalizeSchemeHost(s, h string) (string, string, error) {
// endpoint address cannot contain path, the rest is not case sensitive
s, h = strings.ToLower(s), strings.ToLower(h)

hh, p, err := net.SplitHostPort(h)
if err != nil {
// what is the actual right way of doing this, considering IPv6 addresses, too?
if !strings.Contains(err.Error(), "missing port") {
return "", "", err
}

p = ""
} else {
h = hh
}

switch {
case p == "" && s == "http":
p = "80"
case p == "" && s == "https":
p = "443"
}

h = net.JoinHostPort(h, p)
return s, h, nil
}

// Do implements routing.PostProcessor
func (p *algorithmProvider) Do(r []*routing.Route) []*routing.Route {
rr := make([]*routing.Route, 0, len(r))
Expand Down
47 changes: 47 additions & 0 deletions net/net.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package net

import (
"fmt"
"net"
"net/http"
"net/netip"
"net/url"
"strings"

"go4.org/netipx"
Expand Down Expand Up @@ -154,3 +156,48 @@ func ParseIPCIDRs(cidrs []string) (*netipx.IPSet, error) {

return ips, nil
}

// SchemeHost parses URI string (without #fragment part) and returns schema used in this URI as first return value and
// host[:port] part as second return value. Port is never omitted for HTTP(S): if no port is specified in URI, default port for given
// schema is used. If URI is invalid, error is returned.
func SchemeHost(input string) (string, string, error) {
u, err := url.ParseRequestURI(input)
if err != nil {
return "", "", err
}
if u.Scheme == "" {
return "", "", fmt.Errorf(`parse %q: missing scheme`, input)
}
if u.Host == "" {
return "", "", fmt.Errorf(`parse %q: missing host`, input)
}

// endpoint address cannot contain path, the rest is not case sensitive
s, h := strings.ToLower(u.Scheme), strings.ToLower(u.Host)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: I think variable names here could be more descriptive

scheme, host := strings.ToLower(u.Scheme), strings.ToLower(u.Host)
isolatedHost, port, err := net.SplitHostPort(h)

I feel like the current variable names is easy to forget and harder to debug later


hh, p, err := net.SplitHostPort(h)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
// Trim is needed to remove brackets from IPv6 addresses, JoinHostPort will add them in case of any IPv6 address,
// so we need to remove them to avoid duplicate pairs of brackets.
h = strings.Trim(h, "[]")
switch s {
case "http":
p = "80"
case "https":
p = "443"
default:
p = ""
}
} else {
return "", "", err
}
} else {
h = hh
}

if p != "" {
h = net.JoinHostPort(h, p)
}
return s, h, nil
}
Loading