diff --git a/internal/auth.go b/internal/auth.go index cefd2315..1fbcec07 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -170,23 +170,31 @@ func ClearCookie(r *http.Request) *http.Cookie { } } +func buildCSRFCookieName(nonce string) string { + return config.CSRFCookieName + "_" + nonce[:6] +} + // MakeCSRFCookie makes a csrf cookie (used during login only) +// +// Note, CSRF cookies live shorter than auth cookies, a fixed 1h. +// That's because some CSRF cookies may belong to auth flows that don't complete +// and thus may not get cleared by ClearCookie. func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: buildCSRFCookieName(nonce), Value: nonce, Path: "/", Domain: csrfCookieDomain(r), HttpOnly: true, Secure: !config.InsecureCookie, - Expires: cookieExpiry(), + Expires: time.Now().Local().Add(time.Hour * 1), } } // ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie -func ClearCSRFCookie(r *http.Request) *http.Cookie { +func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: c.Name, Value: "", Path: "/", Domain: csrfCookieDomain(r), @@ -196,18 +204,22 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie { } } -// ValidateCSRFCookie validates the csrf cookie against state -func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) { - state := r.URL.Query().Get("state") +// FindCSRFCookie extracts the CSRF cookie from the request based on state. +func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) { + // Check for CSRF cookie + c, err = r.Cookie(buildCSRFCookieName(state)) + if err != nil { + return nil, err + } + return c, nil +} +// ValidateCSRFCookie validates the csrf cookie against state +func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) { if len(c.Value) != 32 { return false, "", "", errors.New("Invalid CSRF cookie value") } - if len(state) < 34 { - return false, "", "", errors.New("Invalid CSRF state value") - } - // Check nonce match if c.Value != state[:32] { return false, "", "", errors.New("CSRF cookie does not match state") @@ -229,6 +241,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string { return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r)) } +// ValidateState checks whether the state is of right length. +func ValidateState(state string) error { + if len(state) < 34 { + return errors.New("Invalid CSRF state value") + } + return nil +} + // Nonce generates a random nonce func Nonce() (error, string) { nonce := make([]byte, 16) diff --git a/internal/auth_test.go b/internal/auth_test.go index 14ee1ce6..981936a1 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -1,7 +1,6 @@ package tfa import ( - "fmt" "net/http" "net/url" "strings" @@ -236,10 +235,12 @@ func TestAuthMakeCSRFCookie(t *testing.T) { } func TestAuthClearCSRFCookie(t *testing.T) { + assert := assert.New(t) config, _ = NewConfig([]string{}) r, _ := http.NewRequest("GET", "http://example.com", nil) - c := ClearCSRFCookie(r) + c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"}) + assert.Equal("someCsrfCookie", c.Name) if c.Value != "" { t.Error("ClearCSRFCookie should create cookie with empty value") } @@ -249,56 +250,57 @@ func TestAuthValidateCSRFCookie(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) c := &http.Cookie{} - - newCsrfRequest := func(state string) *http.Request { - u := fmt.Sprintf("http://example.com?state=%s", state) - r, _ := http.NewRequest("GET", u, nil) - return r - } + state := "" // Should require 32 char string - r := newCsrfRequest("") + state = "" c.Value = "" - valid, _, _, err := ValidateCSRFCookie(r, c) + valid, _, _, err := ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF cookie value", err.Error()) } c.Value = "123456789012345678901234567890123" - valid, _, _, err = ValidateCSRFCookie(r, c) + valid, _, _, err = ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF cookie value", err.Error()) } - // Should require valid state - r = newCsrfRequest("12345678901234567890123456789012:") - c.Value = "12345678901234567890123456789012" - valid, _, _, err = ValidateCSRFCookie(r, c) - assert.False(valid) - if assert.Error(err) { - assert.Equal("Invalid CSRF state value", err.Error()) - } - // Should require provider - r = newCsrfRequest("12345678901234567890123456789012:99") + state = "12345678901234567890123456789012:99" c.Value = "12345678901234567890123456789012" - valid, _, _, err = ValidateCSRFCookie(r, c) + valid, _, _, err = ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF state format", err.Error()) } // Should allow valid state - r = newCsrfRequest("12345678901234567890123456789012:p99:url123") + state = "12345678901234567890123456789012:p99:url123" c.Value = "12345678901234567890123456789012" - valid, provider, redirect, err := ValidateCSRFCookie(r, c) + valid, provider, redirect, err := ValidateCSRFCookie(c, state) assert.True(valid, "valid request should return valid") assert.Nil(err, "valid request should not return an error") assert.Equal("p99", provider, "valid request should return correct provider") assert.Equal("url123", redirect, "valid request should return correct redirect") } +func TestValidateState(t *testing.T) { + assert := assert.New(t) + + // Should require valid state + state := "12345678901234567890123456789012:" + err := ValidateState(state) + if assert.Error(err) { + assert.Equal("Invalid CSRF state value", err.Error()) + } + // Should pass this state + state = "12345678901234567890123456789012:p99:url123" + err = ValidateState(state) + assert.Nil(err, "valid request should not return an error") +} + func TestMakeState(t *testing.T) { assert := assert.New(t) diff --git a/internal/server.go b/internal/server.go index 8ac03131..b7cac0b9 100644 --- a/internal/server.go +++ b/internal/server.go @@ -120,9 +120,17 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Logging setup logger := s.logger(r, "AuthCallback", "default", "Handling callback") + state := r.URL.Query().Get("state") + if err := ValidateState(state); err != nil { + logger.WithFields(logrus.Fields{ + "error": err, + }).Warn("Bad CSRF state") + http.Error(w, "Not authorized", 401) + return + } // Check for CSRF cookie - c, err := r.Cookie(config.CSRFCookieName) + c, err := FindCSRFCookie(r, state) if err != nil { logger.Info("Missing csrf cookie") http.Error(w, "Not authorized", 401) @@ -130,7 +138,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } // Validate state - valid, providerName, redirect, err := ValidateCSRFCookie(r, c) + valid, providerName, redirect, err := ValidateCSRFCookie(c, state) if !valid { logger.WithFields(logrus.Fields{ "error": err, @@ -153,7 +161,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } // Clear CSRF cookie - http.SetCookie(w, ClearCSRFCookie(r)) + http.SetCookie(w, ClearCSRFCookie(r, c)) // Exchange code for token token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code")) diff --git a/internal/server_test.go b/internal/server_test.go index 2e543400..8ec0f01d 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -98,7 +98,7 @@ func TestServerAuthHandlerExpired(t *testing.T) { // Check for CSRF cookie var cookie *http.Cookie for _, c := range res.Cookies() { - if c.Name == config.CSRFCookieName { + if strings.HasPrefix(c.Name, config.CSRFCookieName) { cookie = c } }