Skip to content

Commit

Permalink
improve provider name handling
Browse files Browse the repository at this point in the history
Add provider name into JWT token claims
to allow provider names with multiple underscore "_" symbols.
Forbid provider names containing URL reserved symbols.
  • Loading branch information
cyb3r4nt committed Aug 29, 2024
1 parent ec38494 commit f13649e
Show file tree
Hide file tree
Showing 22 changed files with 298 additions and 86 deletions.
34 changes: 34 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package auth
import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -267,10 +269,42 @@ func (s *Service) addProviderByName(name string, p provider.Params) {
}

func (s *Service) addProvider(prov provider.Provider) {
if !s.isValidProviderName(prov.Name()) {
return
}
s.providers = append(s.providers, provider.NewService(prov))
s.authMiddleware.Providers = s.providers
}

func (s *Service) isValidProviderName(name string) bool {
if strings.TrimSpace(name) == "" {
s.logger.Logf("[ERROR] provider has been ignored because its name is empty")
return false
}

formatForbidden := func(name string) {
s.logger.Logf("[ERROR] provider has been ignored because its name contains forbidden characters: '%s'", name)
}

path, err := url.PathUnescape(name)
if err != nil || path != name {
formatForbidden(name)
return false
}
if name != url.PathEscape(name) {
formatForbidden(name)
return false
}
// net/url package does not escape everything (https://github.com/golang/go/issues/5684)
// It is better to reject all reserved characters from https://datatracker.ietf.org/doc/html/rfc3986#section-2.2
if regexp.MustCompile(`[:/?#\[\]@!$&'\(\)*+,;=]`).MatchString(name) {
formatForbidden(name)
return false
}

return true
}

