Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sso: send access token on group information query #171

Merged
merged 2 commits into from
Apr 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion internal/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,13 +796,20 @@ func (p *Authenticator) GetProfile(rw http.ResponseWriter, req *http.Request) {
return
}

accessToken := req.Header.Get("X-Access-Token")
if accessToken == "" {
// TODO: This should be an error in the future, but is OK to be missing for now
// we track an error to see observe how often this occurs
p.StatsdClient.Incr("application_error", append(tags, "error:missing_access_token"), 1.0)
}

groupsFormValue := req.FormValue("groups")
allowedGroups := []string{}
if groupsFormValue != "" {
allowedGroups = strings.Split(groupsFormValue, ",")
}

groups, err := p.provider.ValidateGroupMembership(email, allowedGroups)
groups, err := p.provider.ValidateGroupMembership(email, allowedGroups, accessToken)
if err != nil {
tags = append(tags, "error:groups_resource")
p.StatsdClient.Incr("provider_error", tags, 1.0)
Expand Down
2 changes: 1 addition & 1 deletion internal/auth/providers/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (p *GoogleProvider) PopulateMembers(group string) (groups.MemberSet, error)

// ValidateGroupMembership takes in an email and the allowed groups and returns the groups that the email is part of in that list.
// If `allGroups` is an empty list, returns an empty list.
func (p *GoogleProvider) ValidateGroupMembership(email string, allGroups []string) ([]string, error) {
func (p *GoogleProvider) ValidateGroupMembership(email string, allGroups []string, _ string) ([]string, error) {
logger := log.NewLogEntry()

groups := []string{}
Expand Down
2 changes: 1 addition & 1 deletion internal/auth/providers/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ func TestValidateGroupMembers(t *testing.T) {
GroupsCache: &groups.MockCache{ListMembershipsFunc: tc.listMembershipsFunc, Refreshed: true},
}

groups, err := p.ValidateGroupMembership("email", tc.inputAllowedGroups)
groups, err := p.ValidateGroupMembership("email", tc.inputAllowedGroups, "accessToken")

if err != nil {
if tc.expectedErrorString != err.Error() {
Expand Down
2 changes: 1 addition & 1 deletion internal/auth/providers/provider_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (p *ProviderData) Revoke(s *sessions.SessionState) error {
}

// ValidateGroupMembership returns an ErrNotImplemented.
func (p *ProviderData) ValidateGroupMembership(string, []string) ([]string, error) {
func (p *ProviderData) ValidateGroupMembership(string, []string, string) ([]string, error) {
return nil, ErrNotImplemented
}

Expand Down
2 changes: 1 addition & 1 deletion internal/auth/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Provider interface {
ValidateSessionState(*sessions.SessionState) bool
GetSignInURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
ValidateGroupMembership(string, []string) ([]string, error)
ValidateGroupMembership(string, []string, string) ([]string, error)
Revoke(*sessions.SessionState) error
RefreshAccessToken(string) (string, time.Duration, error)
Stop()
Expand Down
4 changes: 2 additions & 2 deletions internal/auth/providers/singleflight_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ func (p *SingleFlightProvider) RefreshSessionIfNeeded(s *sessions.SessionState)
}

// ValidateGroupMembership wraps the provider's GroupsResource function in a single flight call.
func (p *SingleFlightProvider) ValidateGroupMembership(email string, allowedGroups []string) ([]string, error) {
func (p *SingleFlightProvider) ValidateGroupMembership(email string, allowedGroups []string, accessToken string) ([]string, error) {
sort.Strings(allowedGroups)
response, err := p.do("ValidateGroupMembership", fmt.Sprintf("%s:%s", email, strings.Join(allowedGroups, ",")),
func() (interface{}, error) {
return p.provider.ValidateGroupMembership(email, allowedGroups)
return p.provider.ValidateGroupMembership(email, allowedGroups, accessToken)
})
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/auth/providers/test_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (tp *TestProvider) Revoke(*sessions.SessionState) error {
}

// ValidateGroupMembership returns the mock provider's GroupsError if not nil, or the Groups field value.
func (tp *TestProvider) ValidateGroupMembership(string, []string) ([]string, error) {
func (tp *TestProvider) ValidateGroupMembership(string, []string, string) ([]string, error) {
return tp.Groups, tp.GroupsError
}

Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
}
allowedGroups := route.upstreamConfig.AllowedGroups

inGroups, validGroup, err := p.provider.ValidateGroup(session.Email, allowedGroups)
inGroups, validGroup, err := p.provider.ValidateGroup(session.Email, allowedGroups, session.AccessToken)
if err != nil {
tags = append(tags, "error:user_group_failed")
p.StatsdClient.Incr("provider_error", tags, 1.0)
Expand Down
4 changes: 2 additions & 2 deletions internal/proxy/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
type Provider interface {
Data() *ProviderData
Redeem(string, string) (*sessions.SessionState, error)
ValidateGroup(string, []string) ([]string, bool, error)
UserGroups(string, []string) ([]string, error)
ValidateGroup(string, []string, string) ([]string, bool, error)
UserGroups(string, []string, string) ([]string, error)
ValidateSessionState(*sessions.SessionState, []string) bool
GetSignInURL(redirectURL *url.URL, finalRedirect string) *url.URL
GetSignOutURL(redirectURL *url.URL) *url.URL
Expand Down
8 changes: 4 additions & 4 deletions internal/proxy/providers/singleflight_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ func (p *SingleFlightProvider) Redeem(redirectURL, code string) (*sessions.Sessi
}

// ValidateGroup takes an email, allowedGroups, and userGroups and passes it to the provider's ValidateGroup function and returns the response
func (p *SingleFlightProvider) ValidateGroup(email string, allowedGroups []string) ([]string, bool, error) {
return p.provider.ValidateGroup(email, allowedGroups)
func (p *SingleFlightProvider) ValidateGroup(email string, allowedGroups []string, accessToken string) ([]string, bool, error) {
return p.provider.ValidateGroup(email, allowedGroups, accessToken)
}

// UserGroups takes an email and passes it to the provider's UserGroups function and returns the response
func (p *SingleFlightProvider) UserGroups(email string, groups []string) ([]string, error) {
func (p *SingleFlightProvider) UserGroups(email string, groups []string, accessToken string) ([]string, error) {
// sort the groups so that other requests may be able to use the cached request
sort.Strings(groups)
response, err := p.do("UserGroups", fmt.Sprintf("%s:%s", email, strings.Join(groups, ",")), func() (interface{}, error) {
return p.provider.UserGroups(email, groups)
return p.provider.UserGroups(email, groups, accessToken)
})
if err != nil {
return nil, err
Expand Down
14 changes: 9 additions & 5 deletions internal/proxy/providers/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (p *SSOProvider) Redeem(redirectURL, code string) (*sessions.SessionState,

// ValidateGroup does a GET request to the profile url and returns true if the user belongs to
// an authorized group.
func (p *SSOProvider) ValidateGroup(email string, allowedGroups []string) ([]string, bool, error) {
func (p *SSOProvider) ValidateGroup(email string, allowedGroups []string, accessToken string) ([]string, bool, error) {
logger := log.NewLogEntry()

logger.WithUser(email).WithAllowedGroups(allowedGroups).Info("validating groups")
Expand All @@ -177,7 +177,7 @@ func (p *SSOProvider) ValidateGroup(email string, allowedGroups []string) ([]str
return inGroups, true, nil
}

userGroups, err := p.UserGroups(email, allowedGroups)
userGroups, err := p.UserGroups(email, allowedGroups, accessToken)
if err != nil {
return nil, false, err
}
Expand All @@ -196,7 +196,7 @@ func (p *SSOProvider) ValidateGroup(email string, allowedGroups []string) ([]str
}

// UserGroups takes an email and returns the UserGroups for that email
func (p *SSOProvider) UserGroups(email string, groups []string) ([]string, error) {
func (p *SSOProvider) UserGroups(email string, groups []string, accessToken string) ([]string, error) {
params := url.Values{}
params.Add("email", email)
params.Add("client_id", p.ClientID)
Expand All @@ -206,7 +206,10 @@ func (p *SSOProvider) UserGroups(email string, groups []string) ([]string, error
if err != nil {
return nil, err
}

req.Header.Set("X-Client-Secret", p.ClientSecret)
req.Header.Set("X-Access-Token", accessToken)

resp, err := httpClient.Do(req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -260,7 +263,7 @@ func (p *SSOProvider) RefreshSession(s *sessions.SessionState, allowedGroups []s
return false, err
}

inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups)
inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups, newToken)
if err != nil {
// When we detect that the auth provider is not explicitly denying
// authentication, and is merely unavailable, we refresh and continue
Expand Down Expand Up @@ -342,6 +345,7 @@ func (p *SSOProvider) ValidateSessionState(s *sessions.SessionState, allowedGrou
logger.WithUser(s.Email).Error(err, "error validating session state")
return false
}

req.Header.Set("X-Client-Secret", p.ClientSecret)
req.Header.Set("X-Access-Token", s.AccessToken)

Expand All @@ -367,7 +371,7 @@ func (p *SSOProvider) ValidateSessionState(s *sessions.SessionState, allowedGrou
}

// check the user is in the proper group(s)
inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups)
inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups, s.AccessToken)
if err != nil {
// When we detect that the auth provider is not explicitly denying
// authentication, and is merely unavailable, we validate and continue
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/providers/sso_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func TestSSOProviderGroups(t *testing.T) {
}
p.ProfileURL, server = newTestServer(profileStatus, body)
defer server.Close()
inGroups, valid, err := p.ValidateGroup(tc.Email, tc.ProxyGroupIds)
inGroups, valid, err := p.ValidateGroup(tc.Email, tc.ProxyGroupIds, "accessToken")
testutil.Equal(t, tc.ExpectError, err)
if err == nil {
testutil.Equal(t, tc.ExpectedValid, valid)
Expand Down
12 changes: 6 additions & 6 deletions internal/proxy/providers/test_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type TestProvider struct {
RefreshSessionFunc func(*sessions.SessionState, []string) (bool, error)
ValidateSessionFunc func(*sessions.SessionState, []string) bool
RedeemFunc func(string, string) (*sessions.SessionState, error)
UserGroupsFunc func(string, []string) ([]string, error)
ValidateGroupsFunc func(string, []string) ([]string, bool, error)
UserGroupsFunc func(string, []string, string) ([]string, error)
ValidateGroupsFunc func(string, []string, string) ([]string, bool, error)
*ProviderData
}

Expand Down Expand Up @@ -62,13 +62,13 @@ func (tp *TestProvider) RefreshSession(s *sessions.SessionState, g []string) (bo
}

// UserGroups mocks the UserGroups function
func (tp *TestProvider) UserGroups(email string, groups []string) ([]string, error) {
return tp.UserGroupsFunc(email, groups)
func (tp *TestProvider) UserGroups(email string, groups []string, accessToken string) ([]string, error) {
return tp.UserGroupsFunc(email, groups, accessToken)
}

// ValidateGroup mocks the ValidateGroup function
func (tp *TestProvider) ValidateGroup(email string, groups []string) ([]string, bool, error) {
return tp.ValidateGroupsFunc(email, groups)
func (tp *TestProvider) ValidateGroup(email string, groups []string, accessToken string) ([]string, bool, error) {
return tp.ValidateGroupsFunc(email, groups, accessToken)
}

// GetSignOutURL mocks GetSignOutURL function
Expand Down