Skip to content

Commit

Permalink
Add preserve host configuration option
Browse files Browse the repository at this point in the history
  • Loading branch information
sporkmonger committed Sep 18, 2018
1 parent d5fed14 commit 40c5338
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 12 deletions.
24 changes: 16 additions & 8 deletions internal/proxy/oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,26 @@ func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error)

// NewReverseProxy creates a reverse proxy to a specified url.
// It adds an X-Forwarded-Host header that is the request's host.
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
func NewReverseProxy(to *url.URL, config *UpstreamConfig) *httputil.ReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(to)
proxy.Transport = &upstreamTransport{}
director := proxy.Director
proxy.Director = func(req *http.Request) {
req.Header.Add("X-Forwarded-Host", req.Host)
if req.Header.Get("X-Forwarded-Host") == "" {
req.Header.Set("X-Forwarded-Host", req.Host)
}
director(req)
req.Host = to.Host
if !config.PreserveHost {
req.Host = to.Host
}
}
return proxy
}

// NewRewriteReverseProxy creates a reverse proxy that is capable of creating upstream
// urls on the fly based on a from regex and a templated to field.
// It adds an X-Forwarded-Host header to the the upstream's request.
func NewRewriteReverseProxy(route *RewriteRoute) *httputil.ReverseProxy {
func NewRewriteReverseProxy(route *RewriteRoute, config *UpstreamConfig) *httputil.ReverseProxy {
proxy := &httputil.ReverseProxy{}
proxy.Transport = &upstreamTransport{}
proxy.Director = func(req *http.Request) {
Expand All @@ -172,9 +176,13 @@ func NewRewriteReverseProxy(route *RewriteRoute) *httputil.ReverseProxy {
}
director := httputil.NewSingleHostReverseProxy(target).Director

req.Header.Add("X-Forwarded-Host", req.Host)
if req.Header.Get("X-Forwarded-Host") == "" {
req.Header.Set("X-Forwarded-Host", req.Host)
}
director(req)
req.Host = target.Host
if !config.PreserveHost {
req.Host = target.Host
}
}
return proxy
}
Expand Down Expand Up @@ -283,11 +291,11 @@ func NewOAuthProxy(opts *Options, optFuncs ...func(*OAuthProxy) error) (*OAuthPr
for _, upstreamConfig := range opts.upstreamConfigs {
switch route := upstreamConfig.Route.(type) {
case *SimpleRoute:
reverseProxy := NewReverseProxy(route.ToURL)
reverseProxy := NewReverseProxy(route.ToURL, upstreamConfig)
handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig)
p.Handle(route.FromURL.Host, handler, tags, upstreamConfig)
case *RewriteRoute:
reverseProxy := NewRewriteReverseProxy(route)
reverseProxy := NewRewriteReverseProxy(route, upstreamConfig)
handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig)
p.HandleRegex(route.FromRegex, handler, tags, upstreamConfig)
default:
Expand Down
131 changes: 127 additions & 4 deletions internal/proxy/oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestNewReverseProxy(t *testing.T) {
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")

proxyHandler := NewReverseProxy(proxyURL)
proxyHandler := NewReverseProxy(proxyURL, &UpstreamConfig{PreserveHost: false})
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()

Expand Down Expand Up @@ -99,7 +99,7 @@ func TestNewRewriteReverseProxy(t *testing.T) {
},
}

rewriteProxy := NewRewriteReverseProxy(route)
rewriteProxy := NewRewriteReverseProxy(route, &UpstreamConfig{PreserveHost: false})

frontend := httptest.NewServer(rewriteProxy)
defer frontend.Close()
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestNewReverseProxyHostname(t *testing.T) {
t.Fatalf("expected to parse to url: %s", err)
}

reverseProxy := NewReverseProxy(toURL)
reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{PreserveHost: false})
from := httptest.NewServer(reverseProxy)
defer from.Close()

Expand Down Expand Up @@ -191,6 +191,129 @@ func TestNewReverseProxyHostname(t *testing.T) {

}

func TestNewReverseProxyPreserveHost(t *testing.T) {
type respStruct struct {
Host string `json:"host"`
XForwardedHost string `json:"x-forwarded-host"`
}

to := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
body, err := json.Marshal(
&respStruct{
Host: r.Host,
XForwardedHost: r.Header.Get("X-Forwarded-Host"),
},
)
if err != nil {
t.Fatalf("expected to marshal json: %s", err)
}
rw.Write(body)
}))
defer to.Close()

