diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index d807a340..6b0463ac 100755 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -423,20 +423,21 @@ func (p *OAuthProxy) GetRedirectURL(host string) *url.URL { return &u } -func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { +func (p *OAuthProxy) redeemCode(host, code string) (*providers.SessionState, error) { if code == "" { return nil, errors.New("missing code") } redirectURL := p.GetRedirectURL(host) - s, err = p.provider.Redeem(redirectURL.String(), code) + s, err := p.provider.Redeem(redirectURL.String(), code) if err != nil { - return + return s, err } if s.Email == "" { - s.Email, err = p.provider.GetEmailAddress(s) + return s, errors.New("invalid email address") } - return + + return s, nil } // MakeSessionCookie constructs a session cookie given the request, an expiration time and the current time. diff --git a/internal/proxy/oauthproxy_test.go b/internal/proxy/oauthproxy_test.go index fe539363..f97aad5a 100644 --- a/internal/proxy/oauthproxy_test.go +++ b/internal/proxy/oauthproxy_test.go @@ -438,51 +438,12 @@ func TestFavicon(t *testing.T) { testutil.Equal(t, http.StatusNotFound, rw.Code) } -type TestProvider struct { - *providers.ProviderData - EmailAddress string - ValidToken bool -} - -func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { - return &TestProvider{ - ProviderData: &providers.ProviderData{ - ProviderName: "Test Provider", - SignInURL: &url.URL{ - Scheme: "http", - Host: providerURL.Host, - Path: "/oauth/authorize", - }, - RedeemURL: &url.URL{ - Scheme: "http", - Host: providerURL.Host, - Path: "/oauth/token", - }, - ProfileURL: &url.URL{ - Scheme: "http", - Host: providerURL.Host, - Path: "/api/v1/profile", - }, - Scope: "profile.email", - }, - EmailAddress: emailAddress, - } -} - -func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { - return tp.EmailAddress, nil -} - -func (tp *TestProvider) ValidateSessionState(session *providers.SessionState, g []string) bool { - return tp.ValidToken -} - type ProcessCookieTest struct { opts *Options proxy *OAuthProxy rw *httptest.ResponseRecorder req *http.Request - provider TestProvider + provider providers.TestProvider responseCode int validateUser bool } @@ -508,8 +469,8 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { return nil }) - pcTest.proxy.provider = &TestProvider{ - ValidToken: opts.providerValidateCookieResponse, + pcTest.proxy.provider = &providers.TestProvider{ + ValidateSessionFunc: func(*providers.SessionState, []string) bool { return opts.providerValidateCookieResponse }, } pcTest.rw = httptest.NewRecorder() @@ -695,7 +656,7 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { opts.Validate() upstreamURL, _ := url.Parse(upstream.URL) - opts.provider = NewTestProvider(upstreamURL, "") + opts.provider = providers.NewTestProvider(upstreamURL, "") proxy, _ := NewOAuthProxy(opts) rw := httptest.NewRecorder() @@ -749,7 +710,7 @@ func TestAuthSkipRequests(t *testing.T) { opts.Validate() upstreamURL, _ := url.Parse(upstream.URL) - opts.provider = NewTestProvider(upstreamURL, "") + opts.provider = providers.NewTestProvider(upstreamURL, "") proxy, _ := NewOAuthProxy(opts) @@ -829,7 +790,7 @@ func TestMultiAuthSkipRequests(t *testing.T) { opts.Validate() upstreamFooURL, _ := url.Parse(upstreamFoo.URL) - opts.provider = NewTestProvider(upstreamFooURL, "") + opts.provider = providers.NewTestProvider(upstreamFooURL, "") proxy, _ := NewOAuthProxy(opts) @@ -921,7 +882,7 @@ func NewSignatureTest(key string) *SignatureTest { } provider := httptest.NewServer(http.HandlerFunc(providerHandler)) providerURL, _ := url.Parse(provider.URL) - opts.provider = NewTestProvider(providerURL, "email1@example.com") + opts.provider = providers.NewTestProvider(providerURL, "email1@example.com") opts.upstreamConfigs = generateSignatureTestUpstreamConfigs(key, upstream.URL) opts.Validate() @@ -1041,7 +1002,7 @@ func TestHeadersSentToUpstreams(t *testing.T) { opts.upstreamConfigs = generateTestUpstreamConfigs(upstream.URL) opts.Validate() providerURL, _ := url.Parse("http://sso-auth.example.com/") - opts.provider = NewTestProvider(providerURL, "") + opts.provider = providers.NewTestProvider(providerURL, "") state := testSession() state.Email = "foo@example.com" @@ -1098,20 +1059,6 @@ func TestHeadersSentToUpstreams(t *testing.T) { } -type testAuthenticateProvider struct { - *providers.ProviderData - refreshSessionFunc func(*providers.SessionState, []string) (bool, error) - validateSessionFunc func(*providers.SessionState, []string) bool -} - -func (tap *testAuthenticateProvider) RefreshSession(s *providers.SessionState, g []string) (bool, error) { - return tap.refreshSessionFunc(s, g) -} - -func (tap *testAuthenticateProvider) ValidateSessionState(s *providers.SessionState, g []string) bool { - return tap.validateSessionFunc(s, g) -} - func TestAuthenticate(t *testing.T) { // Constants to represent possible cookie behaviors. const ( @@ -1261,9 +1208,9 @@ func TestAuthenticate(t *testing.T) { opts.upstreamConfigs = generateTestUpstreamConfigs("foo-internal.sso.dev") opts.Validate() proxy, _ := NewOAuthProxy(opts, testValidatorFunc(true), testCookieCipher(tc.Cipher)) - proxy.provider = &testAuthenticateProvider{ - refreshSessionFunc: tc.RefreshSessionFunc, - validateSessionFunc: tc.ValidateSessionFunc, + proxy.provider = &providers.TestProvider{ + RefreshSessionFunc: tc.RefreshSessionFunc, + ValidateSessionFunc: tc.ValidateSessionFunc, } value, err := providers.MarshalSession(tc.Session, proxy.CookieCipher) @@ -1518,7 +1465,7 @@ func TestPing(t *testing.T) { opts.Validate() providerURL, _ := url.Parse("http://sso-auth.example.com/") - opts.provider = NewTestProvider(providerURL, "") + opts.provider = providers.NewTestProvider(providerURL, "") proxy, _ := NewOAuthProxy(opts) state := testSession() @@ -1597,7 +1544,7 @@ func TestSecurityHeaders(t *testing.T) { opts.Validate() providerURL, _ := url.Parse("http://sso-auth.example.com/") - opts.provider = NewTestProvider(providerURL, "") + opts.provider = providers.NewTestProvider(providerURL, "") proxy, _ := NewOAuthProxy(opts, testValidatorFunc(true)) @@ -1741,7 +1688,7 @@ func TestHeaderOverrides(t *testing.T) { opts.Validate() providerURL, _ := url.Parse("http://sso-auth.example.com/") - opts.provider = NewTestProvider(providerURL, "") + opts.provider = providers.NewTestProvider(providerURL, "") proxy, _ := NewOAuthProxy(opts, testValidatorFunc(true)) @@ -1785,7 +1732,7 @@ func TestHTTPSRedirect(t *testing.T) { defer upstream.Close() providerURL, _ := url.Parse("http://sso-auth.example.com/") - provider := NewTestProvider(providerURL, "") + provider := providers.NewTestProvider(providerURL, "") state := testSession() testCases := []struct { diff --git a/internal/proxy/providers/internal_util.go b/internal/proxy/providers/internal_util.go index c3714427..daf65001 100644 --- a/internal/proxy/providers/internal_util.go +++ b/internal/proxy/providers/internal_util.go @@ -1,8 +1,6 @@ package providers import ( - "io/ioutil" - "net/http" "net/url" log "github.com/buzzfeed/sso/internal/pkg/logging" @@ -45,41 +43,3 @@ func stripParam(param, endpoint string) string { return endpoint } - -// validateToken returns true if token is valid -func validateToken(p Provider, accessToken string, header http.Header) bool { - logger := log.NewLogEntry() - - if accessToken == "" || p.Data().ValidateURL == nil { - return false - } - endpoint := p.Data().ValidateURL.String() - if len(header) == 0 { - params := url.Values{"access_token": {accessToken}} - endpoint = endpoint + "?" + params.Encode() - } - - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - logger.Error(err, "token validation request failed") - return false - } - req.Header = header - - resp, err := httpClient.Do(req) - if err != nil { - logger.Error(err, "token validation request failed") - return false - } - - body, _ := ioutil.ReadAll(resp.Body) - resp.Body.Close() - logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) - - if resp.StatusCode == 200 { - return true - } - logger.WithHTTPStatus(resp.StatusCode).WithResponseBody(body).Info( - "token validation request failed") - return false -} diff --git a/internal/proxy/providers/internal_util_test.go b/internal/proxy/providers/internal_util_test.go deleted file mode 100644 index ca5a070b..00000000 --- a/internal/proxy/providers/internal_util_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package providers - -import ( - "errors" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/buzzfeed/sso/internal/pkg/testutil" -) - -type ValidateSessionStateTestProvider struct { - *ProviderData -} - -func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { - return "", errors.New("not implemented") -} - -// Note that we're testing the internal validateToken() used to implement -// several Provider's ValidateSessionState() implementations -func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState, g []string) bool { - return false -} - -type ValidateSessionStateTest struct { - backend *httptest.Server - responseCode int - provider *ValidateSessionStateTestProvider - header http.Header -} - -func NewValidateSessionStateTest() *ValidateSessionStateTest { - var vtTest ValidateSessionStateTest - - vtTest.backend = httptest.NewServer( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/oauth/tokeninfo" { - w.WriteHeader(500) - w.Write([]byte("unknown URL")) - } - tokenParam := r.FormValue("access_token") - if tokenParam == "" { - missing := false - receivedHeaders := r.Header - for k := range vtTest.header { - received := receivedHeaders.Get(k) - expected := vtTest.header.Get(k) - if received == "" || received != expected { - missing = true - } - } - if missing { - w.WriteHeader(500) - w.Write([]byte("no token param and missing or incorrect headers")) - } - } - w.WriteHeader(vtTest.responseCode) - w.Write([]byte("only code matters; contents disregarded")) - - })) - backendURL, _ := url.Parse(vtTest.backend.URL) - vtTest.provider = &ValidateSessionStateTestProvider{ - ProviderData: &ProviderData{ - ValidateURL: &url.URL{ - Scheme: "http", - Host: backendURL.Host, - Path: "/oauth/tokeninfo", - }, - }, - } - vtTest.responseCode = 200 - return &vtTest -} - -func (vtTest *ValidateSessionStateTest) Close() { - vtTest.backend.Close() -} - -func TestValidateSessionStateValidToken(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - testutil.Equal(t, true, validateToken(vtTest.provider, "foobar", nil)) -} - -func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - vtTest.header = make(http.Header) - vtTest.header.Set("Authorization", "Bearer foobar") - testutil.Equal(t, true, - validateToken(vtTest.provider, "foobar", vtTest.header)) -} - -func TestValidateSessionStateEmptyToken(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - testutil.Equal(t, false, validateToken(vtTest.provider, "", nil)) -} - -func TestValidateSessionStateEmptyValidateURL(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - vtTest.provider.Data().ValidateURL = nil - testutil.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) -} - -func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { - vtTest := NewValidateSessionStateTest() - // Close immediately to simulate a network failure - vtTest.Close() - testutil.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) -} - -func TestValidateSessionStateExpiredToken(t *testing.T) { - vtTest := NewValidateSessionStateTest() - defer vtTest.Close() - vtTest.responseCode = 401 - testutil.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) -} - -func TestStripTokenNotPresent(t *testing.T) { - test := "http://local.test/api/test?a=1&b=2" - testutil.Equal(t, test, stripToken(test)) -} - -func TestStripToken(t *testing.T) { - test := "http://local.test/api/test?access_token=deadbeef&b=1&c=2" - expected := "http://local.test/api/test?access_token=dead...&b=1&c=2" - testutil.Equal(t, expected, stripToken(test)) -} diff --git a/internal/proxy/providers/provider_data.go b/internal/proxy/providers/provider_data.go index c6068962..f221dfd3 100644 --- a/internal/proxy/providers/provider_data.go +++ b/internal/proxy/providers/provider_data.go @@ -8,21 +8,19 @@ import ( // ProviderData holds the fields associated with providers // necessary to implement the Provider interface. type ProviderData struct { - ProviderName string - ProviderURL *url.URL - ProxyProviderURL *url.URL - ClientID string - ClientSecret string - SignInURL *url.URL - SignOutURL *url.URL - RedeemURL *url.URL - ProxyRedeemURL *url.URL - RefreshURL *url.URL - ProfileURL *url.URL - ProtectedResource *url.URL - ValidateURL *url.URL - Scope string - ApprovalPrompt string + ProviderName string + ProviderURL *url.URL + ProxyProviderURL *url.URL + ClientID string + ClientSecret string + SignInURL *url.URL + SignOutURL *url.URL + RedeemURL *url.URL + ProxyRedeemURL *url.URL + RefreshURL *url.URL + ProfileURL *url.URL + ValidateURL *url.URL + Scope string SessionValidTTL time.Duration SessionLifetimeTTL time.Duration diff --git a/internal/proxy/providers/provider_default.go b/internal/proxy/providers/provider_default.go deleted file mode 100644 index 02496b68..00000000 --- a/internal/proxy/providers/provider_default.go +++ /dev/null @@ -1,146 +0,0 @@ -package providers - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/url" - "time" -) - -// Redeem takes a redirectURL and code, creates some params and redeems the request -func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { - if code == "" { - err = errors.New("missing code") - return - } - - params := url.Values{} - params.Add("redirect_uri", redirectURL) - params.Add("client_id", p.ClientID) - params.Add("client_secret", p.ClientSecret) - params.Add("code", code) - params.Add("grant_type", "authorization_code") - if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { - params.Add("resource", p.ProtectedResource.String()) - } - - var req *http.Request - req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = httpClient.Do(req) - if err != nil { - return nil, err - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } - - // blindly try json and x-www-form-urlencoded - var jsonResponse struct { - AccessToken string `json:"access_token"` - } - err = json.Unmarshal(body, &jsonResponse) - if err == nil { - s = &SessionState{ - AccessToken: jsonResponse.AccessToken, - } - return - } - - var v url.Values - v, err = url.ParseQuery(string(body)) - if err != nil { - return - } - if a := v.Get("access_token"); a != "" { - s = &SessionState{AccessToken: a} - } else { - err = fmt.Errorf("no access token found %s", body) - } - return -} - -// GetSignInURL with typical oauth parameters -func (p *ProviderData) GetSignInURL(redirectURL *url.URL, state string) *url.URL { - var a url.URL - a = *p.SignInURL - now := time.Now() - rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) - params.Set("redirect_uri", rawRedirect) - params.Add("scope", p.Scope) - params.Set("client_id", p.ClientID) - params.Set("response_type", "code") - params.Add("state", state) - params.Set("ts", fmt.Sprint(now.Unix())) - params.Set("sig", p.signRedirectURL(rawRedirect, now)) - a.RawQuery = params.Encode() - return &a -} - -// GetSignOutURL creates and returns the sign out URL, given a redirectURL -func (p *ProviderData) GetSignOutURL(redirectURL *url.URL) *url.URL { - var a url.URL - a = *p.SignOutURL - now := time.Now() - rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) - params.Add("redirect_uri", rawRedirect) - params.Set("ts", fmt.Sprint(now.Unix())) - params.Set("sig", p.signRedirectURL(rawRedirect, now)) - a.RawQuery = params.Encode() - return &a -} - -// signRedirectURL signs the redirect url string, given a timestamp, and returns it -func (p *ProviderData) signRedirectURL(rawRedirect string, timestamp time.Time) string { - h := hmac.New(sha256.New, []byte(p.ClientSecret)) - h.Write([]byte(rawRedirect)) - h.Write([]byte(fmt.Sprint(timestamp.Unix()))) - return base64.URLEncoding.EncodeToString(h.Sum(nil)) -} - -// GetEmailAddress returns an email address or error -func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { - return "", errors.New("not implemented") -} - -// ValidateGroup validates that the provided email exists in the configured provider email group(s). -func (p *ProviderData) ValidateGroup(_ string, _ []string) ([]string, bool, error) { - return []string{}, true, nil -} - -// UserGroups returns a list of users -func (p *ProviderData) UserGroups(string, []string) ([]string, error) { - return []string{}, nil -} - -// ValidateSessionState calls to validate the token given the session and groups -func (p *ProviderData) ValidateSessionState(s *SessionState, groups []string) bool { - return validateToken(p, s.AccessToken, nil) -} - -// RefreshSession returns a boolean or error -func (p *ProviderData) RefreshSession(s *SessionState, group []string) (bool, error) { - return false, nil -} diff --git a/internal/proxy/providers/provider_default_test.go b/internal/proxy/providers/provider_default_test.go deleted file mode 100644 index 2aa48bfe..00000000 --- a/internal/proxy/providers/provider_default_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package providers - -import ( - "testing" - "time" - - "github.com/buzzfeed/sso/internal/pkg/testutil" -) - -func TestRefresh(t *testing.T) { - p := &ProviderData{} - refreshed, err := p.RefreshSession(&SessionState{ - RefreshDeadline: time.Now().Add(time.Duration(-11) * time.Minute), - }, - []string{}, - ) - testutil.Equal(t, false, refreshed) - testutil.Equal(t, nil, err) -} diff --git a/internal/proxy/providers/providers.go b/internal/proxy/providers/providers.go index 4d33a796..fb584529 100644 --- a/internal/proxy/providers/providers.go +++ b/internal/proxy/providers/providers.go @@ -9,14 +9,13 @@ import ( // Provider is an interface exposing functions necessary to authenticate with a given provider. type Provider interface { Data() *ProviderData - GetEmailAddress(*SessionState) (string, error) Redeem(string, string) (*SessionState, error) ValidateGroup(string, []string) ([]string, bool, error) UserGroups(string, []string) ([]string, error) ValidateSessionState(*SessionState, []string) bool - GetSignInURL(redirectURL *url.URL, finalRedirect string) *url.URL - GetSignOutURL(redirectURL *url.URL) *url.URL RefreshSession(*SessionState, []string) (bool, error) + GetSignInURL(redirectURL *url.URL, state string) *url.URL + GetSignOutURL(redirectURL *url.URL) *url.URL } // New returns a new sso Provider diff --git a/internal/proxy/providers/singleflight_middleware.go b/internal/proxy/providers/singleflight_middleware.go index 7de4d3ec..3cee4cbf 100644 --- a/internal/proxy/providers/singleflight_middleware.go +++ b/internal/proxy/providers/singleflight_middleware.go @@ -64,11 +64,6 @@ func (p *SingleFlightProvider) Data() *ProviderData { return p.provider.Data() } -// GetEmailAddress calls the provider function getEmailAddress -func (p *SingleFlightProvider) GetEmailAddress(s *SessionState) (string, error) { - return p.provider.GetEmailAddress(s) -} - // Redeem takes the redirectURL and a code and calls the provider function Redeem func (p *SingleFlightProvider) Redeem(redirectURL, code string) (*SessionState, error) { return p.provider.Redeem(redirectURL, code) @@ -116,16 +111,6 @@ func (p *SingleFlightProvider) ValidateSessionState(s *SessionState, allowedGrou return valid } -// GetSignInURL calls the GetSignInURL for the provider, which will return the sign in url -func (p *SingleFlightProvider) GetSignInURL(redirectURI *url.URL, finalRedirect string) *url.URL { - return p.provider.GetSignInURL(redirectURI, finalRedirect) -} - -// GetSignOutURL calls the GetSignOutURL for the provider, which will return the sign out url -func (p *SingleFlightProvider) GetSignOutURL(redirectURI *url.URL) *url.URL { - return p.provider.GetSignOutURL(redirectURI) -} - // RefreshSession takes in a SessionState and allowedGroups and // returns false if the session is not refreshed and true if it is. func (p *SingleFlightProvider) RefreshSession(s *SessionState, allowedGroups []string) (bool, error) { @@ -143,3 +128,13 @@ func (p *SingleFlightProvider) RefreshSession(s *SessionState, allowedGroups []s return r, nil } + +// GetSignInURL calls the GetSignInURL for the provider, which will return the sign in url +func (p *SingleFlightProvider) GetSignInURL(redirectURI *url.URL, finalRedirect string) *url.URL { + return p.provider.GetSignInURL(redirectURI, finalRedirect) +} + +// GetSignOutURL calls the GetSignOutURL for the provider, which will return the sign out url +func (p *SingleFlightProvider) GetSignOutURL(redirectURI *url.URL) *url.URL { + return p.provider.GetSignOutURL(redirectURI) +} diff --git a/internal/proxy/providers/sso.go b/internal/proxy/providers/sso.go index 5b8d005c..38fe23d5 100644 --- a/internal/proxy/providers/sso.go +++ b/internal/proxy/providers/sso.go @@ -2,6 +2,9 @@ package providers import ( "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -385,3 +388,41 @@ func (p *SSOProvider) ValidateSessionState(s *SessionState, allowedGroups []stri return true } + +// signRedirectURL signs the redirect url string, given a timestamp, and returns it +func (p *SSOProvider) signRedirectURL(rawRedirect string, timestamp time.Time) string { + h := hmac.New(sha256.New, []byte(p.Data().ClientSecret)) + h.Write([]byte(rawRedirect)) + h.Write([]byte(fmt.Sprint(timestamp.Unix()))) + return base64.URLEncoding.EncodeToString(h.Sum(nil)) +} + +// GetSignInURL with typical oauth parameters +func (p *SSOProvider) GetSignInURL(redirectURL *url.URL, state string) *url.URL { + a := *p.Data().SignInURL + now := time.Now() + rawRedirect := redirectURL.String() + params, _ := url.ParseQuery(a.RawQuery) + params.Set("redirect_uri", rawRedirect) + params.Add("scope", p.Data().Scope) + params.Set("client_id", p.Data().ClientID) + params.Set("response_type", "code") + params.Add("state", state) + params.Set("ts", fmt.Sprint(now.Unix())) + params.Set("sig", p.signRedirectURL(rawRedirect, now)) + a.RawQuery = params.Encode() + return &a +} + +// GetSignOutURL creates and returns the sign out URL, given a redirectURL +func (p *SSOProvider) GetSignOutURL(redirectURL *url.URL) *url.URL { + a := *p.Data().SignOutURL + now := time.Now() + rawRedirect := redirectURL.String() + params, _ := url.ParseQuery(a.RawQuery) + params.Add("redirect_uri", rawRedirect) + params.Set("ts", fmt.Sprint(now.Unix())) + params.Set("sig", p.signRedirectURL(rawRedirect, now)) + a.RawQuery = params.Encode() + return &a +} diff --git a/internal/proxy/providers/test_provider.go b/internal/proxy/providers/test_provider.go new file mode 100644 index 00000000..ef20e4ca --- /dev/null +++ b/internal/proxy/providers/test_provider.go @@ -0,0 +1,80 @@ +package providers + +import ( + "net/url" +) + +// TestProvider is a mock provider +type TestProvider struct { + RefreshSessionFunc func(*SessionState, []string) (bool, error) + ValidateSessionFunc func(*SessionState, []string) bool + RedeemFunc func(string, string) (*SessionState, error) + UserGroupsFunc func(string, []string) ([]string, error) + ValidateGroupsFunc func(string, []string) ([]string, bool, error) + *ProviderData +} + +// NewTestProvider returns a new TestProvider +func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { + return &TestProvider{ + ProviderData: &ProviderData{ + ProviderName: "Test Provider", + SignInURL: &url.URL{ + Scheme: "http", + Host: providerURL.Host, + Path: "/oauth/authorize", + }, + RedeemURL: &url.URL{ + Scheme: "http", + Host: providerURL.Host, + Path: "/oauth/token", + }, + ProfileURL: &url.URL{ + Scheme: "http", + Host: providerURL.Host, + Path: "/api/v1/profile", + }, + SignOutURL: &url.URL{ + Scheme: "http", + Host: providerURL.Host, + Path: "/oauth/sign_out", + }, + Scope: "profile.email", + }, + } +} + +// ValidateSessionState mocks the ValidateSessionState function +func (tp *TestProvider) ValidateSessionState(s *SessionState, groups []string) bool { + return tp.ValidateSessionFunc(s, groups) +} + +// Redeem mocks the provider Redeem function +func (tp *TestProvider) Redeem(redirectURL string, token string) (*SessionState, error) { + return tp.RedeemFunc(redirectURL, token) +} + +// RefreshSession mocks the RefreshSession function +func (tp *TestProvider) RefreshSession(s *SessionState, g []string) (bool, error) { + return tp.RefreshSessionFunc(s, g) +} + +// UserGroups mocks the UserGroups function +func (tp *TestProvider) UserGroups(email string, groups []string) ([]string, error) { + return tp.UserGroupsFunc(email, groups) +} + +// ValidateGroup mocks the ValidateGroup function +func (tp *TestProvider) ValidateGroup(email string, groups []string) ([]string, bool, error) { + return tp.ValidateGroupsFunc(email, groups) +} + +// GetSignOutURL mocks GetSignOutURL function +func (tp *TestProvider) GetSignOutURL(redirectURL *url.URL) *url.URL { + return tp.Data().SignOutURL +} + +// GetSignInURL mocks GetSignInURL +func (tp *TestProvider) GetSignInURL(redirectURL *url.URL, state string) *url.URL { + return tp.Data().SignInURL +}