diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index b5e075add1..67db78ead1 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -3,14 +3,16 @@ package oidc import ( "context" + "encoding/json" "errors" "fmt" "net/http" "net/url" "strings" "sync" + "time" - "github.com/coreos/go-oidc" + oidc "github.com/coreos/go-oidc" "golang.org/x/oauth2" "github.com/dexidp/dex/connector" @@ -60,6 +62,10 @@ var brokenAuthHeaderDomains = []string{ "oktapreview.com", } +type connectorData struct { + RefreshToken []byte +} + // Detect auth header provider issues for known providers. This lets users // avoid having to explicitly set "basicAuthUnsupported" in their config. // @@ -167,14 +173,20 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } + var opts []oauth2.AuthCodeOption if len(c.hostedDomains) > 0 { preferredDomain := c.hostedDomains[0] if len(c.hostedDomains) > 1 { preferredDomain = "*" } - return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil + //return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil + opts = append(opts, oauth2.SetAuthURLParam("hd", preferredDomain)) } - return c.oauth2Config.AuthCodeURL(state), nil + if s.OfflineAccess { + opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) + } + + return c.oauth2Config.AuthCodeURL(state, opts...), nil } type oauth2Error struct { @@ -198,30 +210,66 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide if err != nil { return identity, fmt.Errorf("oidc: failed to get token: %v", err) } + return c.createIdentity(r.Context(), identity, token) +} + +func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { return identity, errors.New("oidc: no id_token in token response") } - idToken, err := c.verifier.Verify(r.Context(), rawIDToken) + idToken, err := c.verifier.Verify(ctx, rawIDToken) if err != nil { return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) } var claims map[string]interface{} + + var claimsMapped struct { + Username string `json:"name"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + HostedDomain string `json:"hd"` + Groups []string `json:"groups"` + } + if err := idToken.Claims(&claims); err != nil { return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) } + if err := idToken.Claims(&claimsMapped); err != nil { + return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) + } + // We immediately want to run getUserInfo if configured before we validate the claims if c.getUserInfo { - userInfo, err := c.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token)) + userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token)) + c.logger.Debugf("Got userinfo: %v", userInfo.Claims) if err != nil { return identity, fmt.Errorf("oidc: error loading userinfo: %v", err) } if err := userInfo.Claims(&claims); err != nil { return identity, fmt.Errorf("oidc: failed to decode userinfo claims: %v", err) } + + if err := userInfo.Claims(&claimsMapped); err != nil { + return identity, fmt.Errorf("oidc: failed to decode userinfo claims: %v", err) + } + + } + + c.logger.Debugf("Claims %v", claimsMapped.Groups) + + cd := connectorData{ + RefreshToken: []byte(token.RefreshToken), + } + + c.logger.Debugf("refresh token: %v", string(cd.RefreshToken)) + + connData, err := json.Marshal(&cd) + if err != nil { + return identity, fmt.Errorf("oidc: failed to encode connector data: %v", err) } userNameKey := "name" @@ -265,6 +313,8 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide Username: name, Email: email, EmailVerified: emailVerified, + Groups: claimsMapped.Groups, + ConnectorData: connData, } if c.userIDKey != "" { @@ -280,5 +330,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide // Refresh is implemented for backwards compatibility, even though it's a no-op. func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { - return identity, nil + cd := connectorData{} + err := json.Unmarshal(identity.ConnectorData, &cd) + if err != nil { + return identity, fmt.Errorf("oidc: failed to unmarshal connector data: %v", err) + } + + t := &oauth2.Token{ + RefreshToken: string(cd.RefreshToken), + Expiry: time.Now().Add(-time.Hour), + } + + token, err := c.oauth2Config.TokenSource(ctx, t).Token() + if err != nil { + return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err) + } + + return c.createIdentity(ctx, identity, token) } diff --git a/server/handlers.go b/server/handlers.go index e70206709d..00fbcc97e0 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -505,7 +505,50 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q", authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups) - return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil + //return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil + + returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + _, ok := conn.(connector.RefreshConnector) + if !ok { + return returnURL, nil + } + + // Try to retrieve an existing OfflineSession object for the corresponding user. + if session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID); err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + return "", err + } + s.logger.Debugf("Getting offline session for %s", identity.UserID) + offlineSessions := storage.OfflineSessions{ + UserID: identity.UserID, + ConnID: authReq.ConnectorID, + Refresh: make(map[string]*storage.RefreshTokenRef), + ConnectorData: identity.ConnectorData, + } + + // Create a new OfflineSession object for the user and add a reference object for + // the newly received refreshtoken. + if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil { + s.logger.Errorf("failed to create offline session: %v", err) + return "", err + } + s.logger.Debugf("Creating OfflineSession for %s ", identity.UserID) + } else { + // Update existing OfflineSession obj with new RefreshTokenRef. + if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if len(identity.ConnectorData) > 0 { + old.ConnectorData = identity.ConnectorData + } + return old, nil + }); err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + return "", err + } + } + + return returnURL, nil + } func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { @@ -962,6 +1005,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie scopes = requestedScopes } + var connectorData []byte + if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + return + } + } else if len(refresh.ConnectorData) > 0 { + // Use the old connector data if it exists, should be deleted once used + connectorData = refresh.ConnectorData + } else { + connectorData = session.ConnectorData + } + conn, err := s.getConnector(refresh.ConnectorID) if err != nil { s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) @@ -974,7 +1030,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Email: refresh.Claims.Email, EmailVerified: refresh.Claims.EmailVerified, Groups: refresh.Claims.Groups, - ConnectorData: refresh.ConnectorData, + ConnectorData: connectorData, } // Can the connector refresh the identity? If so, attempt to refresh the data @@ -1039,8 +1095,11 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie old.Claims.Email = ident.Email old.Claims.EmailVerified = ident.EmailVerified old.Claims.Groups = ident.Groups - old.ConnectorData = ident.ConnectorData + //old.ConnectorData = ident.ConnectorData old.LastUsed = lastUsed + + //ConnectorData has been moved to OfflineSession + old.ConnectorData = []byte{} return old, nil } @@ -1051,6 +1110,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return old, errors.New("refresh token invalid") } old.Refresh[refresh.ClientID].LastUsed = lastUsed + old.ConnectorData = ident.ConnectorData return old, nil }); err != nil { s.logger.Errorf("failed to update offline session: %v", err) diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index a13998077c..75b8db9ed2 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -518,6 +518,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { UserID: userID1, ConnID: "Conn1", Refresh: make(map[string]*storage.RefreshTokenRef), + ConnectorData: []byte(`{"some":"data"}`), } // Creating an OfflineSession with an empty Refresh list to ensure that @@ -535,6 +536,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { UserID: userID2, ConnID: "Conn2", Refresh: make(map[string]*storage.RefreshTokenRef), + ConnectorData: []byte(`{"some":"data"}`), } if err := s.CreateOfflineSessions(session2); err != nil { diff --git a/storage/etcd/types.go b/storage/etcd/types.go index 8063c69f59..a16eae8e94 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -188,24 +188,27 @@ type Keys struct { // OfflineSessions is a mirrored struct from storage with JSON struct tags type OfflineSessions struct { - UserID string `json:"user_id,omitempty"` - ConnID string `json:"conn_id,omitempty"` - Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` + UserID string `json:"user_id,omitempty"` + ConnID string `json:"conn_id,omitempty"` + Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` + ConnectorData []byte `json:"connectorData,omitempty"` } func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { return OfflineSessions{ - UserID: o.UserID, - ConnID: o.ConnID, - Refresh: o.Refresh, + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + ConnectorData: o.ConnectorData, } } func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { s := storage.OfflineSessions{ - UserID: o.UserID, - ConnID: o.ConnID, - Refresh: o.Refresh, + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + ConnectorData: o.ConnectorData, } if s.Refresh == nil { // Server code assumes this will be non-nil. diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index a42238b38c..5eda178111 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -552,9 +552,10 @@ type OfflineSessions struct { k8sapi.TypeMeta `json:",inline"` k8sapi.ObjectMeta `json:"metadata,omitempty"` - UserID string `json:"userID,omitempty"` - ConnID string `json:"connID,omitempty"` - Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` + UserID string `json:"userID,omitempty"` + ConnID string `json:"connID,omitempty"` + Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` + ConnectorData []byte `json:"connectorData,omitempty"` } func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { @@ -567,17 +568,19 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline Name: cli.offlineTokenName(o.UserID, o.ConnID), Namespace: cli.namespace, }, - UserID: o.UserID, - ConnID: o.ConnID, - Refresh: o.Refresh, + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + ConnectorData: o.ConnectorData, } } func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { s := storage.OfflineSessions{ - UserID: o.UserID, - ConnID: o.ConnID, - Refresh: o.Refresh, + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + ConnectorData: o.ConnectorData, } if s.Refresh == nil { // Server code assumes this will be non-nil. diff --git a/storage/storage.go b/storage/storage.go index 235f74e07a..b23e0f911a 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -272,7 +272,8 @@ type OfflineSessions struct { // Refresh is a hash table of refresh token reference objects // indexed by the ClientID of the refresh token. - Refresh map[string]*RefreshTokenRef + Refresh map[string]*RefreshTokenRef + ConnectorData []byte } // Password is an email to password mapping managed by the storage.