Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support concurrent CSRF cookies by using a prefix of nonce #187

Merged
merged 3 commits into from
Sep 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions internal/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,31 @@ func ClearCookie(r *http.Request) *http.Cookie {
}
}

func buildCSRFCookieName(nonce string) string {
return config.CSRFCookieName + "_" + nonce[:6]
}

// MakeCSRFCookie makes a csrf cookie (used during login only)
//
// Note, CSRF cookies live shorter than auth cookies, a fixed 1h.
// That's because some CSRF cookies may belong to auth flows that don't complete
// and thus may not get cleared by ClearCookie.
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Name: buildCSRFCookieName(nonce),
Value: nonce,
Path: "/",
Domain: csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: cookieExpiry(),
Expires: time.Now().Local().Add(time.Hour * 1),
}
}

// ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie
func ClearCSRFCookie(r *http.Request) *http.Cookie {
func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Name: c.Name,
Value: "",
Path: "/",
Domain: csrfCookieDomain(r),
Expand All @@ -196,18 +204,18 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
}
}

// ValidateCSRFCookie validates the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
state := r.URL.Query().Get("state")
// FindCSRFCookie extracts the CSRF cookie from the request based on state.
func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) {
// Check for CSRF cookie
return r.Cookie(buildCSRFCookieName(state))
}

// ValidateCSRFCookie validates the csrf cookie against state
func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) {
if len(c.Value) != 32 {
return false, "", "", errors.New("Invalid CSRF cookie value")
}

if len(state) < 34 {
return false, "", "", errors.New("Invalid CSRF state value")
}

// Check nonce match
if c.Value != state[:32] {
return false, "", "", errors.New("CSRF cookie does not match state")
Expand All @@ -229,6 +237,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string {
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
}

// ValidateState checks whether the state is of right length.
func ValidateState(state string) error {
if len(state) < 34 {
return errors.New("Invalid CSRF state value")
}
return nil
}

// Nonce generates a random nonce
func Nonce() (error, string) {
nonce := make([]byte, 16)
Expand Down
67 changes: 34 additions & 33 deletions internal/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tfa

import (
"fmt"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -217,29 +216,30 @@ func TestAuthMakeCSRFCookie(t *testing.T) {

// No cookie domain or auth url
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
assert.Equal("_forward_auth_csrf_123456", c.Name)
assert.Equal("app.example.com", c.Domain)

// With cookie domain but no auth url
config = &Config{
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
}
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
c = MakeCSRFCookie(r, "12222278901234567890123456789012")
assert.Equal("_forward_auth_csrf_122222", c.Name)
assert.Equal("app.example.com", c.Domain)

// With cookie domain and auth url
config = &Config{
AuthHost: "auth.example.com",
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
}
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
config.AuthHost = "auth.example.com"
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
c = MakeCSRFCookie(r, "12333378901234567890123456789012")
assert.Equal("_forward_auth_csrf_123333", c.Name)
assert.Equal("example.com", c.Domain)
}

func TestAuthClearCSRFCookie(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
r, _ := http.NewRequest("GET", "http://example.com", nil)

c := ClearCSRFCookie(r)
c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"})
assert.Equal("someCsrfCookie", c.Name)
if c.Value != "" {
t.Error("ClearCSRFCookie should create cookie with empty value")
}
Expand All @@ -249,56 +249,57 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
c := &http.Cookie{}

newCsrfRequest := func(state string) *http.Request {
u := fmt.Sprintf("http://example.com?state=%s", state)
r, _ := http.NewRequest("GET", u, nil)
return r
}
state := ""

// Should require 32 char string
r := newCsrfRequest("")
state = ""
c.Value = ""
valid, _, _, err := ValidateCSRFCookie(r, c)
valid, _, _, err := ValidateCSRFCookie(c, state)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error())
}
c.Value = "123456789012345678901234567890123"
valid, _, _, err = ValidateCSRFCookie(r, c)
valid, _, _, err = ValidateCSRFCookie(c, state)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error())
}

// Should require valid state
r = newCsrfRequest("12345678901234567890123456789012:")
c.Value = "12345678901234567890123456789012"
valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error())
}

// Should require provider
r = newCsrfRequest("12345678901234567890123456789012:99")
state = "12345678901234567890123456789012:99"
c.Value = "12345678901234567890123456789012"
valid, _, _, err = ValidateCSRFCookie(r, c)
valid, _, _, err = ValidateCSRFCookie(c, state)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state format", err.Error())
}

// Should allow valid state
r = newCsrfRequest("12345678901234567890123456789012:p99:url123")
state = "12345678901234567890123456789012:p99:url123"
c.Value = "12345678901234567890123456789012"
valid, provider, redirect, err := ValidateCSRFCookie(r, c)
valid, provider, redirect, err := ValidateCSRFCookie(c, state)
assert.True(valid, "valid request should return valid")
assert.Nil(err, "valid request should not return an error")
assert.Equal("p99", provider, "valid request should return correct provider")
assert.Equal("url123", redirect, "valid request should return correct redirect")
}

func TestValidateState(t *testing.T) {
assert := assert.New(t)

// Should require valid state
state := "12345678901234567890123456789012:"
err := ValidateState(state)
if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error())
}
// Should pass this state
state = "12345678901234567890123456789012:p99:url123"
err = ValidateState(state)
assert.Nil(err, "valid request should not return an error")
}

func TestMakeState(t *testing.T) {
assert := assert.New(t)

Expand Down
18 changes: 14 additions & 4 deletions internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
// Logging setup
logger := s.logger(r, "AuthCallback", "default", "Handling callback")

// Check state
state := r.URL.Query().Get("state")
if err := ValidateState(state); err != nil {
logger.WithFields(logrus.Fields{
"error": err,
}).Warn("Error validating state")
http.Error(w, "Not authorized", 401)
return
}

// Check for CSRF cookie
c, err := r.Cookie(config.CSRFCookieName)
c, err := FindCSRFCookie(r, state)
if err != nil {
logger.Info("Missing csrf cookie")
http.Error(w, "Not authorized", 401)
return
}

// Validate state
valid, providerName, redirect, err := ValidateCSRFCookie(r, c)
// Validate CSRF cookie against state
valid, providerName, redirect, err := ValidateCSRFCookie(c, state)
if !valid {
logger.WithFields(logrus.Fields{
"error": err,
Expand All @@ -153,7 +163,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
}

// Clear CSRF cookie
http.SetCookie(w, ClearCSRFCookie(r))
http.SetCookie(w, ClearCSRFCookie(r, c))

// Exchange code for token
token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code"))
Expand Down
2 changes: 1 addition & 1 deletion internal/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestServerAuthHandlerExpired(t *testing.T) {
// Check for CSRF cookie
var cookie *http.Cookie
for _, c := range res.Cookies() {
if c.Name == config.CSRFCookieName {
if strings.HasPrefix(c.Name, config.CSRFCookieName) {
cookie = c
}
}
Expand Down