Skip to content

Commit

Permalink
revert google.go prior to adding my own tests and add better ones
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Hoey <matt.hoey@missionlane.com>
  • Loading branch information
snuggie12 committed Sep 14, 2022
1 parent 333c04c commit 7e914dd
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 45 deletions.
69 changes: 31 additions & 38 deletions connector/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
scopes = append(scopes, "profile", "email")
}

srv, err := createDirectoryService(c.ServiceAccountFilePath, c.AdminEmail, logger)
srv, err := createDirectoryService(c.ServiceAccountFilePath, c.AdminEmail)
if err != nil {
cancel()
return nil, fmt.Errorf("could not create directory service: %v", err)
Expand Down Expand Up @@ -220,7 +220,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
var groups []string
if s.Groups && c.adminSrv != nil {
checkedGroups := make(map[string]struct{})
groups, err = getGroups(c.getGroupsList, claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
}
Expand All @@ -244,22 +244,15 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
return identity, nil
}

// getGroupsList returns a list of Groups from google
func (c *googleConnector) getGroupsList(email string, nextPageToken string) (*admin.Groups, error) {
groupsList, err := c.adminSrv.Groups.List().
UserKey(email).PageToken(nextPageToken).Do()
return groupsList, err
}

// getGroups creates a connection to the admin directory service and lists
// all groups the user is a member of
// to test functionality, first parameter is the function you want to run to fetch groups
func getGroups(getGroupsListFunc func(string, string) (*admin.Groups, error), email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
var userGroups []string
var err error
groupsList := &admin.Groups{}
for {
groupsList, err = getGroupsListFunc(email, groupsList.NextPageToken)
groupsList, err = c.adminSrv.Groups.List().
UserKey(email).PageToken(groupsList.NextPageToken).Do()
if err != nil {
return nil, fmt.Errorf("could not list groups: %v", err)
}
Expand All @@ -278,7 +271,7 @@ func getGroups(getGroupsListFunc func(string, string) (*admin.Groups, error), em
}

// getGroups takes a user's email/alias as well as a group's email/alias
transitiveGroups, err := getGroups(getGroupsListFunc, group.Email, fetchTransitiveGroupMembership, checkedGroups)
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
}
Expand All @@ -294,35 +287,35 @@ func getGroups(getGroupsListFunc func(string, string) (*admin.Groups, error), em
return userGroups, nil
}

// createDirectoryService sets up super user impersonation and creates an admin client for calling
// the google admin api. If no serviceAccountFilePath is defined, the application default credential
// is used.
func createDirectoryService(serviceAccountFilePath, email string, logger log.Logger) (*admin.Service, error) {
if email == "" {
return nil, fmt.Errorf("directory service requires adminEmail")
// createDirectoryService loads a google service account credentials file,
// sets up super user impersonation and creates an admin client for calling
// the google admin api
func createDirectoryService(serviceAccountFilePath string, email string) (*admin.Service, error) {
if serviceAccountFilePath == "" && email == "" {
return nil, nil
}

var jsonCredentials []byte
var err error

ctx := context.Background()
if serviceAccountFilePath == "" {
logger.Warn("the application default credential is used since the service account file path is not used")
credential, err := google.FindDefaultCredentials(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch application default credentials: %w", err)
}
jsonCredentials = credential.JSON
} else {
jsonCredentials, err = os.ReadFile(serviceAccountFilePath)
if err != nil {
return nil, fmt.Errorf("error reading credentials from file: %v", err)
}
if serviceAccountFilePath == "" || email == "" {
return nil, fmt.Errorf("directory service requires both serviceAccountFilePath and adminEmail")
}
jsonCredentials, err := os.ReadFile(serviceAccountFilePath)
if err != nil {
return nil, fmt.Errorf("error reading credentials from file: %v", err)
}

config, err := google.JWTConfigFromJSON(jsonCredentials, admin.AdminDirectoryGroupReadonlyScope)
if err != nil {
return nil, fmt.Errorf("unable to parse credentials to config: %v", err)
return nil, fmt.Errorf("unable to parse client secret file to config: %v", err)
}

// Impersonate an admin. This is mandatory for the admin APIs.
config.Subject = email
return admin.NewService(ctx, option.WithHTTPClient(config.Client(ctx)))

ctx := context.Background()
client := config.Client(ctx)

srv, err := admin.NewService(ctx, option.WithHTTPClient(client))
if err != nil {
return nil, fmt.Errorf("unable to create directory service %v", err)
}
return srv, nil
}
107 changes: 100 additions & 7 deletions connector/google/google_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package google

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -10,17 +11,39 @@ import (

"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"

admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/option"
)

var (
// groups_0
// ┌───────┤
// groups_2 groups_1
// │ ├────────┐
// └── user_1 user_2
testGroups = map[string][]*admin.Group{
"user_1@dexidp.com": {{Email: "groups_2@dexidp.com"}, {Email: "groups_1@dexidp.com"}},
"user_2@dexidp.com": {{Email: "groups_1@dexidp.com"}},
"groups_1@dexidp.com": {{Email: "groups_0@dexidp.com"}},
"groups_2@dexidp.com": {{Email: "groups_0@dexidp.com"}},
"groups_0@dexidp.com": {},
}
callCounter = make(map[string]int)
)

func testSetup(t *testing.T) *httptest.Server {
mux := http.NewServeMux()
// TODO: mock calls
// mux.HandleFunc("/admin/directory/v1/groups", func(w http.ResponseWriter, r *http.Request) {
// w.Header().Add("Content-Type", "application/json")
// json.NewEncoder(w).Encode(&admin.Groups{
// Groups: []*admin.Group{},
// })
// })

mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
userKey := r.URL.Query().Get("userKey")
if groups, ok := testGroups[userKey]; ok {
json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
callCounter[userKey]++
}
})

return httptest.NewServer(mux)
}

Expand Down Expand Up @@ -143,3 +166,73 @@ func TestOpen(t *testing.T) {
})
}
}

func TestGetGroups(t *testing.T) {
ts := testSetup(t)
defer ts.Close()

serviceAccountFilePath, err := tempServiceAccountKey()
assert.Nil(t, err)

os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath)
conn, err := newConnector(&Config{
ClientID: "testClient",
ClientSecret: "testSecret",
RedirectURI: ts.URL + "/callback",
Scopes: []string{"openid", "groups"},
AdminEmail: "admin@dexidp.com",
}, ts.URL)
assert.Nil(t, err)

conn.adminSrv, err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
type testCase struct {
userKey string
fetchTransitiveGroupMembership bool
shouldErr bool
expectedGroups []string
}

for name, testCase := range map[string]testCase{
"user1_non_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user1_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user2_non_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com"},
},
"user2_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"},
},
} {
testCase := testCase
callCounter = map[string]int{}
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})

groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
if testCase.shouldErr {
assert.NotNil(err)
} else {
assert.Nil(err)
}
assert.ElementsMatch(testCase.expectedGroups, groups)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
})
}
}

0 comments on commit 7e914dd

Please sign in to comment.