Skip to content

Commit

Permalink
added back sign in url
Browse files Browse the repository at this point in the history
  • Loading branch information
Shraya Ramani committed Oct 17, 2018
1 parent f1c9135 commit cbdfeab
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 94 deletions.
4 changes: 2 additions & 2 deletions internal/proxy/oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
Host: req.Host,
Path: "/",
}
fullURL := providers.GetSignOutURL(p.provider.Data(), redirectURL)
fullURL := p.provider.GetSignOutURL(redirectURL)
http.Redirect(rw, req, fullURL.String(), http.StatusFound)
}

Expand Down Expand Up @@ -708,7 +708,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request, tags
return
}

signinURL := providers.GetSignInURL(p.provider.Data(), callbackURL, encryptedState)
signinURL := p.provider.GetSignInURL(callbackURL, encryptedState)
logger.WithSignInURL(signinURL).Info("starting OAuth flow")
http.Redirect(rw, req, signinURL.String(), http.StatusFound)
}
Expand Down
34 changes: 4 additions & 30 deletions internal/proxy/oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
})

pcTest.proxy.provider = &providers.TestProvider{
ValidToken: opts.providerValidateCookieResponse,
ValidateSessionFunc: func(*providers.SessionState, []string) bool { return opts.providerValidateCookieResponse },
}

pcTest.rw = httptest.NewRecorder()
Expand Down Expand Up @@ -1059,32 +1059,6 @@ func TestHeadersSentToUpstreams(t *testing.T) {

}

type testAuthenticateProvider struct {
*providers.ProviderData
refreshSessionFunc func(*providers.SessionState, []string) (bool, error)
validateSessionFunc func(*providers.SessionState, []string) bool
redeemFunc func(string, string) (*providers.SessionState, error)
}

func (tap *testAuthenticateProvider) RefreshSession(s *providers.SessionState, g []string) (bool, error) {
return tap.refreshSessionFunc(s, g)
}

func (tap *testAuthenticateProvider) ValidateSessionState(s *providers.SessionState, g []string) bool {
return tap.validateSessionFunc(s, g)
}

func (tap *testAuthenticateProvider) Redeem(redirectURL string, token string) (*providers.SessionState, error) {
return tap.redeemFunc(redirectURL, token)
}

func (tap *testAuthenticateProvider) UserGroups(string, []string) ([]string, error) {
return nil, nil
}
func (tap *testAuthenticateProvider) ValidateGroup(string, []string) ([]string, bool, error) {
return nil, false, nil
}