toURL, err := url.Parse(to.URL)
if err != nil {
t.Fatalf("expected to parse to url: %s", err)
}

reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{PreserveHost: true})
from := httptest.NewServer(reverseProxy)
defer from.Close()

fromURL, err := url.Parse(from.URL)
if err != nil {
t.Fatalf("expected to parse from url: %s", err)
}

want := &respStruct{
Host: fromURL.Host,
XForwardedHost: "something",
}

req, err := http.NewRequest("GET", from.URL, strings.NewReader(""))
if err != nil {
t.Fatalf("expected to be able to make req: %s", err)
}
req.Header.Set("X-Forwarded-Host", "something")

res, err := http.DefaultTransport.RoundTrip(req)
if err != nil {
t.Fatalf("expected to be able to get res: %s", err)
}

body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("expected to read body: %s", err)
}

got := &respStruct{}
err = json.Unmarshal(body, got)
if err != nil {
t.Fatalf("expected to decode json: %s", err)
}

if !reflect.DeepEqual(want, got) {
t.Logf(" got host: %v", got.Host)
t.Logf("want host: %v", want.Host)

t.Logf(" got X-Forwarded-Host: %v", got.XForwardedHost)
t.Logf("want X-Forwarded-Host: %v", want.XForwardedHost)

t.Errorf("got unexpected response for Host or X-Forwarded-Host header")
}
if res.Header.Get("Cookie") != "" {
t.Errorf("expected Cookie header to be empty but was %s", res.Header.Get("Cookie"))
}

}

func TestNewRewriteReverseProxyPreserveHost(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(200)
rw.Write([]byte(req.Host))
}))
defer upstream.Close()

parsedUpstreamURL, err := url.Parse(upstream.URL)
if err != nil {
t.Fatalf("expected to parse upstream URL err:%q", err)
}

route := &RewriteRoute{
FromRegex: regexp.MustCompile("(.*)"),
ToTemplate: &url.URL{
Scheme: parsedUpstreamURL.Scheme,
Opaque: parsedUpstreamURL.Host,
},
}

rewriteProxy := NewRewriteReverseProxy(route, &UpstreamConfig{PreserveHost: true})

frontend := httptest.NewServer(rewriteProxy)
defer frontend.Close()

frontendURL, err := url.Parse(frontend.URL)
if err != nil {
t.Fatalf("expected to parse frontend url: %s", err)
}

resp, err := http.Get(frontend.URL)
if err != nil {
t.Fatalf("expected to make successful request err:%q", err)
}

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("expected to read body err:%q", err)
}

if string(body) != frontendURL.Host {
t.Logf("got %v", string(body))
t.Logf("want %v", frontendURL.Host)
t.Fatalf("got unexpected response from upstream")
}
}

func TestDeleteSSOHeader(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -290,7 +413,7 @@ func TestEncodedSlashes(t *testing.T) {
defer backend.Close()

b, _ := url.Parse(backend.URL)
proxyHandler := NewReverseProxy(b)
proxyHandler := NewReverseProxy(b, &UpstreamConfig{PreserveHost: false})
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()

Expand Down
3 changes: 3 additions & 0 deletions internal/proxy/proxy_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type UpstreamConfig struct {

SkipAuthCompiledRegex []*regexp.Regexp
AllowedGroups []string
PreserveHost bool
HMACAuth hmacauth.HmacAuth
Timeout time.Duration
FlushInterval time.Duration
Expand Down Expand Up @@ -75,6 +76,7 @@ type OptionsConfig struct {
HeaderOverrides map[string]string `yaml:"header_overrides"`
SkipAuthRegex []string `yaml:"skip_auth_regex"`
AllowedGroups []string `yaml:"allowed_groups"`
PreserveHost bool `yaml:"preserve_host"`
Timeout time.Duration `yaml:"timeout"`
FlushInterval time.Duration `yaml:"flush_interval"`
}
Expand Down Expand Up @@ -358,6 +360,7 @@ func parseOptionsConfig(proxy *UpstreamConfig) error {
proxy.Timeout = proxy.RouteConfig.Options.Timeout
proxy.FlushInterval = proxy.RouteConfig.Options.FlushInterval
proxy.HeaderOverrides = proxy.RouteConfig.Options.HeaderOverrides
proxy.PreserveHost = proxy.RouteConfig.Options.PreserveHost

proxy.RouteConfig.Options = nil

Expand Down

0 comments on commit 40c5338

Please sign in to comment.