Skip to content

Commit

Permalink
sso: update google rfs
Browse files Browse the repository at this point in the history
  • Loading branch information
jphines committed Jun 19, 2019
1 parent ec71e44 commit 14f2a96
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 38 deletions.
2 changes: 1 addition & 1 deletion internal/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1630,5 +1630,5 @@ func TestGoogleGroupInvalidFile(t *testing.T) {
},
)
testutil.NotEqual(t, nil, err)
testutil.Equal(t, "invalid Google credentials file: file_doesnt_exist.json", err.Error())
testutil.Equal(t, "could not read google credentials file", err.Error())
}
3 changes: 3 additions & 0 deletions internal/auth/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import (
//
// PROVIDER_*_GOOGLE_CREDENTIALS
// PROVIDER_*_GOOGLE_IMPERSONATE
// PROVIDER_*_GOOGLE_PROMPT
// PROVIDER_*_GOOGLE_DOMAIN
//
// PROVIDER_*_OKTA_URL
// PROVIDER_*_OKTA_SERVER
Expand Down Expand Up @@ -227,6 +229,7 @@ type GoogleProviderConfig struct {
Credentials string `mapstructure:"credentials"`
Impersonate string `mapstructure:"impersonate"`
ApprovalPrompt string `mapstructure:"prompt"`
HostedDomain string `mapstructure:"domain"`
}

func (gpc GoogleProviderConfig) Validate() error {
Expand Down
18 changes: 6 additions & 12 deletions internal/auth/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/base64"
"fmt"
"net/url"
"os"
"path"

"github.com/buzzfeed/sso/internal/auth/providers"
Expand All @@ -28,16 +27,12 @@ func newProvider(pc ProviderConfig, sc SessionConfig) (providers.Provider, error
switch pc.ProviderType {
case providers.GoogleProviderName: // Google
gpc := pc.GoogleProviderConfig
p.ApprovalPrompt = gpc.ApprovalPrompt

if gpc.Credentials != "" {
_, err := os.Open(gpc.Credentials)
if err != nil {
return nil, fmt.Errorf("invalid Google credentials file: %s", gpc.Credentials)
}
}

googleProvider, err := providers.NewGoogleProvider(p, gpc.Impersonate, gpc.Credentials)
googleProvider, err := providers.NewGoogleProvider(p,
gpc.ApprovalPrompt,
gpc.HostedDomain,
gpc.Impersonate,
gpc.Credentials,
)
if err != nil {
return nil, err
}
Expand All @@ -48,7 +43,6 @@ func newProvider(pc ProviderConfig, sc SessionConfig) (providers.Provider, error
singleFlightProvider = providers.NewSingleFlightProvider(googleProvider)
case providers.OktaProviderName:
opc := pc.OktaProviderConfig

oktaProvider, err := providers.NewOktaProvider(p, opc.OrgURL, opc.ServerID)
if err != nil {
return nil, err
Expand Down
34 changes: 22 additions & 12 deletions internal/auth/providers/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ type GoogleProvider struct {
AdminService AdminService
cb *circuit.Breaker
GroupsCache groups.MemberSetCache

Prompt string
HostedDomain string
}

// NewGoogleProvider returns a new GoogleProvider and sets the provider url endpoints.
func NewGoogleProvider(p *ProviderData, impersonateUser, credsFilePath string) (*GoogleProvider, error) {
func NewGoogleProvider(p *ProviderData, prompt, hd, impersonate, credentials string) (*GoogleProvider, error) {
p.ProviderName = "Google"
p.SignInURL = &url.URL{Scheme: "https",
Host: "accounts.google.com",
Expand All @@ -49,19 +52,20 @@ func NewGoogleProvider(p *ProviderData, impersonateUser, credsFilePath string) (
Host: "www.googleapis.com",
Path: "/oauth2/v3/tokeninfo",
}
p.ProfileURL = &url.URL{}

if p.Scope == "" {
p.Scope = "profile email"
}
if p.ApprovalPrompt == "" {
p.ApprovalPrompt = "consent"
}

// not used for google
p.ProfileURL = &url.URL{}
if prompt == "" {
prompt = "consent"
}

googleProvider := &GoogleProvider{
ProviderData: p,
Prompt: prompt,
HostedDomain: hd,
}

googleProvider.cb = circuit.NewBreaker(&circuit.Options{
Expand All @@ -74,17 +78,19 @@ func NewGoogleProvider(p *ProviderData, impersonateUser, credsFilePath string) (
time.Duration(200)*time.Second, time.Duration(500)*time.Millisecond,
),
})
if credsFilePath != "" {
credsReader, err := os.Open(credsFilePath)

if credentials != "" {
credsReader, err := os.Open(credentials)
if err != nil {
return nil, errors.New("could not read google credentials file")
}

googleProvider.AdminService = &GoogleAdminService{
adminService: getAdminService(impersonateUser, credsReader),
adminService: getAdminService(impersonate, credsReader),
cb: googleProvider.cb,
}
}

return googleProvider, nil
}

Expand Down Expand Up @@ -136,9 +142,13 @@ func (p *GoogleProvider) GetSignInURL(redirectURI, state string) string {
params.Set("response_type", "code")
params.Set("redirect_uri", redirectURI)
params.Set("scope", p.Scope)
params.Set("access_token", "offline")
params.Add("state", state)
params.Set("prompt", p.ApprovalPrompt)
params.Set("access_type", "offline")
params.Set("state", state)
params.Set("prompt", p.Prompt)

if p.HostedDomain != "" {
params.Set("hd", p.HostedDomain)
}

a.RawQuery = params.Encode()
return a.String()
Expand Down
2 changes: 1 addition & 1 deletion internal/auth/providers/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func newGoogleProvider(providerData *ProviderData) *GoogleProvider {
ValidateURL: &url.URL{},
Scope: ""}
}
provider, _ := NewGoogleProvider(providerData, "", "")
provider, _ := NewGoogleProvider(providerData, "", "", "", "")
return provider
}

Expand Down
25 changes: 14 additions & 11 deletions internal/auth/providers/provider_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@ import (
// ProviderData holds the fields associated with providers
// necessary to implement the Provider interface.
type ProviderData struct {
ProviderName string
ProviderSlug string
ClientID string
ClientSecret string
SignInURL *url.URL
RedeemURL *url.URL
RevokeURL *url.URL
ProfileURL *url.URL
ValidateURL *url.URL
Scope string
ApprovalPrompt string
ProviderName string
ProviderSlug string

ClientID string
ClientSecret string

SignInURL *url.URL
RedeemURL *url.URL
RevokeURL *url.URL
ValidateURL *url.URL
ProfileURL *url.URL

Scope string

SessionLifetimeTTL time.Duration
}

Expand Down
1 change: 0 additions & 1 deletion internal/auth/providers/provider_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ func (p *ProviderData) GetSignInURL(redirectURI, state string) string {
a = *p.SignInURL
params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", redirectURI)
params.Set("approval_prompt", p.ApprovalPrompt)
params.Add("scope", p.Scope)
params.Set("client_id", p.ClientID)
params.Set("response_type", "code")
Expand Down

0 comments on commit 14f2a96

Please sign in to comment.