diff --git a/cookie/nonce.go b/cookie/nonce.go new file mode 100644 index 000000000..3012ce2d1 --- /dev/null +++ b/cookie/nonce.go @@ -0,0 +1,16 @@ +package cookie + +import ( + "crypto/rand" + "fmt" +) + +func Nonce() (nonce string, err error) { + b := make([]byte, 16) + _, err = rand.Read(b) + if err != nil { + return + } + nonce = fmt.Sprintf("%x", b) + return +} diff --git a/oauthproxy.go b/oauthproxy.go index 5d325e14a..38bd0a76b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -35,14 +35,16 @@ var SignatureHeaders []string = []string{ } type OAuthProxy struct { - CookieSeed string - CookieName string - CookieDomain string - CookieSecure bool - CookieHttpOnly bool - CookieExpire time.Duration - CookieRefresh time.Duration - Validator func(string) bool + CookieSeed string + CookieName string + CSRFCookieName string + SessionCookieName string + CookieDomain string + CookieSecure bool + CookieHttpOnly bool + CookieExpire time.Duration + CookieRefresh time.Duration + Validator func(string) bool RobotsPath string PingPath string @@ -173,14 +175,16 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { } return &OAuthProxy{ - CookieName: opts.CookieName, - CookieSeed: opts.CookieSecret, - CookieDomain: opts.CookieDomain, - CookieSecure: opts.CookieSecure, - CookieHttpOnly: opts.CookieHttpOnly, - CookieExpire: opts.CookieExpire, - CookieRefresh: opts.CookieRefresh, - Validator: validator, + CookieName: opts.CookieName, + CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"), + SessionCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "session"), + CookieSeed: opts.CookieSecret, + CookieDomain: opts.CookieDomain, + CookieSecure: opts.CookieSecure, + CookieHttpOnly: opts.CookieHttpOnly, + CookieExpire: opts.CookieExpire, + CookieRefresh: opts.CookieRefresh, + Validator: validator, RobotsPath: "/robots.txt", PingPath: "/ping", @@ -245,7 +249,22 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e return } -func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { +func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { + if value != "" { + value = cookie.SignedValue(p.CookieSeed, p.SessionCookieName, value, now) + if len(value) > 4096 { + // Cookies cannot be larger than 4kb + log.Printf("WARNING - Cookie Size: %d bytes", len(value)) + } + } + return p.makeCookie(req, p.SessionCookieName, value, expiration, now) +} + +func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { + return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) +} + +func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { domain := req.Host if h, _, err := net.SplitHostPort(domain); err == nil { domain = h @@ -257,15 +276,8 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time domain = p.CookieDomain } - if value != "" { - value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) - if len(value) > 4096 { - // Cookies cannot be larger than 4kb - log.Printf("WARNING - Cookie Size: %d bytes", len(value)) - } - } return &http.Cookie{ - Name: p.CookieName, + Name: name, Value: value, Path: "/", Domain: domain, @@ -275,20 +287,28 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time } } -func (p *OAuthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) { - http.SetCookie(rw, p.MakeCookie(req, "", time.Hour*-1, time.Now())) +func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { + http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) +} + +func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) { + http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now())) } -func (p *OAuthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) { - http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now())) +func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { + http.SetCookie(rw, p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())) +} + +func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { + http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) } func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { var age time.Duration - c, err := req.Cookie(p.CookieName) + c, err := req.Cookie(p.SessionCookieName) if err != nil { // always http.ErrNoCookie - return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) + return nil, age, fmt.Errorf("Cookie %q not present", p.SessionCookieName) } val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) if !ok { @@ -309,7 +329,7 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p if err != nil { return err } - p.SetCookie(rw, req, value) + p.SetSessionCookie(rw, req, value) return nil } @@ -339,7 +359,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m } func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { - p.ClearCookie(rw, req) + p.ClearSessionCookie(rw, req) rw.WriteHeader(code) redirect_url := req.URL.RequestURI() @@ -384,20 +404,18 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st return "", false } -func (p *OAuthProxy) GetRedirect(req *http.Request) (string, error) { - err := req.ParseForm() - +func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { + err = req.ParseForm() if err != nil { - return "", err + return } - redirect := req.FormValue("rd") - - if redirect == "" { + redirect = req.Form.Get("rd") + if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { redirect = "/" } - return redirect, err + return } func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) { @@ -459,18 +477,24 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { } func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { - p.ClearCookie(rw, req) + p.ClearSessionCookie(rw, req) http.Redirect(rw, req, "/", 302) } func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { + nonce, err := cookie.Nonce() + if err != nil { + p.ErrorPage(rw, 500, "Internal Error", err.Error()) + return + } + p.SetCSRFCookie(rw, req, nonce) redirect, err := p.GetRedirect(req) if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } redirectURI := p.GetRedirectURI(req.Host) - http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302) + http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302) } func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { @@ -495,8 +519,26 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - redirect := req.Form.Get("state") - if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { + s := strings.SplitN(req.Form.Get("state"), ":", 2) + if len(s) != 2 { + p.ErrorPage(rw, 500, "Internal Error", "Invalid State") + return + } + nonce := s[0] + redirect := s[1] + c, err := req.Cookie(p.CSRFCookieName) + if err != nil { + p.ErrorPage(rw, 403, "Permission Denied", err.Error()) + return + } + p.ClearCSRFCookie(rw, req) + if c.Value != nonce { + log.Printf("%s csrf token mismatch, potential attack", remoteAddr) + p.ErrorPage(rw, 403, "Permission Denied", "csrf failed") + return + } + + if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { redirect = "/" } @@ -595,7 +637,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int } if clearSession { - p.ClearCookie(rw, req) + p.ClearSessionCookie(rw, req) } if session == nil { diff --git a/providers/provider_default.go b/providers/provider_default.go index 6b8ec401e..1d1daea44 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -8,7 +8,6 @@ import ( "io/ioutil" "net/http" "net/url" - "strings" "github.com/bitly/oauth2_proxy/cookie" ) @@ -79,7 +78,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er } // GetLoginURL with typical oauth parameters -func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string { +func (p *ProviderData) GetLoginURL(redirectURI, state string) string { var a url.URL a = *p.LoginURL params, _ := url.ParseQuery(a.RawQuery) @@ -88,9 +87,7 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string { params.Add("scope", p.Scope) params.Set("client_id", p.ClientID) params.Set("response_type", "code") - if strings.HasPrefix(finalRedirect, "/") && !strings.HasPrefix(finalRedirect,"//") { - params.Add("state", finalRedirect) - } + params.Add("state", state) a.RawQuery = params.Encode() return a.String() }