Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refresh token only once for all concurrent requests #2692

Merged
merged 1 commit into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 124 additions & 81 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ type refreshError struct {
desc string
}

func (r *refreshError) Error() string {
return fmt.Sprintf("refresh token error: status %d, %q %s", r.code, r.msg, r.desc)
}

func newInternalServerError() *refreshError {
return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
}
Expand Down Expand Up @@ -60,10 +64,23 @@ func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.Refr
return token, nil
}

type refreshContext struct {
storageToken *storage.RefreshToken
requestToken *internal.RefreshToken

connector Connector
connectorData []byte

scopes []string
}

// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) {
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*refreshContext, *refreshError) {
refreshCtx := refreshContext{requestToken: token}

invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.")

// Get RefreshToken
refresh, err := s.storage.GetRefresh(token.RefreshId)
if err != nil {
if err != storage.ErrNotFound {
Expand Down Expand Up @@ -103,7 +120,31 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref
return nil, expiredErr
}

return &refresh, nil
refreshCtx.storageToken = &refresh

// Get Connector
refreshCtx.connector, err = s.getConnector(refresh.ConnectorID)
if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
return nil, newInternalServerError()
}

// Get Connector Data
session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
switch {
case err != nil:
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return nil, newInternalServerError()
}
case len(refresh.ConnectorData) > 0:
// Use the old connector data if it exists, should be deleted once used
refreshCtx.connectorData = refresh.ConnectorData
default:
refreshCtx.connectorData = session.ConnectorData
}

return &refreshCtx, nil
}

func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) {
Expand Down Expand Up @@ -138,59 +179,23 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken
return requestedScopes, nil
}

func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) {
var connectorData []byte

session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
switch {
case err != nil:
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return connector.Identity{}, newInternalServerError()
}
case len(refresh.ConnectorData) > 0:
// Use the old connector data if it exists, should be deleted once used
connectorData = refresh.ConnectorData
default:
connectorData = session.ConnectorData
}

conn, err := s.getConnector(refresh.ConnectorID)
if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
return connector.Identity{}, newInternalServerError()
}

ident := connector.Identity{
UserID: refresh.Claims.UserID,
Username: refresh.Claims.Username,
PreferredUsername: refresh.Claims.PreferredUsername,
Email: refresh.Claims.Email,
EmailVerified: refresh.Claims.EmailVerified,
Groups: refresh.Claims.Groups,
ConnectorData: connectorData,
}

// user's token was previously updated by a connector and is allowed to reuse
// it is excessive to refresh identity in upstream
if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken {
return ident, nil
}

func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, ident connector.Identity) (connector.Identity, *refreshError) {
// Can the connector refresh the identity? If so, attempt to refresh the data
// in the connector.
//
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
// this interface can't perform refreshing.
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident)
if refreshConn, ok := rCtx.connector.Connector.(connector.RefreshConnector); ok {
s.logger.Debugf("connector data before refresh: %s", ident.ConnectorData)

newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident)
if err != nil {
s.logger.Errorf("failed to refresh identity: %v", err)
return connector.Identity{}, newInternalServerError()
return ident, newInternalServerError()
}
ident = newIdent
}

return newIdent, nil
}
return ident, nil
}

Expand All @@ -200,8 +205,14 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne
if old.Refresh[refresh.ClientID].ID != refresh.ID {
return old, errors.New("refresh token invalid")
}

old.Refresh[refresh.ClientID].LastUsed = lastUsed
old.ConnectorData = ident.ConnectorData
if len(ident.ConnectorData) > 0 {
old.ConnectorData = ident.ConnectorData
}

s.logger.Debugf("saved connector data: %s %s", ident.UserID, ident.ConnectorData)

return old, nil
}

Expand All @@ -217,33 +228,74 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne
}

