Skip to content

Commit

Permalink
Configurable TLS verification of upstream servers
Browse files Browse the repository at this point in the history
  • Loading branch information
sporkmonger committed Sep 25, 2018
1 parent 0fcc468 commit e2eed96
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 15 deletions.
34 changes: 25 additions & 9 deletions internal/proxy/oauthproxy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"crypto/tls"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -131,11 +132,26 @@ func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// upstreamTransport is used to ensure that upstreams cannot override the
// security headers applied by sso_proxy
type upstreamTransport struct{}
type upstreamTransport struct {
InsecureSkipVerify bool
}

// RoundTrip round trips the request and deletes security headers before returning the response.
func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := http.DefaultTransport.RoundTrip(req)
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: t.InsecureSkipVerify},
ExpectContinueTimeout: 1 * time.Second,
}
resp, err := transport.RoundTrip(req)
if err != nil {
logger := log.NewLogEntry()
logger.Error(err, "error in upstreamTransport RoundTrip")
Expand All @@ -149,9 +165,9 @@ func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error)

// NewReverseProxy creates a reverse proxy to a specified url.
// It adds an X-Forwarded-Host header that is the request's host.
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
func NewReverseProxy(to *url.URL, config *UpstreamConfig) *httputil.ReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(to)
proxy.Transport = &upstreamTransport{}
proxy.Transport = &upstreamTransport{InsecureSkipVerify: config.TLSSkipVerify}
director := proxy.Director
proxy.Director = func(req *http.Request) {
req.Header.Add("X-Forwarded-Host", req.Host)
Expand All @@ -164,9 +180,9 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
// NewRewriteReverseProxy creates a reverse proxy that is capable of creating upstream
// urls on the fly based on a from regex and a templated to field.
// It adds an X-Forwarded-Host header to the the upstream's request.
func NewRewriteReverseProxy(route *RewriteRoute) *httputil.ReverseProxy {
func NewRewriteReverseProxy(route *RewriteRoute, config *UpstreamConfig) *httputil.ReverseProxy {
proxy := &httputil.ReverseProxy{}
proxy.Transport = &upstreamTransport{}
proxy.Transport = &upstreamTransport{InsecureSkipVerify: config.TLSSkipVerify}
proxy.Director = func(req *http.Request) {
// we do this to rewrite requests
rewritten := route.FromRegex.ReplaceAllString(req.Host, route.ToTemplate.Opaque)
Expand Down Expand Up @@ -296,15 +312,15 @@ func NewOAuthProxy(opts *Options, optFuncs ...func(*OAuthProxy) error) (*OAuthPr
for _, upstreamConfig := range opts.upstreamConfigs {
switch route := upstreamConfig.Route.(type) {
case *SimpleRoute:
reverseProxy := NewReverseProxy(route.ToURL)
reverseProxy := NewReverseProxy(route.ToURL, upstreamConfig)
handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig)
p.Handle(route.FromURL.Host, handler, tags, upstreamConfig)
case *RewriteRoute:
reverseProxy := NewRewriteReverseProxy(route)
reverseProxy := NewRewriteReverseProxy(route, upstreamConfig)
handler, tags := NewReverseProxyHandler(reverseProxy, opts, upstreamConfig)
p.HandleRegex(route.FromRegex, handler, tags, upstreamConfig)
default:
return nil, fmt.Errorf("unkown route type")
return nil, fmt.Errorf("unknown route type")
}
}

Expand Down
72 changes: 68 additions & 4 deletions internal/proxy/oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestNewReverseProxy(t *testing.T) {
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")

proxyHandler := NewReverseProxy(proxyURL)
proxyHandler := NewReverseProxy(proxyURL, &UpstreamConfig{TLSSkipVerify: false})
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()

Expand Down Expand Up @@ -109,7 +109,7 @@ func TestNewRewriteReverseProxy(t *testing.T) {
},
}

rewriteProxy := NewRewriteReverseProxy(route)
rewriteProxy := NewRewriteReverseProxy(route, &UpstreamConfig{TLSSkipVerify: false})

frontend := httptest.NewServer(rewriteProxy)
defer frontend.Close()
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestNewReverseProxyHostname(t *testing.T) {
t.Fatalf("expected to parse to url: %s", err)
}

reverseProxy := NewReverseProxy(toURL)
reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{TLSSkipVerify: false})
from := httptest.NewServer(reverseProxy)
defer from.Close()

Expand Down Expand Up @@ -201,6 +201,70 @@ func TestNewReverseProxyHostname(t *testing.T) {

}

func TestNewReverseProxyTLSSkipVerify(t *testing.T) {
type respStruct struct {
HandshakeComplete bool `json:"handshake-complete"`
}

testCases := []struct {
name string
skipVerify bool
expectedStatus int
}{
{
name: "skip verify true",
skipVerify: true,
expectedStatus: 200,
},
{
name: "skip verify false",
skipVerify: false,
expectedStatus: 502,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
to := httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
body, err := json.Marshal(
// Doesn't really matter what's sent since we should 502
&respStruct{
HandshakeComplete: r.TLS.HandshakeComplete,
},
)
if err != nil {
t.Fatalf("expected to marshal json: %s", err)
}
rw.Write(body)
}))
defer to.Close()

toURL, err := url.Parse(to.URL)
if err != nil {
t.Fatalf("expected to parse to url: %s", err)
}

reverseProxy := NewReverseProxy(toURL, &UpstreamConfig{TLSSkipVerify: tc.skipVerify})
from := httptest.NewServer(reverseProxy)
defer from.Close()

res, err := http.Get(from.URL)
if err != nil {
t.Fatalf("expected to be able to make req: %s", err)
}

if res.StatusCode != tc.expectedStatus {
t.Logf(" got status code: %v", res.StatusCode)
t.Logf("want status code: %d", tc.expectedStatus)

t.Errorf("got unexpected response code for tls failure")
}
if res.Header.Get("Cookie") != "" {
t.Errorf("expected Cookie header to be empty but was %s", res.Header.Get("Cookie"))
}
})
}
}

