Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy: transition idps ux flow #218

Merged
merged 2 commits into from
Jun 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions internal/pkg/sessions/session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package sessions

import (
"errors"
"fmt"
"strconv"
"strings"
"time"

"github.com/buzzfeed/sso/internal/pkg/aead"
Expand All @@ -17,6 +14,9 @@ var (

// SessionState is our object that keeps track of a user's session state
type SessionState struct {
ProviderSlug string `json:"slug"`
ProviderType string `json:"type"`

AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`

Expand Down Expand Up @@ -73,26 +73,3 @@ func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) {
func ExtendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}

// NewSessionState creates a new session state
// TODO: remove this file when we transition out of backup using the payloads encryption
func NewSessionState(value string, lifetimeTTL time.Duration) (*SessionState, error) {
parts := strings.Split(value, "|")
if len(parts) != 4 {
err := fmt.Errorf("invalid number of fields (got %d expected 4)", len(parts))
return nil, err
}

ts, err := strconv.Atoi(parts[2])
if err != nil {
return nil, err
}

return &SessionState{
Email: parts[0],
AccessToken: parts[1],
RefreshDeadline: time.Unix(int64(ts), 0),
RefreshToken: parts[3],
LifetimeDeadline: ExtendDeadline(lifetimeTTL),
}, nil
}
3 changes: 3 additions & 0 deletions internal/pkg/sessions/session_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ func TestSessionStateSerialization(t *testing.T) {
}

want := &SessionState{
ProviderSlug: "slug",
ProviderType: "sso",

AccessToken: "token1234",
RefreshToken: "refresh4321",

Expand Down
32 changes: 24 additions & 8 deletions internal/proxy/oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ var SignatureHeaders = []string{

// Errors
var (
ErrLifetimeExpired = errors.New("user lifetime expired")
ErrUserNotAuthorized = errors.New("user not authorized")
ErrLifetimeExpired = errors.New("user lifetime expired")
ErrUserNotAuthorized = errors.New("user not authorized")
ErrWrongIdentityProvider = errors.New("user authenticated with wrong identity provider")
)

type ErrOAuthProxyMisconfigured struct {
Expand Down Expand Up @@ -655,23 +656,29 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
// No cookie is set, start the oauth flow
p.OAuthStart(rw, req, tags)
return
case ErrUserNotAuthorized:
tags = append(tags, "error:user_unauthorized")
p.StatsdClient.Incr("application_error", tags, 1.0)
// We know the user is not authorized for the request, we show them a forbidden page
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
return
case ErrLifetimeExpired:
// User's lifetime expired, we trigger the start of the oauth flow
p.OAuthStart(rw, req, tags)
return
case ErrWrongIdentityProvider:
// User is authenticated with the incorrect provider. This most common non-malicious
// case occurs when an upstream has been transitioned to a different provider but
// the user has a stale sesssion.
p.OAuthStart(rw, req, tags)
return
case sessions.ErrInvalidSession:
// The user session is invalid and we can't decode it.
// This can happen for a variety of reasons but the most common non-malicious
// case occurs when the session encoding schema changes. We manage this ux
// by triggering the start of the oauth flow.
p.OAuthStart(rw, req, tags)
return
case ErrUserNotAuthorized:
tags = append(tags, "error:user_unauthorized")
p.StatsdClient.Incr("application_error", tags, 1.0)
// We know the user is not authorized for the request, we show them a forbidden page
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
return
default:
logger.Error(err, "unknown error authenticating user")
tags = append(tags, "error:internal_error")
Expand Down Expand Up @@ -709,6 +716,15 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
return err
}

// check if this session belongs to the correct identity provider application.
// this case exists primarly to allow us to gracefully manage a clean ux during
// transitions from one provider to another by gracefully restarting the authentication process.
if session.ProviderSlug != p.provider.Data().ProviderSlug {
logger.WithUser(session.Email).Info(
"authenticated with incorrect identity provider; restarting authentication")
return ErrWrongIdentityProvider
}

// Lifetime period is the entire duration in which the session is valid.
// This should be set to something like 14 to 30 days.
if session.LifetimePeriodExpired() {
Expand Down
223 changes: 214 additions & 9 deletions internal/proxy/oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,15 @@ func TestAuthOnlyEndpoint(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
providerURL, _ := url.Parse("http://localhost/")
tp := providers.NewTestProvider(providerURL, "")
tp.RefreshSessionFunc = func(*sessions.SessionState, []string) (bool, error) { return true, nil }
tp.ValidateSessionFunc = func(*sessions.SessionState, []string) bool { return true }

proxy, close := testNewOAuthProxy(t,
setSessionStore(tc.sessionStore),
setValidator(func(_ string) bool { return tc.validEmail }),
SetProvider(&providers.TestProvider{
RefreshSessionFunc: func(*sessions.SessionState, []string) (bool, error) { return true, nil },
ValidateSessionFunc: func(*sessions.SessionState, []string) bool { return true },
}),
SetProvider(tp),
)
defer close()

Expand Down Expand Up @@ -571,16 +573,31 @@ func TestAuthenticate(t *testing.T) {
CookieExpectation: NewCookie,
ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return true },
},
{
Name: "wrong identity provider, user OK, do not authenticate",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
ProviderSlug: "example",
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
},
},
ExpectedErr: ErrWrongIdentityProvider,
CookieExpectation: ClearCookie,
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
provider := &providers.TestProvider{
RefreshSessionFunc: tc.RefreshSessionFunc,
ValidateSessionFunc: tc.ValidateSessionFunc,
}
providerURL, _ := url.Parse("http://localhost/")
tp := providers.NewTestProvider(providerURL, "")
tp.RefreshSessionFunc = tc.RefreshSessionFunc
tp.ValidateSessionFunc = tc.ValidateSessionFunc

proxy, close := testNewOAuthProxy(t,
SetProvider(provider),
SetProvider(tp),
setSessionStore(tc.SessionStore),
)
defer close()
Expand Down Expand Up @@ -608,6 +625,194 @@ func TestAuthenticate(t *testing.T) {
}
}

func TestAuthenticationUXFlows(t *testing.T) {
var (
ErrRefreshFailed = errors.New("refresh failed")
LoadCookieFailed = errors.New("load cookie fail")
SaveCookieFailed = errors.New("save cookie fail")
)
testCases := []struct {
Name string

SessionStore *sessions.MockSessionStore
RefreshSessionFunc func(*sessions.SessionState, []string) (bool, error)
ValidateSessionFunc func(*sessions.SessionState, []string) bool

ExpectStatusCode int
}{
{
Name: "missing deadlines, redirect to sign-in",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
Email: "email1@example.com",
AccessToken: "my_access_token",
},
},
ExpectStatusCode: http.StatusFound,
},
{
Name: "session unmarshaling fails, show error",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{},
LoadError: LoadCookieFailed,
},
ExpectStatusCode: http.StatusInternalServerError,
},
{
Name: "authenticate successfully, expect ok",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
},
},
ExpectStatusCode: http.StatusOK,
},
{
Name: "lifetime expired, redirect to sign-in",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
},
},
ExpectStatusCode: http.StatusFound,
},
{
Name: "refresh expired, refresh fails, show error",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
},
},
RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, ErrRefreshFailed },
ExpectStatusCode: http.StatusInternalServerError,
},
{
Name: "refresh expired, user not OK, deny",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
},
},
RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, nil },
ExpectStatusCode: http.StatusForbidden,
},
{
Name: "refresh expired, user OK, expect ok",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
},
},
RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil },
ExpectStatusCode: http.StatusOK,
},
{
Name: "refresh expired, refresh and user OK, error saving session, show error",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
},
SaveError: SaveCookieFailed,
},
RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil },
ExpectStatusCode: http.StatusInternalServerError,
},
{
Name: "validation expired, user not OK, deny",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute),
},
},
ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return false },
ExpectStatusCode: http.StatusForbidden,
},
{
Name: "validation expired, user OK, expect ok",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute),
},
},
ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return true },
ExpectStatusCode: http.StatusOK,
},
{
Name: "wrong identity provider, redirect to sign-in",
SessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
ProviderSlug: "example",
Email: "email1@example.com",
AccessToken: "my_access_token",
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
},
},
ExpectStatusCode: http.StatusFound,
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
providerURL, _ := url.Parse("http://localhost/")
tp := providers.NewTestProvider(providerURL, "")
tp.RefreshSessionFunc = tc.RefreshSessionFunc
tp.ValidateSessionFunc = tc.ValidateSessionFunc

proxy, close := testNewOAuthProxy(t,
SetProvider(tp),
setSessionStore(tc.SessionStore),
)
defer close()

req := httptest.NewRequest("GET", "https://localhost", nil)
rw := httptest.NewRecorder()

proxy.Proxy(rw, req)

res := rw.Result()

if tc.ExpectStatusCode != res.StatusCode {
t.Errorf("have: %v", res.StatusCode)
t.Errorf("want: %v", tc.ExpectStatusCode)
t.Fatalf("expected status codes to be equal")
}
})
}
}

func TestProxyXHRErrorHandling(t *testing.T) {
testCases := []struct {
Name string
Expand Down
3 changes: 3 additions & 0 deletions internal/proxy/providers/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ func (p *SSOProvider) Redeem(redirectURL, code string) (*sessions.SessionState,

user := strings.Split(jsonResponse.Email, "@")[0]
return &sessions.SessionState{
ProviderSlug: p.ProviderData.ProviderSlug,
ProviderType: "sso",

AccessToken: jsonResponse.AccessToken,
RefreshToken: jsonResponse.RefreshToken,

Expand Down