Skip to content

Commit

Permalink
Support concurrent CSRF cookies by using a prefix of nonce (#187)
Browse files Browse the repository at this point in the history
* Support concurrent CSRF cookies by using a prefix of nonce.
* Move ValidateState out and make CSRF cookies last 1h
* add tests to check csrf cookie nam + minor tweaks

Co-authored-by: Michal Witkowski <michal@cerberus>
  • Loading branch information
thomseddon and Michal Witkowski authored Sep 23, 2020
1 parent 1743537 commit 41560fe
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 49 deletions.
38 changes: 27 additions & 11 deletions internal/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,31 @@ func ClearCookie(r *http.Request) *http.Cookie {
}
}

func buildCSRFCookieName(nonce string) string {
return config.CSRFCookieName + "_" + nonce[:6]
}

// MakeCSRFCookie makes a csrf cookie (used during login only)
//
// Note, CSRF cookies live shorter than auth cookies, a fixed 1h.
// That's because some CSRF cookies may belong to auth flows that don't complete
// and thus may not get cleared by ClearCookie.
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Name: buildCSRFCookieName(nonce),
Value: nonce,
Path: "/",
Domain: csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: cookieExpiry(),
Expires: time.Now().Local().Add(time.Hour * 1),
}
}

// ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie
func ClearCSRFCookie(r *http.Request) *http.Cookie {
func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Name: c.Name,
Value: "",
Path: "/",
Domain: csrfCookieDomain(r),
Expand All @@ -196,18 +204,18 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
}
}

// ValidateCSRFCookie validates the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
state := r.URL.Query().Get("state")
// FindCSRFCookie extracts the CSRF cookie from the request based on state.
func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) {
// Check for CSRF cookie
return r.Cookie(buildCSRFCookieName(state))
}

// ValidateCSRFCookie validates the csrf cookie against state
func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) {
if len(c.Value) != 32 {
return false, "", "", errors.New("Invalid CSRF cookie value")
}

if len(state) < 34 {
return false, "", "", errors.New("Invalid CSRF state value")
}

// Check nonce match
if c.Value != state[:32] {
return false, "", "", errors.New("CSRF cookie does not match state")
Expand All @@ -229,6 +237,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string {
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
}

// ValidateState checks whether the state is of right length.
func ValidateState(state string) error {
if len(state) < 34 {
return errors.New("Invalid CSRF state value")
}
return nil
}

// Nonce generates a random nonce
func Nonce() (error, string) {
nonce := make([]byte, 16)
Expand Down
67 changes: 34 additions & 33 deletions internal/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tfa

import (
"fmt"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -217,29 +216,30 @@ func TestAuthMakeCSRFCookie(t *testing.T) {

// No cookie domain or auth url
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
assert.Equal("_forward_auth_csrf_123456", c.Name)
assert.Equal("app.example.com", c.Domain)

// With cookie domain but no auth url
config = &Config{
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
}
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
c = MakeCSRFCookie(r, "12222278901234567890123456789012")
assert.Equal("_forward_auth_csrf_122222", c.Name)
assert.Equal("app.example.com", c.Domain)

// With cookie domain and auth url
config = &Config{
AuthHost: "auth.example.com",
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
}
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
config.AuthHost = "auth.example.com"
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
c = MakeCSRFCookie(r, "12333378901234567890123456789012")
assert.Equal("_forward_auth_csrf_123333", c.Name)
assert.Equal("example.com", c.Domain)
}

func TestAuthClearCSRFCookie(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
r, _ := http.NewRequest("GET", "http://example.com", nil)

c := ClearCSRFCookie(r)
c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"})
assert.Equal("someCsrfCookie", c.Name)
if c.Value != "" {
t.Error("ClearCSRFCookie should create cookie with empty value")
}
Expand All @@ -249,56 +249,57 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
c := &http.Cookie{}

newCsrfRequest := func(state string) *http.Request {
u := fmt.Sprintf("http://example.com?state=%s", state)
r, _ := http.NewRequest("GET", u, nil)
return r
}
state := ""

// Should require 32 char string
r := newCsrfRequest("")
state = ""
c.Value = ""
valid, _, _, err := ValidateCSRFCookie(r, c)
valid, _, _, err := ValidateCSRFCookie(c, state)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error())
}
c.Value = "123456789012345678901234567890123"
valid, _, _, err = ValidateCSRFCookie(r, c)
valid, _, _, err = ValidateCSRFCookie(c, state)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error())
}

// Should require valid state
r = newCsrfRequest("12345678901234567890123456789012:")
c.Value = "12345678901234567890123456789012"
valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error())
}

// Should require provider
r = newCsrfRequest("12345678901234567890123456789012:99")
state = "12345678901234567890123456789012:99"
c.Value = "12345678901234567890123456789012"
valid, _, _, err = ValidateCSRFCookie(r, c)
valid, _, _, err = ValidateCSRFCookie(c, state)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state format", err.Error())
}

// Should allow valid state
r = newCsrfRequest("12345678901234567890123456789012:p99:url123")
state = "12345678901234567890123456789012:p99:url123"
c.Value = "12345678901234567890123456789012"
valid, provider, redirect, err := ValidateCSRFCookie(r, c)
valid, provider, redirect, err := ValidateCSRFCookie(c, state)
assert.True(valid, "valid request should return valid")
assert.Nil(err, "valid request should not return an error")
assert.Equal("p99", provider, "valid request should return correct provider")
assert.Equal("url123", redirect, "valid request should return correct redirect")
}

func TestValidateState(t *testing.T) {
assert := assert.New(t)

// Should require valid state
state := "12345678901234567890123456789012:"
err := ValidateState(state)
if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error())
}
// Should pass this state
state = "12345678901234567890123456789012:p99:url123"
err = ValidateState(state)
assert.Nil(err, "valid request should not return an error")
}

func TestMakeState(t *testing.T) {
assert := assert.New(t)

Expand Down
18 changes: 14 additions & 4 deletions internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
// Logging setup
logger := s.logger(r, "AuthCallback", "default", "Handling callback")

// Check state
state := r.URL.Query().Get("state")
if err := ValidateState(state); err != nil {
logger.WithFields(logrus.Fields{
"error": err,
}).Warn("Error validating state")
http.Error(w, "Not authorized", 401)
return
}

// Check for CSRF cookie
c, err := r.Cookie(config.CSRFCookieName)
c, err := FindCSRFCookie(r, state)
if err != nil {
logger.Info("Missing csrf cookie")
http.Error(w, "Not authorized", 401)
return
}

// Validate state
valid, providerName, redirect, err := ValidateCSRFCookie(r, c)
// Validate CSRF cookie against state
valid, providerName, redirect, err := ValidateCSRFCookie(c, state)
if !valid {
logger.WithFields(logrus.Fields{
"error": err,
Expand All @@ -153,7 +163,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
}

// Clear CSRF cookie
http.SetCookie(w, ClearCSRFCookie(r))
http.SetCookie(w, ClearCSRFCookie(r, c))

// Exchange code for token
token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code"))
Expand Down
2 changes: 1 addition & 1 deletion internal/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestServerAuthHandlerExpired(t *testing.T) {
// Check for CSRF cookie
var cookie *http.Cookie
for _, c := range res.Cookies() {
if c.Name == config.CSRFCookieName {
if strings.HasPrefix(c.Name, config.CSRFCookieName) {
cookie = c
}
}
Expand Down

0 comments on commit 41560fe

Please sign in to comment.