// AddProvider adds provider for given name
func (s *Service) AddProvider(name, cid, csecret string) {
p := provider.Params{
Expand Down
36 changes: 32 additions & 4 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,34 @@ func TestIntegrationList(t *testing.T) {
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}

func TestIntegrationInvalidProviderNames(t *testing.T) {
invalidNames := []string{
"provider/with/slashes",
"provider with spaces",
" providerWithSpacesAround\t",
"providerWithReserved-$-Char",
"providerWithReserved-&-Char",
"providerWithReserved-+-Char",
"providerWithReserved-,-Char",
"providerWithReserved-:-Char",
"providerWithReserved-;-Char",
"providerWithReserved-=-Char",
"providerWithReserved-?-Char",
"providerWithReserved-@-Char",
"providerWith%2F-EscapedSequence",
"",
}
svc, teardown := prepService(t, func(svc *Service) {
for _, name := range invalidNames {
svc.AddCustomProvider(name, Client{"cid", "csecret"}, provider.CustomHandlerOpt{})
}
})
defer teardown()

require.Equal(t, 1, len(svc.Providers()))
require.Equal(t, "dev", svc.Providers()[0].Name())
}

func TestIntegrationUserInfo(t *testing.T) {
_, teardown := prepService(t)
defer teardown()
Expand Down Expand Up @@ -386,7 +414,7 @@ func TestDirectProvider(t *testing.T) {

func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
_, teardown := prepService(t, func(svc *Service) {
svc.AddDirectProviderWithUserIDFunc("directCustom",
svc.AddDirectProviderWithUserIDFunc("direct_custom",
provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}),
Expand All @@ -401,12 +429,12 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
jar, err := cookiejar.New(nil)
require.Nil(t, err)
client := &http.Client{Jar: jar, Timeout: 5 * time.Second}
resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad")
resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 403, resp.StatusCode)

resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password")
resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
Expand All @@ -416,7 +444,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
t.Logf("resp %s", string(body))
t.Logf("headers: %+v", resp.Header)

assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)
assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)

require.Equal(t, 2, len(resp.Cookies()))
assert.Equal(t, "JWT", resp.Cookies()[0].Name)
Expand Down
23 changes: 17 additions & 6 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
}

// check if user provider is allowed
if !a.isProviderAllowed(claims.User.ID) {
if !a.isProviderAllowed(&claims) {
onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID))
a.JWTService.Reset(w)
return
Expand All @@ -153,13 +153,24 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
return f
}

// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890"
// this check is needed to reject users from providers what are used to be allowed but not anymore.
// isProviderAllowed checks if user provider is allowed.
// If provider name is explicitly set in the token claims, then that provider is checked.
//
// If user id looks like "provider_1234567890",
// then there is an attempt to extract provider name from that user ID.
// Note that such read can fail if user id has multiple "_" separator symbols.
//
// This check is needed to reject users from providers what are used to be allowed but not anymore.
// Such users made token before the provider was disabled and should not be allowed to login anymore.
func (a *Authenticator) isProviderAllowed(userID string) bool {
userProvider := strings.Split(userID, "_")[0]
func (a *Authenticator) isProviderAllowed(claims *token.Claims) bool {
// TODO: remove this read when old tokens expire and all new tokens have a provider name in them
userIDProvider := strings.Split(claims.User.ID, "_")[0]
for _, p := range a.Providers {
if p.Name() == userProvider {
name := p.Name()
if claims.AuthProvider != nil && claims.AuthProvider.Name == name {
return true
}
if name == userIDProvider {
return true
}
}
Expand Down
55 changes: 26 additions & 29 deletions middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4Mj

var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g"

var testJwtValidWithAuthProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fSwiYXV0aF9wcm92aWRlciI6eyJuYW1lIjoicHJvdmlkZXIxIn19.iBKM9-lgejJNjcs-crj6gkEejnIJpavmaq8alenf0JA"

func TestAuthJWTCookie(t *testing.T) {
a := makeTestAuth(t)

Expand All @@ -51,56 +53,51 @@ func TestAuthJWTCookie(t *testing.T) {
client := &http.Client{Timeout: 5 * time.Second}
expiration := int(365 * 24 * time.Hour.Seconds()) //nolint

t.Run("valid token", func(t *testing.T) {
makeRequest := func(jwtCookie string, xsrfToken string) *http.Response {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
req.AddCookie(&http.Cookie{
Name: "JWT",
Value: jwtCookie,
HttpOnly: true,
Path: "/",
MaxAge: expiration,
Secure: false,
})
req.Header.Add("X-XSRF-TOKEN", xsrfToken)

resp, err := client.Do(req)
require.NoError(t, err)
return resp
}

t.Run("valid token", func(t *testing.T) {
resp := makeRequest(testJwtValid, "random id")
assert.Equal(t, 201, resp.StatusCode, "valid token user")
})

t.Run("valid token, wrong provider", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/",
MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
t.Run("valid token with auth_provider", func(t *testing.T) {
resp := makeRequest(testJwtValidWithAuthProvider, "random id")
assert.Equal(t, 201, resp.StatusCode, "valid token user")
})

resp, err := client.Do(req)
require.NoError(t, err)
t.Run("valid token, wrong provider", func(t *testing.T) {
resp := makeRequest(testJwtValidWrongProvider, "random id")
assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed")
})

t.Run("xsrf mismatch", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "wrong id")
resp, err := client.Do(req)
require.NoError(t, err)
resp := makeRequest(testJwtValid, "wrong id")
assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch")
})

t.Run("token expired and refreshed", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err := client.Do(req)
require.NoError(t, err)
resp := makeRequest(testJwtExpired, "random id")
assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed")
})

t.Run("no user info in the token", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err := client.Do(req)
require.NoError(t, err)
resp := makeRequest(testJwtNoUser, "random id")
assert.Equal(t, 401, resp.StatusCode, "no user info in the token")
})
}
Expand Down
6 changes: 6 additions & 0 deletions provider/apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ func (ah *AppleHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
ExpiresAt: time.Now().Add(30 * time.Minute).Unix(),
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
AuthProvider: &token.AuthProvider{
Name: ah.name,
},
}

if _, err = ah.JwtService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -376,6 +379,9 @@ func (ah AppleHandler) AuthHandler(w http.ResponseWriter, r *http.Request) {
Audience: oauthClaims.Audience,
},
SessionOnly: false,
AuthProvider: &token.AuthProvider{
Name: ah.name,
},
}

if _, err = ah.JwtService.Set(w, claims); err != nil {
Expand Down
3 changes: 3 additions & 0 deletions provider/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ func (p DirectHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
Audience: creds.Audience,
},
SessionOnly: sessOnly,
AuthProvider: &token.AuthProvider{
Name: p.ProviderName,
},
}

if _, err = p.TokenService.Set(w, claims); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions provider/oauth1.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ func (h Oauth1Handler) LoginHandler(w http.ResponseWriter, r *http.Request) {
ExpiresAt: time.Now().Add(30 * time.Minute).Unix(),
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
AuthProvider: &token.AuthProvider{
Name: h.name,
},
}

if _, err = h.JwtService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -146,6 +149,9 @@ func (h Oauth1Handler) AuthHandler(w http.ResponseWriter, r *http.Request) {
Audience: oauthClaims.Audience,
},
SessionOnly: oauthClaims.SessionOnly,
AuthProvider: &token.AuthProvider{
Name: h.name,
},
}

if _, err = h.JwtService.Set(w, claims); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions provider/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ func (p Oauth2Handler) LoginHandler(w http.ResponseWriter, r *http.Request) {
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
NoAva: r.URL.Query().Get("noava") == "1",
AuthProvider: &token.AuthProvider{
Name: p.name,
},
}

if _, err := p.JwtService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -215,6 +218,9 @@ func (p Oauth2Handler) AuthHandler(w http.ResponseWriter, r *http.Request) {
},
SessionOnly: oauthClaims.SessionOnly,
NoAva: oauthClaims.NoAva,
AuthProvider: &token.AuthProvider{
Name: p.name,
},
}

if _, err = p.JwtService.Set(w, claims); err != nil {
Expand Down
3 changes: 3 additions & 0 deletions provider/telegram.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ func (th *TelegramHandler) LoginHandler(w http.ResponseWriter, r *http.Request)
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
SessionOnly: false, // TODO review?
AuthProvider: &authtoken.AuthProvider{
Name: th.Name(),
},
}

if _, err := th.TokenService.Set(w, claims); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions provider/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ func (e VerifyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
Audience: confClaims.Audience,
},
SessionOnly: sessOnly,
AuthProvider: &token.AuthProvider{
Name: e.ProviderName,
},
}

