diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index 87665308..6b0463ac 100755 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -636,7 +636,7 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { Host: req.Host, Path: "/", } - fullURL := providers.GetSignOutURL(p.provider.Data(), redirectURL) + fullURL := p.provider.GetSignOutURL(redirectURL) http.Redirect(rw, req, fullURL.String(), http.StatusFound) } @@ -708,7 +708,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request, tags return } - signinURL := providers.GetSignInURL(p.provider.Data(), callbackURL, encryptedState) + signinURL := p.provider.GetSignInURL(callbackURL, encryptedState) logger.WithSignInURL(signinURL).Info("starting OAuth flow") http.Redirect(rw, req, signinURL.String(), http.StatusFound) } diff --git a/internal/proxy/oauthproxy_test.go b/internal/proxy/oauthproxy_test.go index c616a2a6..f97aad5a 100644 --- a/internal/proxy/oauthproxy_test.go +++ b/internal/proxy/oauthproxy_test.go @@ -470,7 +470,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { }) pcTest.proxy.provider = &providers.TestProvider{ - ValidToken: opts.providerValidateCookieResponse, + ValidateSessionFunc: func(*providers.SessionState, []string) bool { return opts.providerValidateCookieResponse }, } pcTest.rw = httptest.NewRecorder() @@ -1059,32 +1059,6 @@ func TestHeadersSentToUpstreams(t *testing.T) { } -type testAuthenticateProvider struct { - *providers.ProviderData - refreshSessionFunc func(*providers.SessionState, []string) (bool, error) - validateSessionFunc func(*providers.SessionState, []string) bool - redeemFunc func(string, string) (*providers.SessionState, error) -} - -func (tap *testAuthenticateProvider) RefreshSession(s *providers.SessionState, g []string) (bool, error) { - return tap.refreshSessionFunc(s, g) -} - -func (tap *testAuthenticateProvider) ValidateSessionState(s *providers.SessionState, g []string) bool { - return tap.validateSessionFunc(s, g) -} - -func (tap *testAuthenticateProvider) Redeem(redirectURL string, token string) (*providers.SessionState, error) { - return tap.redeemFunc(redirectURL, token) -} - -func (tap *testAuthenticateProvider) UserGroups(string, []string) ([]string, error) { - return nil, nil -} -func (tap *testAuthenticateProvider) ValidateGroup(string, []string) ([]string, bool, error) { - return nil, false, nil -} - func TestAuthenticate(t *testing.T) { // Constants to represent possible cookie behaviors. const ( @@ -1234,9 +1208,9 @@ func TestAuthenticate(t *testing.T) { opts.upstreamConfigs = generateTestUpstreamConfigs("foo-internal.sso.dev") opts.Validate() proxy, _ := NewOAuthProxy(opts, testValidatorFunc(true), testCookieCipher(tc.Cipher)) - proxy.provider = &testAuthenticateProvider{ - refreshSessionFunc: tc.RefreshSessionFunc, - validateSessionFunc: tc.ValidateSessionFunc, + proxy.provider = &providers.TestProvider{ + RefreshSessionFunc: tc.RefreshSessionFunc, + ValidateSessionFunc: tc.ValidateSessionFunc, } value, err := providers.MarshalSession(tc.Session, proxy.CookieCipher) diff --git a/internal/proxy/providers/providers.go b/internal/proxy/providers/providers.go index ac804103..fb584529 100644 --- a/internal/proxy/providers/providers.go +++ b/internal/proxy/providers/providers.go @@ -1,12 +1,7 @@ package providers import ( - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "fmt" "net/url" - "time" "github.com/datadog/datadog-go/statsd" ) @@ -19,49 +14,11 @@ type Provider interface { UserGroups(string, []string) ([]string, error) ValidateSessionState(*SessionState, []string) bool RefreshSession(*SessionState, []string) (bool, error) + GetSignInURL(redirectURL *url.URL, state string) *url.URL + GetSignOutURL(redirectURL *url.URL) *url.URL } // New returns a new sso Provider func New(provider string, p *ProviderData, sc *statsd.Client) Provider { return NewSSOProvider(p, sc) } - -// GetSignInURL with typical oauth parameters -func GetSignInURL(data *ProviderData, redirectURL *url.URL, state string) *url.URL { - var a url.URL - a = *data.SignInURL - now := time.Now() - rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) - params.Set("redirect_uri", rawRedirect) - params.Add("scope", data.Scope) - params.Set("client_id", data.ClientID) - params.Set("response_type", "code") - params.Add("state", state) - params.Set("ts", fmt.Sprint(now.Unix())) - params.Set("sig", signRedirectURL(data.ClientSecret, rawRedirect, now)) - a.RawQuery = params.Encode() - return &a -} - -// GetSignOutURL creates and returns the sign out URL, given a redirectURL -func GetSignOutURL(data *ProviderData, redirectURL *url.URL) *url.URL { - var a url.URL - a = *data.SignOutURL - now := time.Now() - rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) - params.Add("redirect_uri", rawRedirect) - params.Set("ts", fmt.Sprint(now.Unix())) - params.Set("sig", signRedirectURL(data.ClientSecret, rawRedirect, now)) - a.RawQuery = params.Encode() - return &a -} - -// signRedirectURL signs the redirect url string, given a timestamp, and returns it -func signRedirectURL(clientSecret, rawRedirect string, timestamp time.Time) string { - h := hmac.New(sha256.New, []byte(clientSecret)) - h.Write([]byte(rawRedirect)) - h.Write([]byte(fmt.Sprint(timestamp.Unix()))) - return base64.URLEncoding.EncodeToString(h.Sum(nil)) -} diff --git a/internal/proxy/providers/singleflight_middleware.go b/internal/proxy/providers/singleflight_middleware.go index 2cfa1a7b..3cee4cbf 100644 --- a/internal/proxy/providers/singleflight_middleware.go +++ b/internal/proxy/providers/singleflight_middleware.go @@ -3,6 +3,7 @@ package providers import ( "errors" "fmt" + "net/url" "sort" "strings" @@ -127,3 +128,13 @@ func (p *SingleFlightProvider) RefreshSession(s *SessionState, allowedGroups []s return r, nil } + +// GetSignInURL calls the GetSignInURL for the provider, which will return the sign in url +func (p *SingleFlightProvider) GetSignInURL(redirectURI *url.URL, finalRedirect string) *url.URL { + return p.provider.GetSignInURL(redirectURI, finalRedirect) +} + +// GetSignOutURL calls the GetSignOutURL for the provider, which will return the sign out url +func (p *SingleFlightProvider) GetSignOutURL(redirectURI *url.URL) *url.URL { + return p.provider.GetSignOutURL(redirectURI) +} diff --git a/internal/proxy/providers/sso.go b/internal/proxy/providers/sso.go index ffdb7e1c..1f7040cf 100644 --- a/internal/proxy/providers/sso.go +++ b/internal/proxy/providers/sso.go @@ -2,6 +2,9 @@ package providers import ( "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -383,3 +386,43 @@ func (p *SSOProvider) ValidateSessionState(s *SessionState, allowedGroups []stri return true } + +// signRedirectURL signs the redirect url string, given a timestamp, and returns it +func signRedirectURL(clientSecret, rawRedirect string, timestamp time.Time) string { + h := hmac.New(sha256.New, []byte(clientSecret)) + h.Write([]byte(rawRedirect)) + h.Write([]byte(fmt.Sprint(timestamp.Unix()))) + return base64.URLEncoding.EncodeToString(h.Sum(nil)) +} + +// GetSignInURL with typical oauth parameters +func (p *SSOProvider) GetSignInURL(redirectURL *url.URL, state string) *url.URL { + var a url.URL + a = *p.Data().SignInURL + now := time.Now() + rawRedirect := redirectURL.String() + params, _ := url.ParseQuery(a.RawQuery) + params.Set("redirect_uri", rawRedirect) + params.Add("scope", p.Data().Scope) + params.Set("client_id", p.Data().ClientID) + params.Set("response_type", "code") + params.Add("state", state) + params.Set("ts", fmt.Sprint(now.Unix())) + params.Set("sig", signRedirectURL(p.Data().ClientSecret, rawRedirect, now)) + a.RawQuery = params.Encode() + return &a +} + +// GetSignOutURL creates and returns the sign out URL, given a redirectURL +func (p *SSOProvider) GetSignOutURL(redirectURL *url.URL) *url.URL { + var a url.URL + a = *p.Data().SignOutURL + now := time.Now() + rawRedirect := redirectURL.String() + params, _ := url.ParseQuery(a.RawQuery) + params.Add("redirect_uri", rawRedirect) + params.Set("ts", fmt.Sprint(now.Unix())) + params.Set("sig", signRedirectURL(p.Data().ClientSecret, rawRedirect, now)) + a.RawQuery = params.Encode() + return &a +} diff --git a/internal/proxy/providers/test_provider.go b/internal/proxy/providers/test_provider.go index bb91e153..ef20e4ca 100644 --- a/internal/proxy/providers/test_provider.go +++ b/internal/proxy/providers/test_provider.go @@ -6,13 +6,12 @@ import ( // TestProvider is a mock provider type TestProvider struct { + RefreshSessionFunc func(*SessionState, []string) (bool, error) + ValidateSessionFunc func(*SessionState, []string) bool + RedeemFunc func(string, string) (*SessionState, error) + UserGroupsFunc func(string, []string) ([]string, error) + ValidateGroupsFunc func(string, []string) ([]string, bool, error) *ProviderData - EmailAddress string - ValidToken bool - ValidGroup bool - Refreshed bool - Session *SessionState - Groups []string } // NewTestProvider returns a new TestProvider @@ -42,31 +41,40 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { }, Scope: "profile.email", }, - EmailAddress: emailAddress, } } // ValidateSessionState mocks the ValidateSessionState function -func (tp *TestProvider) ValidateSessionState(*SessionState, []string) bool { - return tp.ValidToken +func (tp *TestProvider) ValidateSessionState(s *SessionState, groups []string) bool { + return tp.ValidateSessionFunc(s, groups) } // Redeem mocks the provider Redeem function -func (tp *TestProvider) Redeem(string, string) (*SessionState, error) { - return tp.Session, nil +func (tp *TestProvider) Redeem(redirectURL string, token string) (*SessionState, error) { + return tp.RedeemFunc(redirectURL, token) } // RefreshSession mocks the RefreshSession function -func (tp *TestProvider) RefreshSession(*SessionState, []string) (bool, error) { - return tp.Refreshed, nil +func (tp *TestProvider) RefreshSession(s *SessionState, g []string) (bool, error) { + return tp.RefreshSessionFunc(s, g) } // UserGroups mocks the UserGroups function -func (tp *TestProvider) UserGroups(string, []string) ([]string, error) { - return tp.Groups, nil +func (tp *TestProvider) UserGroups(email string, groups []string) ([]string, error) { + return tp.UserGroupsFunc(email, groups) } // ValidateGroup mocks the ValidateGroup function -func (tp *TestProvider) ValidateGroup(string, []string) ([]string, bool, error) { - return tp.Groups, tp.ValidGroup, nil +func (tp *TestProvider) ValidateGroup(email string, groups []string) ([]string, bool, error) { + return tp.ValidateGroupsFunc(email, groups) +} + +// GetSignOutURL mocks GetSignOutURL function +func (tp *TestProvider) GetSignOutURL(redirectURL *url.URL) *url.URL { + return tp.Data().SignOutURL +} + +// GetSignInURL mocks GetSignInURL +func (tp *TestProvider) GetSignInURL(redirectURL *url.URL, state string) *url.URL { + return tp.Data().SignInURL }