Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sign out all sessions #60

Open
wants to merge 9 commits into
base: pics
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/docs/configuration/alpha_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ Provider holds all configuration for a single provider
| `allowedGroups` | _[]string_ | AllowedGroups is a list of restrict logins to members of this group |
| `code_challenge_method` | _string_ | The code challenge method |
| `backendLogoutURL` | _string_ | URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session |
| `backendLogoutAllSessionsURL` | _string_ | URL to call to perform backend logout, `{user_id}` would be replaced by the actual `user_id` if available in the session IntrospectClaims |

### ProviderType
#### (`string` alias)
Expand Down
62 changes: 44 additions & 18 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,16 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) {

// The userinfo and logout endpoints needs to load sessions before handling the request
s.Path(userInfoPath).Handler(p.sessionChain.ThenFunc(p.UserInfo))
s.Path(signOutPath).Handler(p.sessionChain.ThenFunc(p.SignOut))
s.Path(signOutPath).Handler(p.sessionChain.ThenFunc(
func(w http.ResponseWriter, r *http.Request) {
p.SignOut(w, r, false)
},
))
s.Path(picsSignOutAllDevicesPath).Handler(p.sessionChain.ThenFunc(
func(w http.ResponseWriter, r *http.Request) {
p.SignOut(w, r, true)
},
))
}

// buildPreAuthChain constructs a chain that should process every request before
Expand Down Expand Up @@ -758,7 +767,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) {
}

// SignOut sends a response to clear the authentication cookie
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request, signOutAllSessions bool) {
redirect, err := p.appDirector.GetRedirect(req)
if err != nil {
logger.Errorf("Error obtaining redirect: %v", err)
Expand All @@ -772,12 +781,12 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
return
}

p.backendLogout(rw, req)
p.backendLogout(rw, req, signOutAllSessions)

http.Redirect(rw, req, redirect, http.StatusFound)
}

func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) {
func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request, signOutAllSessions bool) {
session, err := p.getAuthenticatedSession(rw, req)
if err != nil {
logger.Errorf("error getting authenticated session during backend logout: %v", err)
Expand All @@ -789,22 +798,39 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) {
}

providerData := p.provider.Data()
if providerData.BackendLogoutURL == "" {
return
}
var resp *http.Response
if signOutAllSessions {
if providerData.BackendLogoutAllSessionsURL == "" {
return
}

backendLogoutURL := strings.ReplaceAll(providerData.BackendLogoutURL, "{id_token}", session.IDToken)
// security exception because URL is dynamic ({id_token} replacement) but
// base is not end-user provided but comes from configuration somewhat secure
resp, err := http.Get(backendLogoutURL) // #nosec G107
if err != nil {
logger.Errorf("error while calling backend logout: %v", err)
return
}
resp, err := PicsSignOutAllSessions(providerData.BackendLogoutAllSessionsURL, session.IntrospectClaims, session.AccessToken)
if err != nil {
logger.Errorf("error while calling backend logout all sessions: %v", err)
return
}

defer resp.Body.Close()
if resp.StatusCode != 200 {
logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode)
if resp.StatusCode() != 200 {
logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode())
}
} else {
if providerData.BackendLogoutURL == "" {
return
}

backendLogoutURL := strings.ReplaceAll(providerData.BackendLogoutURL, "{id_token}", session.IDToken)
// security exception because URL is dynamic ({id_token} replacement) but
// base is not end-user provided but comes from configuration somewhat secure
resp, err = http.Get(backendLogoutURL) // #nosec G107
if err != nil {
logger.Errorf("error while calling backend logout: %v", err)
return
}

defer resp.Body.Close()
if resp.StatusCode != 200 {
logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode)
}
}
}

Expand Down
59 changes: 59 additions & 0 deletions pics_oauthproxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package main

import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"

"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
)

const (
picsSignOutAllDevicesPath = "/sign_out_all_sessions"
)

func PicsSignOutAllSessions(backendLogoutAllSessionsURL string, introspectClaims string, accessToken string) (resp requests.Result, err error) {
userID, err := getUserID(introspectClaims)
if err != nil {
return nil, fmt.Errorf("error getting userID from instrospect claims: %v", err)
}

backendLogoutURL := strings.ReplaceAll(backendLogoutAllSessionsURL, "{user_id}", userID)
resp = requests.New(backendLogoutURL).
WithMethod("POST").
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("API-Version", "1").
SetHeader("Accept", "application/json").
Do()

if resp.Error() != nil {
return nil, fmt.Errorf("error logging out from IAM: %v", err)
}

return resp, err
}

func getUserID(introspectClaims string) (string, error) {
decodedClaims, err := base64.StdEncoding.DecodeString(introspectClaims)
if err != nil {
logger.Errorf("error decoding claims: %v", err)
return "", err
}

var claims map[string]interface{}
err = json.Unmarshal(decodedClaims, &claims)
if err != nil {
logger.Errorf("error unmarshalling claims: %v", err)
return "", err
}

userID, ok := claims["sub"].(string)
if !ok {
logger.Errorf("error extracting 'sub' from claims")
return "", err
}

return userID, nil
}
55 changes: 55 additions & 0 deletions pics_oauthproxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package main

