From 1c544f533d8271d5841d9036537ecb44064e3481 Mon Sep 17 00:00:00 2001 From: Matt Hoey Date: Fri, 12 Aug 2022 16:02:28 -0700 Subject: [PATCH] Move unique functionality into getGroups to reduce calls to google Signed-off-by: Matt Hoey --- connector/google/google.go | 43 ++++++------- connector/google/google_test.go | 106 +++++++++++++++++++++++++++++--- 2 files changed, 118 insertions(+), 31 deletions(-) diff --git a/connector/google/google.go b/connector/google/google.go index 3f79a8a227..f80c3586ab 100644 --- a/connector/google/google.go +++ b/connector/google/google.go @@ -227,7 +227,8 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector var groups []string if s.Groups && c.adminSrv != nil { - groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership) + checkedGroups := make(map[string]struct{}) + groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups) if err != nil { return identity, fmt.Errorf("google: could not retrieve groups: %v", err) } @@ -253,7 +254,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector // getGroups creates a connection to the admin directory service and lists // all groups the user is a member of -func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool) ([]string, error) { +func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) { var userGroups []string var err error groupsList := &admin.Groups{} @@ -265,18 +266,25 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership } for _, group := range groupsList.Groups { + if _, exists := checkedGroups[group.Email]; exists { + continue + } + + checkedGroups[group.Email] = struct{}{} // TODO (joelspeed): Make desired group key configurable userGroups = append(userGroups, group.Email) - // getGroups takes a user's email/alias as well as a group's email/alias - if fetchTransitiveGroupMembership { - transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership) - if err != nil { - return nil, fmt.Errorf("could not list transitive groups: %v", err) - } + if !fetchTransitiveGroupMembership { + continue + } - userGroups = append(userGroups, transitiveGroups...) + // getGroups takes a user's email/alias as well as a group's email/alias + transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups) + if err != nil { + return nil, fmt.Errorf("could not list transitive groups: %v", err) } + + userGroups = append(userGroups, transitiveGroups...) } if groupsList.NextPageToken == "" { @@ -284,7 +292,7 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership } } - return uniqueGroups(userGroups), nil + return userGroups, nil } // createDirectoryService sets up super user impersonation and creates an admin client for calling @@ -316,7 +324,7 @@ func createDirectoryService(serviceAccountFilePath, email string, logger log.Log } config, err := google.JWTConfigFromJSON(jsonCredentials, admin.AdminDirectoryGroupReadonlyScope) if err != nil { - return nil, fmt.Errorf("unable to parse credentials to config: %v", err) + return nil, fmt.Errorf("unable to parse client secret file to config: %v", err) } // Only attempt impersonation when there is a user configured @@ -326,16 +334,3 @@ func createDirectoryService(serviceAccountFilePath, email string, logger log.Log return admin.NewService(ctx, option.WithHTTPClient(config.Client(ctx))) } - -// uniqueGroups returns the unique groups of a slice -func uniqueGroups(groups []string) []string { - keys := make(map[string]struct{}) - unique := []string{} - for _, group := range groups { - if _, exists := keys[group]; !exists { - keys[group] = struct{}{} - unique = append(unique, group) - } - } - return unique -} diff --git a/connector/google/google_test.go b/connector/google/google_test.go index b0c4f3a2f8..83c4cba1a0 100644 --- a/connector/google/google_test.go +++ b/connector/google/google_test.go @@ -1,6 +1,7 @@ package google import ( + "context" "encoding/json" "fmt" "net/http" @@ -10,17 +11,38 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + admin "google.golang.org/api/admin/directory/v1" + "google.golang.org/api/option" +) + +var ( + // groups_0 + // ┌───────┤ + // groups_2 groups_1 + // │ ├────────┐ + // └── user_1 user_2 + testGroups = map[string][]*admin.Group{ + "user_1@dexidp.com": {{Email: "groups_2@dexidp.com"}, {Email: "groups_1@dexidp.com"}}, + "user_2@dexidp.com": {{Email: "groups_1@dexidp.com"}}, + "groups_1@dexidp.com": {{Email: "groups_0@dexidp.com"}}, + "groups_2@dexidp.com": {{Email: "groups_0@dexidp.com"}}, + "groups_0@dexidp.com": {}, + } + callCounter = make(map[string]int) ) func testSetup(t *testing.T) *httptest.Server { mux := http.NewServeMux() - // TODO: mock calls - // mux.HandleFunc("/admin/directory/v1/groups", func(w http.ResponseWriter, r *http.Request) { - // w.Header().Add("Content-Type", "application/json") - // json.NewEncoder(w).Encode(&admin.Groups{ - // Groups: []*admin.Group{}, - // }) - // }) + + mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + userKey := r.URL.Query().Get("userKey") + if groups, ok := testGroups[userKey]; ok { + json.NewEncoder(w).Encode(admin.Groups{Groups: groups}) + callCounter[userKey]++ + } + }) + return httptest.NewServer(mux) } @@ -144,3 +166,73 @@ func TestOpen(t *testing.T) { }) } } + +func TestGetGroups(t *testing.T) { + ts := testSetup(t) + defer ts.Close() + + serviceAccountFilePath, err := tempServiceAccountKey() + assert.Nil(t, err) + + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath) + conn, err := newConnector(&Config{ + ClientID: "testClient", + ClientSecret: "testSecret", + RedirectURI: ts.URL + "/callback", + Scopes: []string{"openid", "groups"}, + AdminEmail: "admin@dexidp.com", + }, ts.URL) + assert.Nil(t, err) + + conn.adminSrv, err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) + assert.Nil(t, err) + type testCase struct { + userKey string + fetchTransitiveGroupMembership bool + shouldErr bool + expectedGroups []string + } + + for name, testCase := range map[string]testCase{ + "user1_non_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user1_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user2_non_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com"}, + }, + "user2_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, + }, + } { + testCase := testCase + callCounter = map[string]int{} + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + lookup := make(map[string]struct{}) + + groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) + if testCase.shouldErr { + assert.NotNil(err) + } else { + assert.Nil(err) + } + assert.ElementsMatch(testCase.expectedGroups, groups) + t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) + }) + } +}