diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index b3394582..5a780c44 100755 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -185,7 +185,9 @@ func NewReverseProxy(to *url.URL, config *UpstreamConfig) *httputil.ReverseProxy proxy.Director = func(req *http.Request) { req.Header.Add("X-Forwarded-Host", req.Host) director(req) - req.Host = to.Host + if !config.PreserveHost { + req.Host = to.Host + } } return proxy } @@ -214,7 +216,9 @@ func NewRewriteReverseProxy(route *RewriteRoute, config *UpstreamConfig) *httput req.Header.Add("X-Forwarded-Host", req.Host) director(req) - req.Host = target.Host + if !config.PreserveHost { + req.Host = target.Host + } } return proxy } diff --git a/internal/proxy/oauthproxy_test.go b/internal/proxy/oauthproxy_test.go index 09760e8b..c540b226 100644 --- a/internal/proxy/oauthproxy_test.go +++ b/internal/proxy/oauthproxy_test.go @@ -65,191 +65,277 @@ func testSession() *providers.SessionState { } func TestNewReverseProxy(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - hostname, _, _ := net.SplitHostPort(r.Host) - w.Write([]byte(hostname)) - })) - defer backend.Close() - - backendURL, _ := url.Parse(backend.URL) - backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) - backendHost := net.JoinHostPort(backendHostname, backendPort) - proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") - - proxyHandler := NewReverseProxy(proxyURL, &UpstreamConfig{TLSSkipVerify: false}) - frontend := httptest.NewServer(proxyHandler) - defer frontend.Close() - - getReq, _ := http.NewRequest("GET", frontend.URL, nil) - res, _ := http.DefaultClient.Do(getReq) - bodyBytes, _ := ioutil.ReadAll(res.Body) - if g, e := string(bodyBytes), backendHostname; g != e { - t.Errorf("got body %q; expected %q", g, e) - } -} - -func TestNewRewriteReverseProxy(t *testing.T) { - upstream := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(200) - rw.Write([]byte(req.Host)) - })) - defer upstream.Close() - - parsedUpstreamURL, err := url.Parse(upstream.URL) - if err != nil { - t.Fatalf("expected to parse upstream URL err:%q", err) + type respStruct struct { + Host string `json:"host"` + XForwardedHost string `json:"x-forwarded-host"` } - route := &RewriteRoute{ - FromRegex: regexp.MustCompile("(.*)"), - ToTemplate: &url.URL{ - Scheme: parsedUpstreamURL.Scheme, - Opaque: parsedUpstreamURL.Host, + testCases := []struct { + name string + useTLS bool + skipVerify bool + preserveHost bool + expectedStatus int + }{ + { + name: "tls true skip verify false preserve host false", + useTLS: true, + skipVerify: false, + preserveHost: false, + expectedStatus: 502, + }, + { + name: "tls true skip verify true preserve host false", + useTLS: true, + skipVerify: true, + preserveHost: false, + expectedStatus: 200, + }, + { + name: "tls true skip verify false preserve host true", + useTLS: true, + skipVerify: false, + preserveHost: true, + expectedStatus: 502, + }, + { + name: "tls true skip verify true preserve host true", + useTLS: true, + skipVerify: true, + preserveHost: true, + expectedStatus: 200, }, - } - - rewriteProxy := NewRewriteReverseProxy(route, &UpstreamConfig{TLSSkipVerify: false}) - - frontend := httptest.NewServer(rewriteProxy) - defer frontend.Close() - - resp, err := http.Get(frontend.URL) - if err != nil { - t.Fatalf("expected to make successful request err:%q", err) - } - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("expected to read body err:%q", err) - } - if string(body) != parsedUpstreamURL.Host { - t.Logf("got %v", string(body)) - t.Logf("want %v", parsedUpstreamURL.Host) - t.Fatalf("got unexpected response from upstream") + { + name: "tls false skip verify false preserve host false", + useTLS: false, + skipVerify: false, + preserveHost: false, + expectedStatus: 200, + }, + { + name: "tls false skip verify true preserve host false", + useTLS: false, + skipVerify: true, + preserveHost: false, + expectedStatus: 200, + }, + { + name: "tls false skip verify false preserve host true", + useTLS: false, + skipVerify: false, + preserveHost: true, + expectedStatus: 200, + }, + { + name: "tls false skip verify true preserve host true", + useTLS: false, + skipVerify: true, + preserveHost: true, + expectedStatus: 200, + }, } -} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var newServer func(http.Handler) *httptest.Server + if tc.useTLS { + newServer = httptest.NewTLSServer + } else { + newServer = httptest.NewServer + } + to := newServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + body, err := json.Marshal( + &respStruct{ + Host: r.Host, + XForwardedHost: r.Header.Get("X-Forwarded-Host"), + }, + ) + if err != nil { + t.Fatalf("expected to marshal json: %s", err) + } + rw.Write(body) + })) + defer to.Close() -func TestNewReverseProxyHostname(t *testing.T) { - type respStruct struct { - Host string `json:"host"` - XForwardedHost string `json:"x-forwarded-host"` - } + toURL, err := url.Parse(to.URL) + if err != nil { + t.Fatalf("expected to parse to url: %s", err) + } - to := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - body, err := json.Marshal( - &respStruct{ - Host: r.Host, - XForwardedHost: r.Header.Get("X-Forwarded-Host"), - }, - ) - if err != nil { - t.Fatalf("expected to marshal json: %s", err) - } - rw.Write(body) - })) - defer to.Close() + reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{TLSSkipVerify: tc.skipVerify, PreserveHost: tc.preserveHost}) + from := httptest.NewServer(reverseProxy) + defer from.Close() - toURL, err := url.Parse(to.URL) - if err != nil { - t.Fatalf("expected to parse to url: %s", err) - } + fromURL, err := url.Parse(from.URL) + if err != nil { + t.Fatalf("expected to parse from url: %s", err) + } - reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{TLSSkipVerify: false}) - from := httptest.NewServer(reverseProxy) - defer from.Close() + want := &respStruct{ + Host: toURL.Host, + XForwardedHost: fromURL.Host, + } + if tc.preserveHost { + want.Host = fromURL.Host + } - fromURL, err := url.Parse(from.URL) - if err != nil { - t.Fatalf("expected to parse from url: %s", err) - } + res, err := http.Get(from.URL) + if err != nil { + t.Fatalf("expected to be able to make req: %s", err) + } - want := &respStruct{ - Host: toURL.Host, - XForwardedHost: fromURL.Host, - } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("expected to read body: %s", err) + } - res, err := http.Get(from.URL) - if err != nil { - t.Fatalf("expected to be able to make req: %s", err) - } + if res.StatusCode != tc.expectedStatus { + t.Logf(" got status code: %v", res.StatusCode) + t.Logf("want status code: %d", tc.expectedStatus) - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("expected to read body: %s", err) - } + t.Errorf("got unexpected response code for tls failure") + } - got := &respStruct{} - err = json.Unmarshal(body, got) - if err != nil { - t.Fatalf("expected to decode json: %s", err) - } + if res.StatusCode >= 200 && res.StatusCode < 300 { + got := &respStruct{} + err = json.Unmarshal(body, got) + if err != nil { + t.Fatalf("expected to decode json: %s", err) + } - if !reflect.DeepEqual(want, got) { - t.Logf(" got host: %v", got.Host) - t.Logf("want host: %v", want.Host) + if !reflect.DeepEqual(want, got) { + t.Logf(" got host: %v", got.Host) + t.Logf("want host: %v", want.Host) - t.Logf(" got X-Forwarded-Host: %v", got.XForwardedHost) - t.Logf("want X-Forwarded-Host: %v", want.XForwardedHost) + t.Logf(" got X-Forwarded-Host: %v", got.XForwardedHost) + t.Logf("want X-Forwarded-Host: %v", want.XForwardedHost) - t.Errorf("got unexpected response for Host or X-Forwarded-Host header") - } - if res.Header.Get("Cookie") != "" { - t.Errorf("expected Cookie header to be empty but was %s", res.Header.Get("Cookie")) + t.Errorf("got unexpected response for Host or X-Forwarded-Host header") + } + if res.Header.Get("Cookie") != "" { + t.Errorf("expected Cookie header to be empty but was %s", res.Header.Get("Cookie")) + } + } + }) } - } -func TestNewReverseProxyTLSSkipVerify(t *testing.T) { +func TestNewRewriteReverseProxy(t *testing.T) { type respStruct struct { - HandshakeComplete bool `json:"handshake-complete"` + Host string `json:"host"` + XForwardedHost string `json:"x-forwarded-host"` } testCases := []struct { name string + useTLS bool skipVerify bool + preserveHost bool expectedStatus int }{ { - name: "skip verify true", + name: "tls true skip verify false preserve host false", + useTLS: true, + skipVerify: false, + preserveHost: false, + expectedStatus: 502, + }, + { + name: "tls true skip verify true preserve host false", + useTLS: true, skipVerify: true, + preserveHost: false, expectedStatus: 200, }, { - name: "skip verify false", + name: "tls true skip verify false preserve host true", + useTLS: true, skipVerify: false, + preserveHost: true, expectedStatus: 502, }, + { + name: "tls true skip verify true preserve host true", + useTLS: true, + skipVerify: true, + preserveHost: true, + expectedStatus: 200, + }, + + { + name: "tls false skip verify false preserve host false", + useTLS: false, + skipVerify: false, + preserveHost: false, + expectedStatus: 200, + }, + { + name: "tls false skip verify true preserve host false", + useTLS: false, + skipVerify: true, + preserveHost: false, + expectedStatus: 200, + }, + { + name: "tls false skip verify false preserve host true", + useTLS: false, + skipVerify: false, + preserveHost: true, + expectedStatus: 200, + }, + { + name: "tls false skip verify true preserve host true", + useTLS: false, + skipVerify: true, + preserveHost: true, + expectedStatus: 200, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - to := httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - body, err := json.Marshal( - // Doesn't really matter what's sent since we should 502 - &respStruct{ - HandshakeComplete: r.TLS.HandshakeComplete, - }, - ) - if err != nil { - t.Fatalf("expected to marshal json: %s", err) - } - rw.Write(body) + var newServer func(http.Handler) *httptest.Server + if tc.useTLS { + newServer = httptest.NewTLSServer + } else { + newServer = httptest.NewServer + } + upstream := newServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(200) + rw.Write([]byte(req.Host)) })) - defer to.Close() + defer upstream.Close() - toURL, err := url.Parse(to.URL) + parsedUpstreamURL, err := url.Parse(upstream.URL) if err != nil { - t.Fatalf("expected to parse to url: %s", err) + t.Fatalf("expected to parse upstream URL err:%q", err) } - reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{TLSSkipVerify: tc.skipVerify}) - from := httptest.NewServer(reverseProxy) - defer from.Close() + route := &RewriteRoute{ + FromRegex: regexp.MustCompile("(.*)"), + ToTemplate: &url.URL{ + Scheme: parsedUpstreamURL.Scheme, + Opaque: parsedUpstreamURL.Host, + }, + } - res, err := http.Get(from.URL) + rewriteProxy := NewRewriteReverseProxy(route, &UpstreamConfig{TLSSkipVerify: tc.skipVerify, PreserveHost: tc.preserveHost}) + + frontend := httptest.NewServer(rewriteProxy) + defer frontend.Close() + + frontendURL, err := url.Parse(frontend.URL) if err != nil { - t.Fatalf("expected to be able to make req: %s", err) + t.Fatalf("expected to parse frontend url: %s", err) + } + + res, err := http.Get(frontend.URL) + if err != nil { + t.Fatalf("expected to make successful request err:%q", err) + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("expected to read body err:%q", err) } if res.StatusCode != tc.expectedStatus { @@ -258,8 +344,25 @@ func TestNewReverseProxyTLSSkipVerify(t *testing.T) { t.Errorf("got unexpected response code for tls failure") } - if res.Header.Get("Cookie") != "" { - t.Errorf("expected Cookie header to be empty but was %s", res.Header.Get("Cookie")) + + if res.StatusCode >= 200 && res.StatusCode < 300 { + if tc.preserveHost { + if string(body) != frontendURL.Host { + t.Logf("got %v", string(body)) + t.Logf("want %v", frontendURL.Host) + t.Fatalf("got unexpected response from upstream") + } + } else { + if string(body) != parsedUpstreamURL.Host { + t.Logf("got %v", string(body)) + t.Logf("want %v", parsedUpstreamURL.Host) + t.Fatalf("got unexpected response from upstream") + } + } + + if res.Header.Get("Cookie") != "" { + t.Errorf("expected Cookie header to be empty but was %s", res.Header.Get("Cookie")) + } } }) } @@ -324,7 +427,7 @@ func TestRoundTrip(t *testing.T) { }{ { name: "no error", - url: "https://www.example.com/", + url: "http://www.example.com/", }, { name: "with error", @@ -364,7 +467,7 @@ func TestEncodedSlashes(t *testing.T) { defer backend.Close() b, _ := url.Parse(backend.URL) - proxyHandler := NewReverseProxy(b, &UpstreamConfig{TLSSkipVerify: false}) + proxyHandler := NewReverseProxy(b, &UpstreamConfig{TLSSkipVerify: false, PreserveHost: false}) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() diff --git a/internal/proxy/proxy_config.go b/internal/proxy/proxy_config.go index 38f4ca9a..8a8f3409 100644 --- a/internal/proxy/proxy_config.go +++ b/internal/proxy/proxy_config.go @@ -53,6 +53,7 @@ type UpstreamConfig struct { SkipAuthCompiledRegex []*regexp.Regexp AllowedGroups []string TLSSkipVerify bool + PreserveHost bool HMACAuth hmacauth.HmacAuth Timeout time.Duration FlushInterval time.Duration @@ -82,6 +83,7 @@ type OptionsConfig struct { SkipAuthRegex []string `yaml:"skip_auth_regex"` AllowedGroups []string `yaml:"allowed_groups"` TLSSkipVerify bool `yaml:"tls_skip_verify"` + PreserveHost bool `yaml:"preserve_host"` Timeout time.Duration `yaml:"timeout"` FlushInterval time.Duration `yaml:"flush_interval"` SkipRequestSigning bool `yaml:"skip_request_signing"` @@ -367,6 +369,7 @@ func parseOptionsConfig(proxy *UpstreamConfig) error { proxy.FlushInterval = proxy.RouteConfig.Options.FlushInterval proxy.HeaderOverrides = proxy.RouteConfig.Options.HeaderOverrides proxy.TLSSkipVerify = proxy.RouteConfig.Options.TLSSkipVerify + proxy.PreserveHost = proxy.RouteConfig.Options.PreserveHost proxy.SkipRequestSigning = proxy.RouteConfig.Options.SkipRequestSigning proxy.RouteConfig.Options = nil diff --git a/internal/proxy/proxy_config_test.go b/internal/proxy/proxy_config_test.go index 506b20f1..4a87deaf 100644 --- a/internal/proxy/proxy_config_test.go +++ b/internal/proxy/proxy_config_test.go @@ -292,6 +292,36 @@ func TestUpstreamConfigFlushInterval(t *testing.T) { } } +func TestUpstreamConfigPreserveHost(t *testing.T) { + wantPreserveHost := true + templateVars := map[string]string{ + "cluster": "sso", + "root_domain": "dev", + } + upstreamConfigs, err := loadServiceConfigs([]byte(` +- service: foo + default: + from: foo.{{cluster}}.{{root_domain}} + to: foo-internal.{{cluster}}.{{root_domain}} + options: + preserve_host: true +`), "sso", "http", templateVars) + if err != nil { + t.Fatalf("expected to parse upstream configs: %s", err) + } + + if len(upstreamConfigs) == 0 { + t.Fatalf("expected service config") + } + + upstreamConfig := upstreamConfigs[0] + if upstreamConfig.PreserveHost != wantPreserveHost { + t.Logf("want: %v", wantPreserveHost) + t.Logf(" got: %v", upstreamConfig.PreserveHost) + t.Errorf("got unexpected configured timeout") + } +} + func TestUpstreamConfigHeaderOverrides(t *testing.T) { wantHeaders := map[string]string{ "X-Frame-Options": "DENY",