From 0da0c7e1998ecc4c60cb2638fce3b124e08d59f3 Mon Sep 17 00:00:00 2001 From: Justin Hines Date: Mon, 24 Jun 2019 20:36:08 -0400 Subject: [PATCH 1/2] proxy: add idp transition ux flow --- internal/pkg/sessions/session_state.go | 29 ++--------------- internal/pkg/sessions/session_state_test.go | 3 ++ internal/proxy/oauthproxy.go | 32 ++++++++++++++----- internal/proxy/oauthproxy_test.go | 35 +++++++++++++++------ internal/proxy/providers/sso.go | 3 ++ 5 files changed, 59 insertions(+), 43 deletions(-) diff --git a/internal/pkg/sessions/session_state.go b/internal/pkg/sessions/session_state.go index 532f9e3e..f6044d7e 100644 --- a/internal/pkg/sessions/session_state.go +++ b/internal/pkg/sessions/session_state.go @@ -2,9 +2,6 @@ package sessions import ( "errors" - "fmt" - "strconv" - "strings" "time" "github.com/buzzfeed/sso/internal/pkg/aead" @@ -17,6 +14,9 @@ var ( // SessionState is our object that keeps track of a user's session state type SessionState struct { + ProviderSlug string `json:"slug"` + ProviderType string `json:"type"` + AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` @@ -73,26 +73,3 @@ func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) { func ExtendDeadline(ttl time.Duration) time.Time { return time.Now().Add(ttl).Truncate(time.Second) } - -// NewSessionState creates a new session state -// TODO: remove this file when we transition out of backup using the payloads encryption -func NewSessionState(value string, lifetimeTTL time.Duration) (*SessionState, error) { - parts := strings.Split(value, "|") - if len(parts) != 4 { - err := fmt.Errorf("invalid number of fields (got %d expected 4)", len(parts)) - return nil, err - } - - ts, err := strconv.Atoi(parts[2]) - if err != nil { - return nil, err - } - - return &SessionState{ - Email: parts[0], - AccessToken: parts[1], - RefreshDeadline: time.Unix(int64(ts), 0), - RefreshToken: parts[3], - LifetimeDeadline: ExtendDeadline(lifetimeTTL), - }, nil -} diff --git a/internal/pkg/sessions/session_state_test.go b/internal/pkg/sessions/session_state_test.go index 975ddb2c..9db8d607 100644 --- a/internal/pkg/sessions/session_state_test.go +++ b/internal/pkg/sessions/session_state_test.go @@ -16,6 +16,9 @@ func TestSessionStateSerialization(t *testing.T) { } want := &SessionState{ + ProviderSlug: "slug", + ProviderType: "sso", + AccessToken: "token1234", RefreshToken: "refresh4321", diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index 80f6d2a5..4018a784 100644 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -38,8 +38,9 @@ var SignatureHeaders = []string{ // Errors var ( - ErrLifetimeExpired = errors.New("user lifetime expired") - ErrUserNotAuthorized = errors.New("user not authorized") + ErrLifetimeExpired = errors.New("user lifetime expired") + ErrUserNotAuthorized = errors.New("user not authorized") + ErrWrongIdentityProvider = errors.New("user authenticated with wrong identity provider") ) type ErrOAuthProxyMisconfigured struct { @@ -655,16 +656,16 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // No cookie is set, start the oauth flow p.OAuthStart(rw, req, tags) return - case ErrUserNotAuthorized: - tags = append(tags, "error:user_unauthorized") - p.StatsdClient.Incr("application_error", tags, 1.0) - // We know the user is not authorized for the request, we show them a forbidden page - p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page") - return case ErrLifetimeExpired: // User's lifetime expired, we trigger the start of the oauth flow p.OAuthStart(rw, req, tags) return + case ErrWrongIdentityProvider: + // User is authenticated with the incorrect provider. This most common non-malicious + // case occurs when an upstream has been transitioned to a different provider but + // the user has a stale sesssion. + p.OAuthStart(rw, req, tags) + return case sessions.ErrInvalidSession: // The user session is invalid and we can't decode it. // This can happen for a variety of reasons but the most common non-malicious @@ -672,6 +673,12 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // by triggering the start of the oauth flow. p.OAuthStart(rw, req, tags) return + case ErrUserNotAuthorized: + tags = append(tags, "error:user_unauthorized") + p.StatsdClient.Incr("application_error", tags, 1.0) + // We know the user is not authorized for the request, we show them a forbidden page + p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page") + return default: logger.Error(err, "unknown error authenticating user") tags = append(tags, "error:internal_error") @@ -709,6 +716,15 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er return err } + // check if this session belongs to the correct identity provider application. + // this case exists primarly to allow us to gracefully manage a clean ux during + // transitions from one provider to another by gracefully restarting the authentication process. + if session.ProviderSlug != p.provider.Data().ProviderSlug { + logger.WithUser(session.Email).Info( + "authenticated with incorrect identity provider; restarting authentication") + return ErrWrongIdentityProvider + } + // Lifetime period is the entire duration in which the session is valid. // This should be set to something like 14 to 30 days. if session.LifetimePeriodExpired() { diff --git a/internal/proxy/oauthproxy_test.go b/internal/proxy/oauthproxy_test.go index 554f5aea..a9c6dc31 100644 --- a/internal/proxy/oauthproxy_test.go +++ b/internal/proxy/oauthproxy_test.go @@ -263,13 +263,15 @@ func TestAuthOnlyEndpoint(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + providerURL, _ := url.Parse("http://localhost/") + tp := providers.NewTestProvider(providerURL, "") + tp.RefreshSessionFunc = func(*sessions.SessionState, []string) (bool, error) { return true, nil } + tp.ValidateSessionFunc = func(*sessions.SessionState, []string) bool { return true } + proxy, close := testNewOAuthProxy(t, setSessionStore(tc.sessionStore), setValidator(func(_ string) bool { return tc.validEmail }), - SetProvider(&providers.TestProvider{ - RefreshSessionFunc: func(*sessions.SessionState, []string) (bool, error) { return true, nil }, - ValidateSessionFunc: func(*sessions.SessionState, []string) bool { return true }, - }), + SetProvider(tp), ) defer close() @@ -571,16 +573,31 @@ func TestAuthenticate(t *testing.T) { CookieExpectation: NewCookie, ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return true }, }, + { + Name: "wrong identity provider, user OK, do not authenticate", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + ProviderSlug: "example", + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + }, + }, + ExpectedErr: ErrWrongIdentityProvider, + CookieExpectation: ClearCookie, + }, } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { - provider := &providers.TestProvider{ - RefreshSessionFunc: tc.RefreshSessionFunc, - ValidateSessionFunc: tc.ValidateSessionFunc, - } + providerURL, _ := url.Parse("http://localhost/") + tp := providers.NewTestProvider(providerURL, "") + tp.RefreshSessionFunc = tc.RefreshSessionFunc + tp.ValidateSessionFunc = tc.ValidateSessionFunc proxy, close := testNewOAuthProxy(t, - SetProvider(provider), + SetProvider(tp), setSessionStore(tc.SessionStore), ) defer close() diff --git a/internal/proxy/providers/sso.go b/internal/proxy/providers/sso.go index a6f97c52..44eaff27 100644 --- a/internal/proxy/providers/sso.go +++ b/internal/proxy/providers/sso.go @@ -158,6 +158,9 @@ func (p *SSOProvider) Redeem(redirectURL, code string) (*sessions.SessionState, user := strings.Split(jsonResponse.Email, "@")[0] return &sessions.SessionState{ + ProviderSlug: p.ProviderData.ProviderSlug, + ProviderType: "sso", + AccessToken: jsonResponse.AccessToken, RefreshToken: jsonResponse.RefreshToken, From 0441b455bcccf9c508d89c5ea6c85a5a0eb611de Mon Sep 17 00:00:00 2001 From: Justin Hines Date: Tue, 25 Jun 2019 15:16:23 -0400 Subject: [PATCH 2/2] proxy: add tests around ux flows --- internal/proxy/oauthproxy_test.go | 188 ++++++++++++++++++++++++++++++ 1 file changed, 188 insertions(+) diff --git a/internal/proxy/oauthproxy_test.go b/internal/proxy/oauthproxy_test.go index a9c6dc31..b565d09d 100644 --- a/internal/proxy/oauthproxy_test.go +++ b/internal/proxy/oauthproxy_test.go @@ -625,6 +625,194 @@ func TestAuthenticate(t *testing.T) { } } +func TestAuthenticationUXFlows(t *testing.T) { + var ( + ErrRefreshFailed = errors.New("refresh failed") + LoadCookieFailed = errors.New("load cookie fail") + SaveCookieFailed = errors.New("save cookie fail") + ) + testCases := []struct { + Name string + + SessionStore *sessions.MockSessionStore + RefreshSessionFunc func(*sessions.SessionState, []string) (bool, error) + ValidateSessionFunc func(*sessions.SessionState, []string) bool + + ExpectStatusCode int + }{ + { + Name: "missing deadlines, redirect to sign-in", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + Email: "email1@example.com", + AccessToken: "my_access_token", + }, + }, + ExpectStatusCode: http.StatusFound, + }, + { + Name: "session unmarshaling fails, show error", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{}, + LoadError: LoadCookieFailed, + }, + ExpectStatusCode: http.StatusInternalServerError, + }, + { + Name: "authenticate successfully, expect ok", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + }, + }, + ExpectStatusCode: http.StatusOK, + }, + { + Name: "lifetime expired, redirect to sign-in", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + }, + }, + ExpectStatusCode: http.StatusFound, + }, + { + Name: "refresh expired, refresh fails, show error", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + }, + }, + RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, ErrRefreshFailed }, + ExpectStatusCode: http.StatusInternalServerError, + }, + { + Name: "refresh expired, user not OK, deny", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + }, + }, + RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, nil }, + ExpectStatusCode: http.StatusForbidden, + }, + { + Name: "refresh expired, user OK, expect ok", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + }, + }, + RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil }, + ExpectStatusCode: http.StatusOK, + }, + { + Name: "refresh expired, refresh and user OK, error saving session, show error", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + }, + SaveError: SaveCookieFailed, + }, + RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil }, + ExpectStatusCode: http.StatusInternalServerError, + }, + { + Name: "validation expired, user not OK, deny", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), + }, + }, + ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return false }, + ExpectStatusCode: http.StatusForbidden, + }, + { + Name: "validation expired, user OK, expect ok", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), + }, + }, + ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return true }, + ExpectStatusCode: http.StatusOK, + }, + { + Name: "wrong identity provider, redirect to sign-in", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + ProviderSlug: "example", + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + }, + }, + ExpectStatusCode: http.StatusFound, + }, + } + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + providerURL, _ := url.Parse("http://localhost/") + tp := providers.NewTestProvider(providerURL, "") + tp.RefreshSessionFunc = tc.RefreshSessionFunc + tp.ValidateSessionFunc = tc.ValidateSessionFunc + + proxy, close := testNewOAuthProxy(t, + SetProvider(tp), + setSessionStore(tc.SessionStore), + ) + defer close() + + req := httptest.NewRequest("GET", "https://localhost", nil) + rw := httptest.NewRecorder() + + proxy.Proxy(rw, req) + + res := rw.Result() + + if tc.ExpectStatusCode != res.StatusCode { + t.Errorf("have: %v", res.StatusCode) + t.Errorf("want: %v", tc.ExpectStatusCode) + t.Fatalf("expected status codes to be equal") + } + }) + } +} + func TestProxyXHRErrorHandling(t *testing.T) { testCases := []struct { Name string