diff --git a/docs/docs/configuration/alpha_config.md b/docs/docs/configuration/alpha_config.md index ac5d7d4fda..5def6d1402 100644 --- a/docs/docs/configuration/alpha_config.md +++ b/docs/docs/configuration/alpha_config.md @@ -445,6 +445,7 @@ Provider holds all configuration for a single provider | `allowedGroups` | _[]string_ | AllowedGroups is a list of restrict logins to members of this group | | `code_challenge_method` | _string_ | The code challenge method | | `backendLogoutURL` | _string_ | URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session | +| `backendLogoutAllSessionsURL` | _string_ | URL to call to perform backend logout, `{user_id}` would be replaced by the actual `user_id` if available in the session IntrospectClaims | ### ProviderType #### (`string` alias) diff --git a/oauthproxy.go b/oauthproxy.go index cfec6a1ec1..6798b3ecff 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -358,7 +358,16 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { // The userinfo and logout endpoints needs to load sessions before handling the request s.Path(userInfoPath).Handler(p.sessionChain.ThenFunc(p.UserInfo)) - s.Path(signOutPath).Handler(p.sessionChain.ThenFunc(p.SignOut)) + s.Path(signOutPath).Handler(p.sessionChain.ThenFunc( + func(w http.ResponseWriter, r *http.Request) { + p.SignOut(w, r, false) + }, + )) + s.Path(picsSignOutAllDevicesPath).Handler(p.sessionChain.ThenFunc( + func(w http.ResponseWriter, r *http.Request) { + p.SignOut(w, r, true) + }, + )) } // buildPreAuthChain constructs a chain that should process every request before @@ -758,7 +767,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { } // SignOut sends a response to clear the authentication cookie -func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { +func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request, signOutAllSessions bool) { redirect, err := p.appDirector.GetRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) @@ -772,12 +781,12 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { return } - p.backendLogout(rw, req) + p.backendLogout(rw, req, signOutAllSessions) http.Redirect(rw, req, redirect, http.StatusFound) } -func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) { +func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request, signOutAllSessions bool) { session, err := p.getAuthenticatedSession(rw, req) if err != nil { logger.Errorf("error getting authenticated session during backend logout: %v", err) @@ -789,22 +798,39 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) { } providerData := p.provider.Data() - if providerData.BackendLogoutURL == "" { - return - } + var resp *http.Response + if signOutAllSessions { + if providerData.BackendLogoutAllSessionsURL == "" { + return + } - backendLogoutURL := strings.ReplaceAll(providerData.BackendLogoutURL, "{id_token}", session.IDToken) - // security exception because URL is dynamic ({id_token} replacement) but - // base is not end-user provided but comes from configuration somewhat secure - resp, err := http.Get(backendLogoutURL) // #nosec G107 - if err != nil { - logger.Errorf("error while calling backend logout: %v", err) - return - } + resp, err := PicsSignOutAllSessions(providerData.BackendLogoutAllSessionsURL, session.IntrospectClaims, session.AccessToken) + if err != nil { + logger.Errorf("error while calling backend logout all sessions: %v", err) + return + } - defer resp.Body.Close() - if resp.StatusCode != 200 { - logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode) + if resp.StatusCode() != 200 { + logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode()) + } + } else { + if providerData.BackendLogoutURL == "" { + return + } + + backendLogoutURL := strings.ReplaceAll(providerData.BackendLogoutURL, "{id_token}", session.IDToken) + // security exception because URL is dynamic ({id_token} replacement) but + // base is not end-user provided but comes from configuration somewhat secure + resp, err = http.Get(backendLogoutURL) // #nosec G107 + if err != nil { + logger.Errorf("error while calling backend logout: %v", err) + return + } + + defer resp.Body.Close() + if resp.StatusCode != 200 { + logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode) + } } } diff --git a/pics_oauthproxy.go b/pics_oauthproxy.go new file mode 100644 index 0000000000..9464dc227a --- /dev/null +++ b/pics_oauthproxy.go @@ -0,0 +1,59 @@ +package main + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" +) + +const ( + picsSignOutAllDevicesPath = "/sign_out_all_sessions" +) + +func PicsSignOutAllSessions(backendLogoutAllSessionsURL string, introspectClaims string, accessToken string) (resp requests.Result, err error) { + userID, err := getUserID(introspectClaims) + if err != nil { + return nil, fmt.Errorf("error getting userID from instrospect claims: %v", err) + } + + backendLogoutURL := strings.ReplaceAll(backendLogoutAllSessionsURL, "{user_id}", userID) + resp = requests.New(backendLogoutURL). + WithMethod("POST"). + SetHeader("Authorization", "Bearer "+accessToken). + SetHeader("API-Version", "1"). + SetHeader("Accept", "application/json"). + Do() + + if resp.Error() != nil { + return nil, fmt.Errorf("error logging out from IAM: %v", err) + } + + return resp, err +} + +func getUserID(introspectClaims string) (string, error) { + decodedClaims, err := base64.StdEncoding.DecodeString(introspectClaims) + if err != nil { + logger.Errorf("error decoding claims: %v", err) + return "", err + } + + var claims map[string]interface{} + err = json.Unmarshal(decodedClaims, &claims) + if err != nil { + logger.Errorf("error unmarshalling claims: %v", err) + return "", err + } + + userID, ok := claims["sub"].(string) + if !ok { + logger.Errorf("error extracting 'sub' from claims") + return "", err + } + + return userID, nil +} diff --git a/pics_oauthproxy_test.go b/pics_oauthproxy_test.go new file mode 100644 index 0000000000..0990679479 --- /dev/null +++ b/pics_oauthproxy_test.go @@ -0,0 +1,55 @@ +package main + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func createIntrospectClaims() string { + claims := map[string]interface{}{ + "sub": "1234567890", + } + claimsBytes, err := json.Marshal(claims) + if err != nil { + return "" + } + + return base64.StdEncoding.EncodeToString(claimsBytes) +} + +func Test_PicsSignOutAllSessionsReturnsErrorWhenUserIDIsNotFound(t *testing.T) { + _, err := PicsSignOutAllSessions("http://localhost:8080/test", "", "") + + assert.Error(t, err) +} + +func Test_getUserID(t *testing.T) { + introspectClaims := createIntrospectClaims() + userID, err := getUserID(introspectClaims) + + assert.NoError(t, err) + assert.Equal(t, "1234567890", userID) +} + +func Test_PicsSignOutAllSessionsReturns200Ok(t *testing.T) { + introspectClaims := createIntrospectClaims() + accessToken := "validAccessToken" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer "+accessToken, r.Header.Get("Authorization")) + assert.Equal(t, "1", r.Header.Get("API-Version")) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + resp, err := PicsSignOutAllSessions(server.URL+"/{user_id}", introspectClaims, accessToken) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode()) +} diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index 3105528c8f..8fc5110efa 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -545,6 +545,8 @@ type LegacyProvider struct { AllowedRoles []string `flag:"allowed-role" cfg:"allowed_roles"` BackendLogoutURL string `flag:"backend-logout-url" cfg:"backend_logout_url"` + BackendLogoutAllSessionsURL string `flag:"backend-logout-all-sessions-url" cfg:"backend_logout_all_sessions_url"` + AcrValues string `flag:"acr-values" cfg:"acr_values"` JWTKey string `flag:"jwt-key" cfg:"jwt_key"` JWTKeyFile string `flag:"jwt-key-file" cfg:"jwt_key_file"` @@ -613,6 +615,7 @@ func legacyProviderFlagSet() *pflag.FlagSet { flagSet.StringSlice("allowed-group", []string{}, "restrict logins to members of this group (may be given multiple times)") flagSet.StringSlice("allowed-role", []string{}, "(keycloak-oidc) restrict logins to members of these roles (may be given multiple times)") flagSet.String("backend-logout-url", "", "url to perform a backend logout, {id_token} can be used as placeholder for the id_token") + flagSet.String("backend-logout-all-sessions-url", "", "url to perform a backend logout, {user_id} can be used as placeholder for the user_id") return flagSet } @@ -693,6 +696,8 @@ func (l *LegacyProvider) convert() (Providers, error) { AllowedGroups: l.AllowedGroups, CodeChallengeMethod: l.CodeChallengeMethod, BackendLogoutURL: l.BackendLogoutURL, + + BackendLogoutAllSessionsURL: l.BackendLogoutAllSessionsURL, } // This part is out of the switch section for all providers that support OIDC diff --git a/pkg/apis/options/providers.go b/pkg/apis/options/providers.go index a90b584c40..aefd3cc2c7 100644 --- a/pkg/apis/options/providers.go +++ b/pkg/apis/options/providers.go @@ -88,6 +88,9 @@ type Provider struct { // URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session BackendLogoutURL string `json:"backendLogoutURL"` + + // URL to call to perform backend logout, `{user_id}` would be replaced by the actual `user_id` if available in the session IntrospectClaims + BackendLogoutAllSessionsURL string `json:"backendLogoutAllSessionsURL"` } // ProviderType is used to enumerate the different provider type options diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index f861c756fa..92dbdb4e25 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -190,7 +190,8 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req if err := s.refreshSession(rw, req, session); err != nil { // If a preemptive refresh fails, we still keep the session // if validateSession succeeds. - logger.Errorf("Unable to refresh session: %v", err) + // PICS: We will clean the session if the refresh fails. + return fmt.Errorf("unable to refresh session: %v", err) } // Validate all sessions after any Redeem/Refresh operation (fail or success) diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 904c2028fa..2d3a6f669d 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -295,17 +295,12 @@ var _ = Describe("Stored Session Suite", func() { refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), - Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{ + Entry("when the provider refresh fails", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=RefreshError"}, }, existingSession: nil, - expectedSession: &sessionsapi.SessionState{ - RefreshToken: "RefreshError", - CreatedAt: &createdPast, - ExpiresOn: &createdFuture, - Lock: &sessionsapi.NoOpLock{}, - }, + expectedSession: nil, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, diff --git a/providers/provider_data.go b/providers/provider_data.go index a967f17d9a..4744f67433 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -60,6 +60,8 @@ type ProviderData struct { loginURLParameterOverrides map[string]*regexp.Regexp BackendLogoutURL string + + BackendLogoutAllSessionsURL string } // Data returns the ProviderData diff --git a/providers/providers.go b/providers/providers.go index 153db12e4e..1c950aaf63 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -163,6 +163,7 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, p.setAllowedGroups(providerConfig.AllowedGroups) p.BackendLogoutURL = providerConfig.BackendLogoutURL + p.BackendLogoutAllSessionsURL = providerConfig.BackendLogoutAllSessionsURL return p, nil }