// updateRefreshToken updates refresh token and offline session in the storage
func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) {
newToken := token
if s.refreshTokenPolicy.RotationEnabled() {
newToken = &internal.RefreshToken{
RefreshId: refresh.ID,
Token: storage.NewID(),
}
func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) {
var rerr *refreshError

newToken := &internal.RefreshToken{
Token: rCtx.requestToken.Token,
RefreshId: rCtx.requestToken.RefreshId,
}

lastUsed := s.now()

ident := connector.Identity{
UserID: rCtx.storageToken.Claims.UserID,
Username: rCtx.storageToken.Claims.Username,
PreferredUsername: rCtx.storageToken.Claims.PreferredUsername,
Email: rCtx.storageToken.Claims.Email,
EmailVerified: rCtx.storageToken.Claims.EmailVerified,
Groups: rCtx.storageToken.Claims.Groups,
ConnectorData: rCtx.connectorData,
}

refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
if s.refreshTokenPolicy.RotationEnabled() {
if old.Token != token.Token {
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token {
newToken.Token = old.Token
// Do not update last used time for offline session if token is allowed to be reused
lastUsed = old.LastUsed
return old, nil
}
rotationEnabled := s.refreshTokenPolicy.RotationEnabled()
reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed)

switch {
case !rotationEnabled && reusingAllowed:
// If rotation is disabled and the offline session was updated not so long ago - skip further actions.
return old, nil

case rotationEnabled && reusingAllowed:
if old.Token != rCtx.requestToken.Token && old.ObsoleteToken != rCtx.requestToken.Token {
return old, errors.New("refresh token claimed twice")
}

// Return previously generated token for all requests with an obsolete tokens
if old.ObsoleteToken == rCtx.requestToken.Token {
newToken.Token = old.Token
}

// Do not update last used time for offline session if token is allowed to be reused
lastUsed = old.LastUsed
ident.ConnectorData = nil
return old, nil

case rotationEnabled && !reusingAllowed:
if old.Token != rCtx.requestToken.Token {
return old, errors.New("refresh token claimed twice")
}

// Issue new refresh token
old.ObsoleteToken = old.Token
newToken.Token = storage.NewID()
}

old.Token = newToken.Token
old.LastUsed = lastUsed

// ConnectorData has been moved to OfflineSession
old.ConnectorData = []byte{}

// Call only once if there is a request which is not in the reuse interval.
// This is required to avoid multiple calls to the external IdP for concurrent requests.
// Dex will call the connector's Refresh method only once if request is not in reuse interval.
ident, rerr = s.refreshWithConnector(ctx, rCtx, ident)
if rerr != nil {
return old, rerr
}

// Update the claims of the refresh token.
//
// UserID intentionally ignored for now.
Expand All @@ -252,26 +304,23 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora
old.Claims.Email = ident.Email
old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups
old.LastUsed = lastUsed

// ConnectorData has been moved to OfflineSession
old.ConnectorData = []byte{}
return old, nil
}

// Update refresh token in the storage.
err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater)
err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater)
if err != nil {
s.logger.Errorf("failed to update refresh token: %v", err)
return nil, newInternalServerError()
return nil, ident, newInternalServerError()
}

rerr := s.updateOfflineSession(refresh, ident, lastUsed)
rerr = s.updateOfflineSession(rCtx.storageToken, ident, lastUsed)
if rerr != nil {
return nil, rerr
return nil, ident, rerr
}

return newToken, nil
return newToken, ident, nil
}

// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6
Expand All @@ -283,19 +332,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return
}

refresh, rerr := s.getRefreshTokenFromStorage(client.ID, token)
rCtx, rerr := s.getRefreshTokenFromStorage(client.ID, token)
if rerr != nil {
s.refreshTokenErrHelper(w, rerr)
return
}

scopes, rerr := s.getRefreshScopes(r, refresh)
rCtx.scopes, rerr = s.getRefreshScopes(r, rCtx.storageToken)
if rerr != nil {
s.refreshTokenErrHelper(w, rerr)
return
}

ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes)
newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx)
if rerr != nil {
s.refreshTokenErrHelper(w, rerr)
return
Expand All @@ -310,26 +359,20 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Groups: ident.Groups,
}

accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
accessToken, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.refreshTokenErrHelper(w, newInternalServerError())
return
}

idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID)
idToken, expiry, err := s.newIDToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
s.refreshTokenErrHelper(w, newInternalServerError())
return
}

newToken, rerr := s.updateRefreshToken(token, refresh, ident)
if rerr != nil {
s.refreshTokenErrHelper(w, rerr)
return
}

rawNewToken, err := internal.Marshal(newToken)
if err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err)
Expand Down
Loading