import (
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func createIntrospectClaims() string {
claims := map[string]interface{}{
"sub": "1234567890",
}
claimsBytes, err := json.Marshal(claims)
if err != nil {
return ""
}

return base64.StdEncoding.EncodeToString(claimsBytes)
}

func Test_PicsSignOutAllSessionsReturnsErrorWhenUserIDIsNotFound(t *testing.T) {
_, err := PicsSignOutAllSessions("http://localhost:8080/test", "", "")

assert.Error(t, err)
}

func Test_getUserID(t *testing.T) {
introspectClaims := createIntrospectClaims()
userID, err := getUserID(introspectClaims)

assert.NoError(t, err)
assert.Equal(t, "1234567890", userID)
}

func Test_PicsSignOutAllSessionsReturns200Ok(t *testing.T) {
introspectClaims := createIntrospectClaims()
accessToken := "validAccessToken"

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "Bearer "+accessToken, r.Header.Get("Authorization"))
assert.Equal(t, "1", r.Header.Get("API-Version"))
assert.Equal(t, "application/json", r.Header.Get("Accept"))
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

resp, err := PicsSignOutAllSessions(server.URL+"/{user_id}", introspectClaims, accessToken)

assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode())
}
5 changes: 5 additions & 0 deletions pkg/apis/options/legacy_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ type LegacyProvider struct {
AllowedRoles []string `flag:"allowed-role" cfg:"allowed_roles"`
BackendLogoutURL string `flag:"backend-logout-url" cfg:"backend_logout_url"`

BackendLogoutAllSessionsURL string `flag:"backend-logout-all-sessions-url" cfg:"backend_logout_all_sessions_url"`

AcrValues string `flag:"acr-values" cfg:"acr_values"`
JWTKey string `flag:"jwt-key" cfg:"jwt_key"`
JWTKeyFile string `flag:"jwt-key-file" cfg:"jwt_key_file"`
Expand Down Expand Up @@ -613,6 +615,7 @@ func legacyProviderFlagSet() *pflag.FlagSet {
flagSet.StringSlice("allowed-group", []string{}, "restrict logins to members of this group (may be given multiple times)")
flagSet.StringSlice("allowed-role", []string{}, "(keycloak-oidc) restrict logins to members of these roles (may be given multiple times)")
flagSet.String("backend-logout-url", "", "url to perform a backend logout, {id_token} can be used as placeholder for the id_token")
flagSet.String("backend-logout-all-sessions-url", "", "url to perform a backend logout, {user_id} can be used as placeholder for the user_id")

return flagSet
}
Expand Down Expand Up @@ -693,6 +696,8 @@ func (l *LegacyProvider) convert() (Providers, error) {
AllowedGroups: l.AllowedGroups,
CodeChallengeMethod: l.CodeChallengeMethod,
BackendLogoutURL: l.BackendLogoutURL,

BackendLogoutAllSessionsURL: l.BackendLogoutAllSessionsURL,
}

// This part is out of the switch section for all providers that support OIDC
Expand Down
3 changes: 3 additions & 0 deletions pkg/apis/options/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ type Provider struct {

// URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session
BackendLogoutURL string `json:"backendLogoutURL"`

// URL to call to perform backend logout, `{user_id}` would be replaced by the actual `user_id` if available in the session IntrospectClaims
BackendLogoutAllSessionsURL string `json:"backendLogoutAllSessionsURL"`
}

// ProviderType is used to enumerate the different provider type options
Expand Down
3 changes: 2 additions & 1 deletion pkg/middleware/stored_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
if err := s.refreshSession(rw, req, session); err != nil {
// If a preemptive refresh fails, we still keep the session
// if validateSession succeeds.
logger.Errorf("Unable to refresh session: %v", err)
// PICS: We will clean the session if the refresh fails.
return fmt.Errorf("unable to refresh session: %v", err)
}

// Validate all sessions after any Redeem/Refresh operation (fail or success)
Expand Down
9 changes: 2 additions & 7 deletions pkg/middleware/stored_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,17 +295,12 @@ var _ = Describe("Stored Session Suite", func() {
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{
Entry("when the provider refresh fails", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "RefreshError",
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
Expand Down
2 changes: 2 additions & 0 deletions providers/provider_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ type ProviderData struct {
loginURLParameterOverrides map[string]*regexp.Regexp

BackendLogoutURL string

BackendLogoutAllSessionsURL string
}

// Data returns the ProviderData
Expand Down
1 change: 1 addition & 0 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData,
p.setAllowedGroups(providerConfig.AllowedGroups)

p.BackendLogoutURL = providerConfig.BackendLogoutURL
p.BackendLogoutAllSessionsURL = providerConfig.BackendLogoutAllSessionsURL

return p, nil
}
Expand Down
Loading