Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move unique functionality into getGroups to reduce calls to google #2628

Merged
merged 1 commit into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 19 additions & 24 deletions connector/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector

var groups []string
if s.Groups && c.adminSrv != nil {
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership)
checkedGroups := make(map[string]struct{})
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
}
Expand All @@ -253,7 +254,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector

// getGroups creates a connection to the admin directory service and lists
// all groups the user is a member of
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool) ([]string, error) {
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
var userGroups []string
var err error
groupsList := &admin.Groups{}
Expand All @@ -265,26 +266,33 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership
}

for _, group := range groupsList.Groups {
if _, exists := checkedGroups[group.Email]; exists {
continue
}

checkedGroups[group.Email] = struct{}{}
// TODO (joelspeed): Make desired group key configurable
userGroups = append(userGroups, group.Email)

// getGroups takes a user's email/alias as well as a group's email/alias
if fetchTransitiveGroupMembership {
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership)
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
}
if !fetchTransitiveGroupMembership {
continue
}

userGroups = append(userGroups, transitiveGroups...)
// getGroups takes a user's email/alias as well as a group's email/alias
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
}

userGroups = append(userGroups, transitiveGroups...)
}

if groupsList.NextPageToken == "" {
break
}
}

return uniqueGroups(userGroups), nil
return userGroups, nil
}

// createDirectoryService sets up super user impersonation and creates an admin client for calling
Expand Down Expand Up @@ -316,7 +324,7 @@ func createDirectoryService(serviceAccountFilePath, email string, logger log.Log
}
config, err := google.JWTConfigFromJSON(jsonCredentials, admin.AdminDirectoryGroupReadonlyScope)
if err != nil {
return nil, fmt.Errorf("unable to parse credentials to config: %v", err)
return nil, fmt.Errorf("unable to parse client secret file to config: %v", err)
}

// Only attempt impersonation when there is a user configured
Expand All @@ -326,16 +334,3 @@ func createDirectoryService(serviceAccountFilePath, email string, logger log.Log

return admin.NewService(ctx, option.WithHTTPClient(config.Client(ctx)))
}

// uniqueGroups returns the unique groups of a slice
func uniqueGroups(groups []string) []string {
keys := make(map[string]struct{})
unique := []string{}
for _, group := range groups {
if _, exists := keys[group]; !exists {
keys[group] = struct{}{}
unique = append(unique, group)
}
}
return unique
}
106 changes: 99 additions & 7 deletions connector/google/google_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package google

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -10,17 +11,38 @@ import (

"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/option"
)

var (
// groups_0
// ┌───────┤
// groups_2 groups_1
// │ ├────────┐
// └── user_1 user_2
testGroups = map[string][]*admin.Group{
"user_1@dexidp.com": {{Email: "groups_2@dexidp.com"}, {Email: "groups_1@dexidp.com"}},
"user_2@dexidp.com": {{Email: "groups_1@dexidp.com"}},
"groups_1@dexidp.com": {{Email: "groups_0@dexidp.com"}},
"groups_2@dexidp.com": {{Email: "groups_0@dexidp.com"}},
"groups_0@dexidp.com": {},
}
callCounter = make(map[string]int)
)

func testSetup(t *testing.T) *httptest.Server {
mux := http.NewServeMux()
// TODO: mock calls
// mux.HandleFunc("/admin/directory/v1/groups", func(w http.ResponseWriter, r *http.Request) {
// w.Header().Add("Content-Type", "application/json")
// json.NewEncoder(w).Encode(&admin.Groups{
// Groups: []*admin.Group{},
// })
// })

mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
userKey := r.URL.Query().Get("userKey")
if groups, ok := testGroups[userKey]; ok {
json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
callCounter[userKey]++
}
})

return httptest.NewServer(mux)
}

Expand Down Expand Up @@ -144,3 +166,73 @@ func TestOpen(t *testing.T) {
})
}
}

func TestGetGroups(t *testing.T) {
ts := testSetup(t)
defer ts.Close()

serviceAccountFilePath, err := tempServiceAccountKey()
assert.Nil(t, err)

os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath)
conn, err := newConnector(&Config{
ClientID: "testClient",
ClientSecret: "testSecret",
RedirectURI: ts.URL + "/callback",
Scopes: []string{"openid", "groups"},
AdminEmail: "admin@dexidp.com",
}, ts.URL)
assert.Nil(t, err)

conn.adminSrv, err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
type testCase struct {
userKey string
fetchTransitiveGroupMembership bool
shouldErr bool
expectedGroups []string
}

for name, testCase := range map[string]testCase{
"user1_non_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user1_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user2_non_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com"},
},
"user2_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"},
},
} {
testCase := testCase
callCounter = map[string]int{}
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})

groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
if testCase.shouldErr {
assert.NotNil(err)
} else {
assert.Nil(err)
}
assert.ElementsMatch(testCase.expectedGroups, groups)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
})
}
}