if _, err = e.TokenService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -152,6 +155,9 @@ func (e VerifyHandler) sendConfirmation(w http.ResponseWriter, r *http.Request)
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
Issuer: e.Issuer,
},
AuthProvider: &token.AuthProvider{
Name: e.ProviderName,
},
}

tkn, err := e.TokenService.Token(claims)
Expand Down
14 changes: 10 additions & 4 deletions token/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ type Service struct {
// Claims stores user info for token and state & from from login
type Claims struct {
jwt.StandardClaims
User *User `json:"user,omitempty"` // user info
SessionOnly bool `json:"sess_only,omitempty"`
Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake
NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon
User *User `json:"user,omitempty"` // user info
SessionOnly bool `json:"sess_only,omitempty"`
Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake
NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon
AuthProvider *AuthProvider `json:"auth_provider,omitempty"` // auth provider info
}

// Handshake used for oauth handshake
Expand All @@ -34,6 +35,11 @@ type Handshake struct {
ID string `json:"id,omitempty"`
}

// AuthProvider stores attributes of provider which has created a JWT token
type AuthProvider struct {
Name string `json:"name,omitempty"`
}

const (
// default names for cookies and headers
defaultJWTCookieName = "JWT"
Expand Down
Loading

0 comments on commit f13649e

Please sign in to comment.