Skip to content

Commit

Permalink
Adds Groups claims and OIDC refresh tokens
Browse files Browse the repository at this point in the history
This builds on the terrific work in https://github.com/dexidp/dex/pull/1180/files
and dexidp#1065.

This makes some minor changes that bring the approach up-to-date with
current dex versions.
  • Loading branch information
dskatz committed Nov 6, 2019
1 parent b7184be commit 4596f65
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 28 deletions.
78 changes: 72 additions & 6 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ package oidc

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/coreos/go-oidc"
oidc "github.com/coreos/go-oidc"
"golang.org/x/oauth2"

"github.com/dexidp/dex/connector"
Expand Down Expand Up @@ -60,6 +62,10 @@ var brokenAuthHeaderDomains = []string{
"oktapreview.com",
}

type connectorData struct {
RefreshToken []byte
}

// Detect auth header provider issues for known providers. This lets users
// avoid having to explicitly set "basicAuthUnsupported" in their config.
//
Expand Down Expand Up @@ -167,14 +173,20 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}

var opts []oauth2.AuthCodeOption
if len(c.hostedDomains) > 0 {
preferredDomain := c.hostedDomains[0]
if len(c.hostedDomains) > 1 {
preferredDomain = "*"
}
return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil
//return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil
opts = append(opts, oauth2.SetAuthURLParam("hd", preferredDomain))
}
return c.oauth2Config.AuthCodeURL(state), nil
if s.OfflineAccess {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
}

return c.oauth2Config.AuthCodeURL(state, opts...), nil
}

type oauth2Error struct {
Expand All @@ -198,30 +210,66 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}
return c.createIdentity(r.Context(), identity, token)
}

func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) {

rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return identity, errors.New("oidc: no id_token in token response")
}
idToken, err := c.verifier.Verify(r.Context(), rawIDToken)
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}

var claims map[string]interface{}

var claimsMapped struct {
Username string `json:"name"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
HostedDomain string `json:"hd"`
Groups []string `json:"groups"`
}

if err := idToken.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}

if err := idToken.Claims(&claimsMapped); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}

// We immediately want to run getUserInfo if configured before we validate the claims
if c.getUserInfo {
userInfo, err := c.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token))
userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
c.logger.Debugf("Got userinfo: %v", userInfo.Claims)
if err != nil {
return identity, fmt.Errorf("oidc: error loading userinfo: %v", err)
}
if err := userInfo.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode userinfo claims: %v", err)
}

if err := userInfo.Claims(&claimsMapped); err != nil {
return identity, fmt.Errorf("oidc: failed to decode userinfo claims: %v", err)
}

}

c.logger.Debugf("Claims %v", claimsMapped.Groups)

cd := connectorData{
RefreshToken: []byte(token.RefreshToken),
}

c.logger.Debugf("refresh token: %v", string(cd.RefreshToken))

connData, err := json.Marshal(&cd)
if err != nil {
return identity, fmt.Errorf("oidc: failed to encode connector data: %v", err)
}

userNameKey := "name"
Expand Down Expand Up @@ -265,6 +313,8 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
Username: name,
Email: email,
EmailVerified: emailVerified,
Groups: claimsMapped.Groups,
ConnectorData: connData,
}

if c.userIDKey != "" {
Expand All @@ -280,5 +330,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide

// Refresh is implemented for backwards compatibility, even though it's a no-op.
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
return identity, nil
cd := connectorData{}
err := json.Unmarshal(identity.ConnectorData, &cd)
if err != nil {
return identity, fmt.Errorf("oidc: failed to unmarshal connector data: %v", err)
}

t := &oauth2.Token{
RefreshToken: string(cd.RefreshToken),
Expiry: time.Now().Add(-time.Hour),
}

token, err := c.oauth2Config.TokenSource(ctx, t).Token()
if err != nil {
return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err)
}

return c.createIdentity(ctx, identity, token)
}
66 changes: 63 additions & 3 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,50 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth
s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q",
authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups)

return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil
//return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil

returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID
_, ok := conn.(connector.RefreshConnector)
if !ok {
return returnURL, nil
}

// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return "", err
}
s.logger.Debugf("Getting offline session for %s", identity.UserID)
offlineSessions := storage.OfflineSessions{
UserID: identity.UserID,
ConnID: authReq.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: identity.ConnectorData,
}

// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
return "", err
}
s.logger.Debugf("Creating OfflineSession for %s ", identity.UserID)
} else {
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if len(identity.ConnectorData) > 0 {
old.ConnectorData = identity.ConnectorData
}
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
return "", err
}
}

return returnURL, nil

}

func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -962,6 +1005,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
scopes = requestedScopes
}

var connectorData []byte
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return
}
} else if len(refresh.ConnectorData) > 0 {
// Use the old connector data if it exists, should be deleted once used
connectorData = refresh.ConnectorData
} else {
connectorData = session.ConnectorData
}

conn, err := s.getConnector(refresh.ConnectorID)
if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
Expand All @@ -974,7 +1030,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Email: refresh.Claims.Email,
EmailVerified: refresh.Claims.EmailVerified,
Groups: refresh.Claims.Groups,
ConnectorData: refresh.ConnectorData,
ConnectorData: connectorData,
}

// Can the connector refresh the identity? If so, attempt to refresh the data
Expand Down Expand Up @@ -1039,8 +1095,11 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
old.Claims.Email = ident.Email
old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups
old.ConnectorData = ident.ConnectorData
//old.ConnectorData = ident.ConnectorData
old.LastUsed = lastUsed

//ConnectorData has been moved to OfflineSession
old.ConnectorData = []byte{}
return old, nil
}

Expand All @@ -1051,6 +1110,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return old, errors.New("refresh token invalid")
}
old.Refresh[refresh.ClientID].LastUsed = lastUsed
old.ConnectorData = ident.ConnectorData
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
Expand Down
2 changes: 2 additions & 0 deletions storage/conformance/conformance.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
UserID: userID1,
ConnID: "Conn1",
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
}

// Creating an OfflineSession with an empty Refresh list to ensure that
Expand All @@ -535,6 +536,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
UserID: userID2,
ConnID: "Conn2",
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
}

if err := s.CreateOfflineSessions(session2); err != nil {
Expand Down
21 changes: 12 additions & 9 deletions storage/etcd/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,27 @@ type Keys struct {

// OfflineSessions is a mirrored struct from storage with JSON struct tags
type OfflineSessions struct {
UserID string `json:"user_id,omitempty"`
ConnID string `json:"conn_id,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
UserID string `json:"user_id,omitempty"`
ConnID string `json:"conn_id,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
}

func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
return OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
}

func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
s := storage.OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
if s.Refresh == nil {
// Server code assumes this will be non-nil.
Expand Down
21 changes: 12 additions & 9 deletions storage/kubernetes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,10 @@ type OfflineSessions struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`

UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
}

func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
Expand All @@ -567,17 +568,19 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline
Name: cli.offlineTokenName(o.UserID, o.ConnID),
Namespace: cli.namespace,
},
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
}

func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
s := storage.OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
if s.Refresh == nil {
// Server code assumes this will be non-nil.
Expand Down
3 changes: 2 additions & 1 deletion storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ type OfflineSessions struct {

// Refresh is a hash table of refresh token reference objects
// indexed by the ClientID of the refresh token.
Refresh map[string]*RefreshTokenRef
Refresh map[string]*RefreshTokenRef
ConnectorData []byte
}

// Password is an email to password mapping managed by the storage.
Expand Down

0 comments on commit 4596f65

Please sign in to comment.