diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index 8ae55e2c..c61a0d8b 100755 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -138,14 +138,18 @@ func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error) // NewReverseProxy creates a reverse proxy to a specified url. // It adds an X-Forwarded-Host header that is the request's host. -func NewReverseProxy(to *url.URL) *httputil.ReverseProxy { +func NewReverseProxy(to *url.URL, config *UpstreamConfig) *httputil.ReverseProxy { proxy := httputil.NewSingleHostReverseProxy(to) proxy.Transport = &upstreamTransport{} director := proxy.Director proxy.Director = func(req *http.Request) { - req.Header.Add("X-Forwarded-Host", req.Host) + if req.Header.Get("X-Forwarded-Host") == "" { + req.Header.Set("X-Forwarded-Host", req.Host) + } director(req) - req.Host = to.Host + if !config.PreserveHost { + req.Host = to.Host + } } return proxy } @@ -153,7 +157,7 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy { // NewRewriteReverseProxy creates a reverse proxy that is capable of creating upstream // urls on the fly based on a from regex and a templated to field. // It adds an X-Forwarded-Host header to the the upstream's request. -func NewRewriteReverseProxy(route *RewriteRoute) *httputil.ReverseProxy { +func NewRewriteReverseProxy(route *RewriteRoute, config *UpstreamConfig) *httputil.ReverseProxy { proxy := &httputil.ReverseProxy{} proxy.Transport = &upstreamTransport{} proxy.Director = func(req *http.Request) { @@ -172,9 +176,13 @@ func NewRewriteReverseProxy(route *RewriteRoute) *httputil.ReverseProxy { } director := httputil.NewSingleHostReverseProxy(target).Director - req.Header.Add("X-Forwarded-Host", req.Host) + if req.Header.Get("X-Forwarded-Host") == "" { + req.Header.Set("X-Forwarded-Host", req.Host) + } director(req) - req.Host = target.Host + if !config.PreserveHost { + req.Host = target.Host + } } return proxy } @@ -283,11 +291,11 @@ func NewOAuthProxy(opts *Options, optFuncs ...func(*OAuthProxy) error) (*OAuthPr for _, upstreamConfig := range opts.upstreamConfigs { switch route := upstreamConfig.Route.(type) { case *SimpleRoute: - reverseProxy := NewReverseProxy(route.ToURL) + reverseProxy := NewReverseProxy(route.ToURL, upstreamConfig) handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig) p.Handle(route.FromURL.Host, handler, tags, upstreamConfig) case *RewriteRoute: - reverseProxy := NewRewriteReverseProxy(route) + reverseProxy := NewRewriteReverseProxy(route, upstreamConfig) handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig) p.HandleRegex(route.FromRegex, handler, tags, upstreamConfig) default: diff --git a/internal/proxy/oauthproxy_test.go b/internal/proxy/oauthproxy_test.go index 36be46f0..768d447f 100644 --- a/internal/proxy/oauthproxy_test.go +++ b/internal/proxy/oauthproxy_test.go @@ -67,7 +67,7 @@ func TestNewReverseProxy(t *testing.T) { backendHost := net.JoinHostPort(backendHostname, backendPort) proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") - proxyHandler := NewReverseProxy(proxyURL) + proxyHandler := NewReverseProxy(proxyURL, &UpstreamConfig{PreserveHost: false}) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() @@ -99,7 +99,7 @@ func TestNewRewriteReverseProxy(t *testing.T) { }, } - rewriteProxy := NewRewriteReverseProxy(route) + rewriteProxy := NewRewriteReverseProxy(route, &UpstreamConfig{PreserveHost: false}) frontend := httptest.NewServer(rewriteProxy) defer frontend.Close() @@ -146,7 +146,7 @@ func TestNewReverseProxyHostname(t *testing.T) { t.Fatalf("expected to parse to url: %s", err) } - reverseProxy := NewReverseProxy(toURL) + reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{PreserveHost: false}) from := httptest.NewServer(reverseProxy) defer from.Close() @@ -191,6 +191,129 @@ func TestNewReverseProxyHostname(t *testing.T) { } +func TestNewReverseProxyPreserveHost(t *testing.T) { + type respStruct struct { + Host string `json:"host"` + XForwardedHost string `json:"x-forwarded-host"` + } + + to := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + body, err := json.Marshal( + &respStruct{ + Host: r.Host, + XForwardedHost: r.Header.Get("X-Forwarded-Host"), + }, + ) + if err != nil { + t.Fatalf("expected to marshal json: %s", err) + } + rw.Write(body) + })) + defer to.Close() + + toURL, err := url.Parse(to.URL) + if err != nil { + t.Fatalf("expected to parse to url: %s", err) + } + + reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{PreserveHost: true}) + from := httptest.NewServer(reverseProxy) + defer from.Close() + + fromURL, err := url.Parse(from.URL) + if err != nil { + t.Fatalf("expected to parse from url: %s", err) + } + + want := &respStruct{ + Host: fromURL.Host, + XForwardedHost: "something", + } + + req, err := http.NewRequest("GET", from.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("expected to be able to make req: %s", err) + } + req.Header.Set("X-Forwarded-Host", "something") + + res, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + t.Fatalf("expected to be able to get res: %s", err) + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("expected to read body: %s", err) + } + + got := &respStruct{} + err = json.Unmarshal(body, got) + if err != nil { + t.Fatalf("expected to decode json: %s", err) + } + + if !reflect.DeepEqual(want, got) { + t.Logf(" got host: %v", got.Host) + t.Logf("want host: %v", want.Host) + + t.Logf(" got X-Forwarded-Host: %v", got.XForwardedHost) + t.Logf("want X-Forwarded-Host: %v", want.XForwardedHost) + + t.Errorf("got unexpected response for Host or X-Forwarded-Host header") + } + if res.Header.Get("Cookie") != "" { + t.Errorf("expected Cookie header to be empty but was %s", res.Header.Get("Cookie")) + } + +} + +func TestNewRewriteReverseProxyPreserveHost(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(200) + rw.Write([]byte(req.Host)) + })) + defer upstream.Close() + + parsedUpstreamURL, err := url.Parse(upstream.URL) + if err != nil { + t.Fatalf("expected to parse upstream URL err:%q", err) + } + + route := &RewriteRoute{ + FromRegex: regexp.MustCompile("(.*)"), + ToTemplate: &url.URL{ + Scheme: parsedUpstreamURL.Scheme, + Opaque: parsedUpstreamURL.Host, + }, + } + + rewriteProxy := NewRewriteReverseProxy(route, &UpstreamConfig{PreserveHost: true}) + + frontend := httptest.NewServer(rewriteProxy) + defer frontend.Close() + + frontendURL, err := url.Parse(frontend.URL) + if err != nil { + t.Fatalf("expected to parse frontend url: %s", err) + } + + resp, err := http.Get(frontend.URL) + if err != nil { + t.Fatalf("expected to make successful request err:%q", err) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("expected to read body err:%q", err) + } + + if string(body) != frontendURL.Host { + t.Logf("got %v", string(body)) + t.Logf("want %v", frontendURL.Host) + t.Fatalf("got unexpected response from upstream") + } +} + func TestDeleteSSOHeader(t *testing.T) { testCases := []struct { name string @@ -290,7 +413,7 @@ func TestEncodedSlashes(t *testing.T) { defer backend.Close() b, _ := url.Parse(backend.URL) - proxyHandler := NewReverseProxy(b) + proxyHandler := NewReverseProxy(b, &UpstreamConfig{PreserveHost: false}) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() diff --git a/internal/proxy/proxy_config.go b/internal/proxy/proxy_config.go index 38d84c19..750a6296 100644 --- a/internal/proxy/proxy_config.go +++ b/internal/proxy/proxy_config.go @@ -48,6 +48,7 @@ type UpstreamConfig struct { SkipAuthCompiledRegex []*regexp.Regexp AllowedGroups []string + PreserveHost bool HMACAuth hmacauth.HmacAuth Timeout time.Duration FlushInterval time.Duration @@ -75,6 +76,7 @@ type OptionsConfig struct { HeaderOverrides map[string]string `yaml:"header_overrides"` SkipAuthRegex []string `yaml:"skip_auth_regex"` AllowedGroups []string `yaml:"allowed_groups"` + PreserveHost bool `yaml:"preserve_host"` Timeout time.Duration `yaml:"timeout"` FlushInterval time.Duration `yaml:"flush_interval"` } @@ -358,6 +360,7 @@ func parseOptionsConfig(proxy *UpstreamConfig) error { proxy.Timeout = proxy.RouteConfig.Options.Timeout proxy.FlushInterval = proxy.RouteConfig.Options.FlushInterval proxy.HeaderOverrides = proxy.RouteConfig.Options.HeaderOverrides + proxy.PreserveHost = proxy.RouteConfig.Options.PreserveHost proxy.RouteConfig.Options = nil