func TestAuthenticate(t *testing.T) {
// Constants to represent possible cookie behaviors.
const (
Expand Down Expand Up @@ -1234,9 +1208,9 @@ func TestAuthenticate(t *testing.T) {
opts.upstreamConfigs = generateTestUpstreamConfigs("foo-internal.sso.dev")
opts.Validate()
proxy, _ := NewOAuthProxy(opts, testValidatorFunc(true), testCookieCipher(tc.Cipher))
proxy.provider = &testAuthenticateProvider{
refreshSessionFunc: tc.RefreshSessionFunc,
validateSessionFunc: tc.ValidateSessionFunc,
proxy.provider = &providers.TestProvider{
RefreshSessionFunc: tc.RefreshSessionFunc,
ValidateSessionFunc: tc.ValidateSessionFunc,
}

value, err := providers.MarshalSession(tc.Session, proxy.CookieCipher)
Expand Down
47 changes: 2 additions & 45 deletions internal/proxy/providers/providers.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
package providers

import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/url"
"time"

"github.com/datadog/datadog-go/statsd"
)
Expand All @@ -19,49 +14,11 @@ type Provider interface {
UserGroups(string, []string) ([]string, error)
ValidateSessionState(*SessionState, []string) bool
RefreshSession(*SessionState, []string) (bool, error)
GetSignInURL(redirectURL *url.URL, state string) *url.URL
GetSignOutURL(redirectURL *url.URL) *url.URL
}

// New returns a new sso Provider
func New(provider string, p *ProviderData, sc *statsd.Client) Provider {
return NewSSOProvider(p, sc)
}

// GetSignInURL with typical oauth parameters
func GetSignInURL(data *ProviderData, redirectURL *url.URL, state string) *url.URL {
var a url.URL
a = *data.SignInURL
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", rawRedirect)
params.Add("scope", data.Scope)
params.Set("client_id", data.ClientID)
params.Set("response_type", "code")
params.Add("state", state)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", signRedirectURL(data.ClientSecret, rawRedirect, now))
a.RawQuery = params.Encode()
return &a
}

// GetSignOutURL creates and returns the sign out URL, given a redirectURL
func GetSignOutURL(data *ProviderData, redirectURL *url.URL) *url.URL {
var a url.URL
a = *data.SignOutURL
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery)
params.Add("redirect_uri", rawRedirect)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", signRedirectURL(data.ClientSecret, rawRedirect, now))
a.RawQuery = params.Encode()
return &a
}

// signRedirectURL signs the redirect url string, given a timestamp, and returns it
func signRedirectURL(clientSecret, rawRedirect string, timestamp time.Time) string {
h := hmac.New(sha256.New, []byte(clientSecret))
h.Write([]byte(rawRedirect))
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
return base64.URLEncoding.EncodeToString(h.Sum(nil))
}
11 changes: 11 additions & 0 deletions internal/proxy/providers/singleflight_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package providers
import (
"errors"
"fmt"
"net/url"
"sort"
"strings"

Expand Down Expand Up @@ -127,3 +128,13 @@ func (p *SingleFlightProvider) RefreshSession(s *SessionState, allowedGroups []s

return r, nil
}

// GetSignInURL calls the GetSignInURL for the provider, which will return the sign in url
func (p *SingleFlightProvider) GetSignInURL(redirectURI *url.URL, finalRedirect string) *url.URL {
return p.provider.GetSignInURL(redirectURI, finalRedirect)
}

// GetSignOutURL calls the GetSignOutURL for the provider, which will return the sign out url
func (p *SingleFlightProvider) GetSignOutURL(redirectURI *url.URL) *url.URL {
return p.provider.GetSignOutURL(redirectURI)
}
43 changes: 43 additions & 0 deletions internal/proxy/providers/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package providers

import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -383,3 +386,43 @@ func (p *SSOProvider) ValidateSessionState(s *SessionState, allowedGroups []stri

return true
}

// signRedirectURL signs the redirect url string, given a timestamp, and returns it
func signRedirectURL(clientSecret, rawRedirect string, timestamp time.Time) string {
h := hmac.New(sha256.New, []byte(clientSecret))
h.Write([]byte(rawRedirect))
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
return base64.URLEncoding.EncodeToString(h.Sum(nil))
}

// GetSignInURL with typical oauth parameters
func (p *SSOProvider) GetSignInURL(redirectURL *url.URL, state string) *url.URL {
var a url.URL
a = *p.Data().SignInURL
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", rawRedirect)
params.Add("scope", p.Data().Scope)
params.Set("client_id", p.Data().ClientID)
params.Set("response_type", "code")
params.Add("state", state)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", signRedirectURL(p.Data().ClientSecret, rawRedirect, now))
a.RawQuery = params.Encode()
return &a
}

// GetSignOutURL creates and returns the sign out URL, given a redirectURL
func (p *SSOProvider) GetSignOutURL(redirectURL *url.URL) *url.URL {
var a url.URL
a = *p.Data().SignOutURL
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery)
params.Add("redirect_uri", rawRedirect)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", signRedirectURL(p.Data().ClientSecret, rawRedirect, now))
a.RawQuery = params.Encode()
return &a
}
42 changes: 25 additions & 17 deletions internal/proxy/providers/test_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ import (

// TestProvider is a mock provider
type TestProvider struct {
RefreshSessionFunc func(*SessionState, []string) (bool, error)
ValidateSessionFunc func(*SessionState, []string) bool
RedeemFunc func(string, string) (*SessionState, error)
UserGroupsFunc func(string, []string) ([]string, error)
ValidateGroupsFunc func(string, []string) ([]string, bool, error)
*ProviderData
EmailAddress string
ValidToken bool
ValidGroup bool
Refreshed bool
Session *SessionState
Groups []string
}

// NewTestProvider returns a new TestProvider
Expand Down Expand Up @@ -42,31 +41,40 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
},
Scope: "profile.email",
},
EmailAddress: emailAddress,
}
}

// ValidateSessionState mocks the ValidateSessionState function
func (tp *TestProvider) ValidateSessionState(*SessionState, []string) bool {
return tp.ValidToken
func (tp *TestProvider) ValidateSessionState(s *SessionState, groups []string) bool {
return tp.ValidateSessionFunc(s, groups)
}

// Redeem mocks the provider Redeem function
func (tp *TestProvider) Redeem(string, string) (*SessionState, error) {
return tp.Session, nil
func (tp *TestProvider) Redeem(redirectURL string, token string) (*SessionState, error) {
return tp.RedeemFunc(redirectURL, token)
}

// RefreshSession mocks the RefreshSession function
func (tp *TestProvider) RefreshSession(*SessionState, []string) (bool, error) {
return tp.Refreshed, nil
func (tp *TestProvider) RefreshSession(s *SessionState, g []string) (bool, error) {
return tp.RefreshSessionFunc(s, g)
}

// UserGroups mocks the UserGroups function
func (tp *TestProvider) UserGroups(string, []string) ([]string, error) {
return tp.Groups, nil
func (tp *TestProvider) UserGroups(email string, groups []string) ([]string, error) {
return tp.UserGroupsFunc(email, groups)
}

// ValidateGroup mocks the ValidateGroup function
func (tp *TestProvider) ValidateGroup(string, []string) ([]string, bool, error) {
return tp.Groups, tp.ValidGroup, nil
func (tp *TestProvider) ValidateGroup(email string, groups []string) ([]string, bool, error) {
return tp.ValidateGroupsFunc(email, groups)
}

// GetSignOutURL mocks GetSignOutURL function
func (tp *TestProvider) GetSignOutURL(redirectURL *url.URL) *url.URL {
return tp.Data().SignOutURL
}

// GetSignInURL mocks GetSignInURL
func (tp *TestProvider) GetSignInURL(redirectURL *url.URL, state string) *url.URL {
return tp.Data().SignInURL
}

0 comments on commit cbdfeab

Please sign in to comment.