Skip to content

Commit

Permalink
sso_proxy: go vet
Browse files Browse the repository at this point in the history
  • Loading branch information
jphines committed Apr 1, 2019
1 parent a0cd908 commit c75733d
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion internal/proxy/oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
}
allowedGroups := route.upstreamConfig.AllowedGroups

inGroups, validGroup, err := p.provider.ValidateGroup(session.Email, allowedGroups)
inGroups, validGroup, err := p.provider.ValidateGroup(session.Email, allowedGroups, session.AccessToken)
if err != nil {
tags = append(tags, "error:user_group_failed")
p.StatsdClient.Incr("provider_error", tags, 1.0)
Expand Down
4 changes: 2 additions & 2 deletions internal/proxy/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
type Provider interface {
Data() *ProviderData
Redeem(string, string) (*sessions.SessionState, error)
ValidateGroup(string, []string) ([]string, bool, error)
UserGroups(string, []string) ([]string, error)
ValidateGroup(string, []string, string) ([]string, bool, error)
UserGroups(string, []string, string) ([]string, error)
ValidateSessionState(*sessions.SessionState, []string) bool
GetSignInURL(redirectURL *url.URL, finalRedirect string) *url.URL
GetSignOutURL(redirectURL *url.URL) *url.URL
Expand Down
8 changes: 4 additions & 4 deletions internal/proxy/providers/singleflight_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ func (p *SingleFlightProvider) Redeem(redirectURL, code string) (*sessions.Sessi
}

// ValidateGroup takes an email, allowedGroups, and userGroups and passes it to the provider's ValidateGroup function and returns the response
func (p *SingleFlightProvider) ValidateGroup(email string, allowedGroups []string) ([]string, bool, error) {
return p.provider.ValidateGroup(email, allowedGroups)
func (p *SingleFlightProvider) ValidateGroup(email string, allowedGroups []string, accessToken string) ([]string, bool, error) {
return p.provider.ValidateGroup(email, allowedGroups, accessToken)
}

// UserGroups takes an email and passes it to the provider's UserGroups function and returns the response
func (p *SingleFlightProvider) UserGroups(email string, groups []string) ([]string, error) {
func (p *SingleFlightProvider) UserGroups(email string, groups []string, accessToken string) ([]string, error) {
// sort the groups so that other requests may be able to use the cached request
sort.Strings(groups)
response, err := p.do("UserGroups", fmt.Sprintf("%s:%s", email, strings.Join(groups, ",")), func() (interface{}, error) {
return p.provider.UserGroups(email, groups)
return p.provider.UserGroups(email, groups, accessToken)
})
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/providers/sso_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func TestSSOProviderGroups(t *testing.T) {
}
p.ProfileURL, server = newTestServer(profileStatus, body)
defer server.Close()
inGroups, valid, err := p.ValidateGroup(tc.Email, tc.ProxyGroupIds)
inGroups, valid, err := p.ValidateGroup(tc.Email, tc.ProxyGroupIds, "accessToken")
testutil.Equal(t, tc.ExpectError, err)
if err == nil {
testutil.Equal(t, tc.ExpectedValid, valid)
Expand Down
12 changes: 6 additions & 6 deletions internal/proxy/providers/test_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type TestProvider struct {
RefreshSessionFunc func(*sessions.SessionState, []string) (bool, error)
ValidateSessionFunc func(*sessions.SessionState, []string) bool
RedeemFunc func(string, string) (*sessions.SessionState, error)
UserGroupsFunc func(string, []string) ([]string, error)
ValidateGroupsFunc func(string, []string) ([]string, bool, error)
UserGroupsFunc func(string, []string, string) ([]string, error)
ValidateGroupsFunc func(string, []string, string) ([]string, bool, error)
*ProviderData
}

Expand Down Expand Up @@ -62,13 +62,13 @@ func (tp *TestProvider) RefreshSession(s *sessions.SessionState, g []string) (bo
}

// UserGroups mocks the UserGroups function
func (tp *TestProvider) UserGroups(email string, groups []string) ([]string, error) {
return tp.UserGroupsFunc(email, groups)
func (tp *TestProvider) UserGroups(email string, groups []string, accessToken string) ([]string, error) {
return tp.UserGroupsFunc(email, groups, accessToken)
}

// ValidateGroup mocks the ValidateGroup function
func (tp *TestProvider) ValidateGroup(email string, groups []string) ([]string, bool, error) {
return tp.ValidateGroupsFunc(email, groups)
func (tp *TestProvider) ValidateGroup(email string, groups []string, accessToken string) ([]string, bool, error) {
return tp.ValidateGroupsFunc(email, groups, accessToken)
}

// GetSignOutURL mocks GetSignOutURL function
Expand Down

0 comments on commit c75733d

Please sign in to comment.