func TestDeleteSSOHeader(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -300,7 +364,7 @@ func TestEncodedSlashes(t *testing.T) {
defer backend.Close()

b, _ := url.Parse(backend.URL)
proxyHandler := NewReverseProxy(b)
proxyHandler := NewReverseProxy(b, &UpstreamConfig{TLSSkipVerify: false})
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()

Expand Down
5 changes: 4 additions & 1 deletion internal/proxy/proxy_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type UpstreamConfig struct {

SkipAuthCompiledRegex []*regexp.Regexp
AllowedGroups []string
TLSSkipVerify bool
HMACAuth hmacauth.HmacAuth
Timeout time.Duration
FlushInterval time.Duration
Expand Down Expand Up @@ -79,6 +80,7 @@ type OptionsConfig struct {
HeaderOverrides map[string]string `yaml:"header_overrides"`
SkipAuthRegex []string `yaml:"skip_auth_regex"`
AllowedGroups []string `yaml:"allowed_groups"`
TLSSkipVerify bool `yaml:"tls_skip_verify"`
Timeout time.Duration `yaml:"timeout"`
FlushInterval time.Duration `yaml:"flush_interval"`
}
Expand Down Expand Up @@ -109,7 +111,7 @@ func loadServiceConfigs(raw []byte, cluster, scheme string, configVars map[strin
// we don't set this to the len(serviceConfig) since not all service configs
// are configured for all clusters, leaving nil tail pointers in the slice.
configs := make([]*UpstreamConfig, 0)
// resovle overrides
// resolve overrides
for _, service := range serviceConfigs {
proxy, err := resolveUpstreamConfig(service, cluster)
if err != nil {
Expand Down Expand Up @@ -362,6 +364,7 @@ func parseOptionsConfig(proxy *UpstreamConfig) error {
proxy.Timeout = proxy.RouteConfig.Options.Timeout
proxy.FlushInterval = proxy.RouteConfig.Options.FlushInterval
proxy.HeaderOverrides = proxy.RouteConfig.Options.HeaderOverrides
proxy.TLSSkipVerify = proxy.RouteConfig.Options.TLSSkipVerify

proxy.RouteConfig.Options = nil

Expand Down
Loading

0 comments on commit e2eed96

Please sign in to comment.