diff --git a/cmd/sso-auth/main.go b/cmd/sso-auth/main.go index 2fcf215b..367ea33c 100644 --- a/cmd/sso-auth/main.go +++ b/cmd/sso-auth/main.go @@ -7,7 +7,6 @@ import ( "github.com/buzzfeed/sso/internal/auth" log "github.com/buzzfeed/sso/internal/pkg/logging" - "github.com/buzzfeed/sso/internal/pkg/options" ) func init() { @@ -29,31 +28,28 @@ func main() { os.Exit(1) } - emailValidator := func(p *auth.Authenticator) error { - if len(opts.EmailAddresses) != 0 { - p.Validator = options.NewEmailAddressValidator(opts.EmailAddresses) - } else { - p.Validator = options.NewEmailDomainValidator(opts.EmailDomains) - } - return nil + statsdClient, err := auth.NewStatsdClient(opts.StatsdHost, opts.StatsdPort) + if err != nil { + logger.Error(err, "error creating statsd client") + os.Exit(1) } - authenticator, err := auth.NewAuthenticator(opts, emailValidator, auth.AssignProvider(opts), auth.SetCookieStore(opts), auth.AssignStatsdClient(opts)) + authMux, err := auth.NewAuthenticatorMux(opts, statsdClient) if err != nil { - logger.Error(err, "error creating new Authenticator") + logger.Error(err, "error creating new AuthenticatorMux") os.Exit(1) } - defer authenticator.Stop() + defer authMux.Stop() // we leave the message field blank, which will inherit the stdlib timeout page which is sufficient // and better than other naive messages we would currently place here - timeoutHandler := http.TimeoutHandler(authenticator.ServeMux, opts.RequestTimeout, "") + timeoutHandler := http.TimeoutHandler(authMux, opts.RequestTimeout, "") s := &http.Server{ Addr: fmt.Sprintf(":%d", opts.Port), ReadTimeout: opts.TCPReadTimeout, WriteTimeout: opts.TCPWriteTimeout, - Handler: auth.NewLoggingHandler(os.Stdout, timeoutHandler, opts.RequestLogging, authenticator.StatsdClient), + Handler: auth.NewLoggingHandler(os.Stdout, timeoutHandler, opts.RequestLogging, statsdClient), } logger.Fatal(s.ListenAndServe()) diff --git a/internal/auth/authenticator.go b/internal/auth/authenticator.go index 57633086..f2fbd1a9 100644 --- a/internal/auth/authenticator.go +++ b/internal/auth/authenticator.go @@ -87,7 +87,7 @@ type getProfileResponse struct { } // SetCookieStore sets the cookie store to use a miscreant cipher -func SetCookieStore(opts *Options) func(*Authenticator) error { +func SetCookieStore(opts *Options, providerSlug string) func(*Authenticator) error { return func(a *Authenticator) error { decodedAuthCodeSecret, err := base64.StdEncoding.DecodeString(opts.AuthCodeSecret) if err != nil { @@ -98,7 +98,8 @@ func SetCookieStore(opts *Options) func(*Authenticator) error { return err } - cookieStore, err := sessions.NewCookieStore(opts.CookieName, + cookieName := fmt.Sprintf("%s_%s", opts.CookieName, providerSlug) + cookieStore, err := sessions.NewCookieStore(cookieName, sessions.CreateMiscreantCookieCipher(opts.decodedCookieSecret), func(c *sessions.CookieStore) error { c.CookieDomain = opts.CookieDomain @@ -123,9 +124,6 @@ func SetCookieStore(opts *Options) func(*Authenticator) error { func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error) (*Authenticator, error) { logger := log.NewLogEntry() - redirectURL := opts.redirectURL - redirectURL.Path = "/oauth2/callback" - templates := templates.NewHTMLTemplate() proxyRootDomains := []string{} @@ -144,7 +142,6 @@ func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error) Host: opts.Host, CookieSecure: opts.CookieSecure, - redirectURL: redirectURL, SetXAuthRequest: opts.SetXAuthRequest, PassUserHeaders: opts.PassUserHeaders, SkipProviderButton: opts.SkipProviderButton, @@ -166,39 +163,19 @@ func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error) } func (p *Authenticator) newMux() http.Handler { - logger := log.NewLogEntry() - - mux := http.NewServeMux() - - // we setup global endpoints that should respond to any hostname - mux.HandleFunc("/ping", p.withMethods(p.PingPage, "GET")) - // we setup our service mux to handle service routes that use the required host header serviceMux := http.NewServeMux() - serviceMux.HandleFunc("/ping", p.withMethods(p.PingPage, "GET")) serviceMux.HandleFunc("/robots.txt", p.withMethods(p.RobotsTxt, "GET")) serviceMux.HandleFunc("/start", p.withMethods(p.OAuthStart, "GET")) serviceMux.HandleFunc("/sign_in", p.withMethods(p.validateClientID(p.validateRedirectURI(p.validateSignature(p.SignIn))), "GET")) serviceMux.HandleFunc("/sign_out", p.withMethods(p.validateRedirectURI(p.validateSignature(p.SignOut)), "GET", "POST")) - serviceMux.HandleFunc("/oauth2/callback", p.withMethods(p.OAuthCallback, "GET")) + serviceMux.HandleFunc("/callback", p.withMethods(p.OAuthCallback, "GET")) serviceMux.HandleFunc("/profile", p.withMethods(p.validateClientID(p.validateClientSecret(p.GetProfile)), "GET")) serviceMux.HandleFunc("/validate", p.withMethods(p.validateClientID(p.validateClientSecret(p.ValidateToken)), "GET")) serviceMux.HandleFunc("/redeem", p.withMethods(p.validateClientID(p.validateClientSecret(p.Redeem)), "POST")) serviceMux.HandleFunc("/refresh", p.withMethods(p.validateClientID(p.validateClientSecret(p.Refresh)), "POST")) - fsHandler, err := loadFSHandler() - if err != nil { - logger.Fatal(err) - } - serviceMux.Handle("/static/", http.StripPrefix("/static/", fsHandler)) - - // NOTE: we have to include trailing slash for the router to match the host header - host := p.Host - if !strings.HasSuffix(host, "/") { - host = fmt.Sprintf("%s/", host) - } - mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header - return setHeaders(mux) + return setHeaders(serviceMux) } // GetRedirectURI returns the redirect url for a given OAuthProxy, @@ -215,13 +192,8 @@ func (p *Authenticator) RobotsTxt(rw http.ResponseWriter, req *http.Request) { fmt.Fprintf(rw, "User-agent: *\nDisallow: /") } -// PingPage handles the /ping route -func (p *Authenticator) PingPage(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusOK) - fmt.Fprintf(rw, "OK") -} - type signInResp struct { + ProviderSlug string ProviderName string EmailDomains []string Redirect string @@ -237,13 +209,20 @@ func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, co // We don't want to rely on req.Host, as that can be attacked via Host header injection // This ends up looking like: // https://sso-auth.example.com/sign_in?client_id=...&redirect_uri=... - redirectURL := p.redirectURL.ResolveReference(req.URL) + path := strings.TrimPrefix(req.URL.Path, "/") + redirectURL := p.redirectURL.ResolveReference( + &url.URL{ + Path: path, + RawQuery: req.URL.RawQuery, + }, + ) // validateRedirectURI middleware already ensures that this is a valid URL destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri")) t := signInResp{ ProviderName: p.provider.Data().ProviderName, + ProviderSlug: p.provider.Data().ProviderSlug, EmailDomains: p.EmailDomains, Redirect: redirectURL.String(), Destination: destinationURL.Host, @@ -474,13 +453,14 @@ func (p *Authenticator) SignOut(rw http.ResponseWriter, req *http.Request) { } type signOutResp struct { - Version string - Redirect string - Signature string - Timestamp string - Message string - Destination string - Email string + ProviderSlug string + Version string + Redirect string + Signature string + Timestamp string + Message string + Destination string + Email string } // SignOutPage renders a sign out page with a message @@ -504,13 +484,14 @@ func (p *Authenticator) SignOutPage(rw http.ResponseWriter, req *http.Request, m } t := signOutResp{ - Version: VERSION, - Redirect: redirectURI, - Signature: signature, - Timestamp: timestamp, - Message: message, - Destination: destinationURL.Host, - Email: session.Email, + ProviderSlug: p.provider.Data().ProviderSlug, + Version: VERSION, + Redirect: redirectURI, + Signature: signature, + Timestamp: timestamp, + Message: message, + Destination: destinationURL.Host, + Email: session.Email, } p.templates.ExecuteTemplate(rw, "sign_out.html", t) return diff --git a/internal/auth/authenticator_test.go b/internal/auth/authenticator_test.go index 26beff6c..07da9413 100644 --- a/internal/auth/authenticator_test.go +++ b/internal/auth/authenticator_test.go @@ -66,6 +66,28 @@ func setTestProvider(provider *providers.TestProvider) func(*Authenticator) erro } } +func setMockValidator(response bool) func(*Authenticator) error { + return func(a *Authenticator) error { + a.Validator = func(string) bool { return response } + return nil + } +} + +func setRedirectURL(redirectURL *url.URL) func(*Authenticator) error { + return func(a *Authenticator) error { + a.redirectURL = redirectURL + return nil + } +} + +func assignProvider(opts *Options) func(*Authenticator) error { + return func(a *Authenticator) error { + var err error + a.provider, err = newProvider(opts) + return err + } +} + // generated using `openssl rand 32 -base64` var testEncodedCookieSecret = "x7xzsM1Ky4vGQPwqy6uTztfr3jtm/pIdRbJXgE0q8kU=" var testAuthCodeSecret = "qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=" @@ -83,6 +105,10 @@ func testOpts(t *testing.T, proxyClientID, proxyClientSecret string) *Options { opts.AuthCodeSecret = testAuthCodeSecret opts.ProxyRootDomains = []string{"example.com"} opts.Host = "/" + opts.EmailDomains = []string{"*"} + opts.StatsdPort = 8125 + opts.StatsdHost = "localhost" + opts.RedirectURL = "http://example.com" return opts } @@ -429,10 +455,13 @@ func TestSignIn(t *testing.T) { t.Run(tc.name, func(t *testing.T) { opts := testOpts(t, "test", "secret") opts.Validate() - auth, err := NewAuthenticator(opts, func(p *Authenticator) error { - p.Validator = func(string) bool { return tc.validEmail } - return nil - }, setMockSessionStore(tc.mockSessionStore), setMockTempl(), setMockAuthCodeCipher(tc.mockAuthCodeCipher, nil)) + auth, err := NewAuthenticator(opts, + setMockValidator(tc.validEmail), + setMockSessionStore(tc.mockSessionStore), + setMockTempl(), + setRedirectURL(opts.redirectURL), + setMockAuthCodeCipher(tc.mockAuthCodeCipher, nil), + ) testutil.Ok(t, err) // set test provider @@ -1477,7 +1506,7 @@ func TestGlobalHeaders(t *testing.T) { testCases := []struct { path string }{ - {"/oauth2/callback"}, + {"/callback"}, {"/ping"}, {"/profile"}, {"/redeem"}, @@ -1545,14 +1574,16 @@ func TestOAuthStart(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { opts := testOpts(t, "abced", "testtest") - opts.RedirectURL = "https://example.com/oauth2/callback" + opts.RedirectURL = "https://example.com/" opts.Validate() u, _ := url.Parse("http://example.com") provider := providers.NewTestProvider(u) - proxy, _ := NewAuthenticator(opts, setTestProvider(provider), func(p *Authenticator) error { - p.Validator = func(string) bool { return true } - return nil - }, setMockCSRFStore(&sessions.MockCSRFStore{})) + proxy, _ := NewAuthenticator(opts, + setTestProvider(provider), + setMockValidator(true), + setRedirectURL(opts.redirectURL), + setMockCSRFStore(&sessions.MockCSRFStore{}), + ) params := url.Values{} if tc.RedirectURI != "" { @@ -1585,67 +1616,11 @@ func TestOAuthStart(t *testing.T) { } } -func TestHostHeader(t *testing.T) { - testCases := []struct { - Name string - Host string - RequestHost string - Path string - ExpectedStatusCode int - }{ - { - Name: "reject requests with an invalid hostname", - Host: "example.com", - RequestHost: "unknown.com", - Path: "/robots.txt", - ExpectedStatusCode: http.StatusNotFound, - }, - { - Name: "allow requests to any hostname to /ping", - Host: "example.com", - RequestHost: "unknown.com", - Path: "/ping", - ExpectedStatusCode: http.StatusOK, - }, - { - Name: "allow requests with a valid hostname", - Host: "example.com", - RequestHost: "example.com", - Path: "/robots.txt", - ExpectedStatusCode: http.StatusOK, - }, - } - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - opts := testOpts(t, "abced", "testtest") - opts.Host = tc.Host - opts.Validate() - - proxy, _ := NewAuthenticator(opts, func(p *Authenticator) error { - p.Validator = func(string) bool { return true } - return nil - }) - - uri := fmt.Sprintf("http://%s%s", tc.RequestHost, tc.Path) - rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", uri, nil) - proxy.ServeMux.ServeHTTP(rw, req) - if rw.Code != tc.ExpectedStatusCode { - t.Errorf("got unexpected status code") - t.Errorf("want %v", tc.ExpectedStatusCode) - t.Errorf(" got %v", rw.Code) - t.Errorf(" headers %v", rw) - t.Errorf(" body: %q", rw.Body) - } - }) - } -} - func TestGoogleProviderApiSettings(t *testing.T) { opts := testOpts(t, "abced", "testtest") opts.Provider = "google" opts.Validate() - proxy, _ := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error { + proxy, _ := NewAuthenticator(opts, assignProvider(opts), func(p *Authenticator) error { p.Validator = func(string) bool { return true } return nil }) @@ -1664,7 +1639,7 @@ func TestGoogleGroupInvalidFile(t *testing.T) { opts.GoogleAdminEmail = "admin@example.com" opts.GoogleServiceAccountJSON = "file_doesnt_exist.json" opts.Validate() - _, err := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error { + _, err := NewAuthenticator(opts, assignProvider(opts), func(p *Authenticator) error { p.Validator = func(string) bool { return true } return nil }) @@ -1676,7 +1651,7 @@ func TestUnimplementedProvider(t *testing.T) { opts := testOpts(t, "abced", "testtest") opts.Provider = "null_provider" opts.Validate() - _, err := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error { + _, err := NewAuthenticator(opts, assignProvider(opts), func(p *Authenticator) error { p.Validator = func(string) bool { return true } return nil }) diff --git a/internal/auth/metrics.go b/internal/auth/metrics.go index 388a71bc..764c54b1 100644 --- a/internal/auth/metrics.go +++ b/internal/auth/metrics.go @@ -11,7 +11,7 @@ import ( "github.com/datadog/datadog-go/statsd" ) -func newStatsdClient(host string, port int) (*statsd.Client, error) { +func NewStatsdClient(host string, port int) (*statsd.Client, error) { client, err := statsd.New(net.JoinHostPort(host, strconv.Itoa(port))) if err != nil { return nil, err @@ -27,16 +27,16 @@ func newStatsdClient(host string, port int) (*statsd.Client, error) { func GetActionTag(req *http.Request) string { // only log metrics for these paths and actions pathToAction := map[string]string{ - "/robots.txt": "robots", - "/start": "start", - "/sign_in": "sign_in", - "/sign_out": "sign_out", - "/oauth2/callback": "callback", - "/profile": "profile", - "/validate": "validate", - "/redeem": "redeem", - "/refresh": "refresh", - "/ping": "ping", + "/robots.txt": "robots", + "/start": "start", + "/sign_in": "sign_in", + "/sign_out": "sign_out", + "/callback": "callback", + "/profile": "profile", + "/validate": "validate", + "/redeem": "redeem", + "/refresh": "refresh", + "/ping": "ping", } // get the action from the url path path := req.URL.Path diff --git a/internal/auth/metrics_test.go b/internal/auth/metrics_test.go index 4a28011b..f707d1b8 100644 --- a/internal/auth/metrics_test.go +++ b/internal/auth/metrics_test.go @@ -14,7 +14,7 @@ import ( func newTestStatsdClient(t *testing.T) (*statsd.Client, string, int) { - client, err := newStatsdClient("127.0.0.1", 8125) + client, err := NewStatsdClient("127.0.0.1", 8125) if err != nil { t.Fatalf("error starting new statsd client %s", err.Error()) } @@ -69,7 +69,7 @@ func TestNewStatsd(t *testing.T) { t.Fatalf("error while instantiating config options: %s", err.Error()) } opts.Validate() - client, err := newStatsdClient(tc.host, tc.port) + client, err := NewStatsdClient(tc.host, tc.port) if err != nil { t.Fatalf("error starting new statsd client: %s", err.Error()) } @@ -226,12 +226,12 @@ func TestGetActionTag(t *testing.T) { }, { name: "request with callback in the path", - url: "/oauth2/callback", + url: "/callback", expectedAction: "callback", }, { name: "request with sign_out in the path with query parameters", - url: "/oauth2/callback?query=parameter", + url: "/callback?query=parameter", expectedAction: "callback", }, { diff --git a/internal/auth/mux.go b/internal/auth/mux.go new file mode 100644 index 00000000..5f035fef --- /dev/null +++ b/internal/auth/mux.go @@ -0,0 +1,109 @@ +package auth + +import ( + "fmt" + "net/http" + "strings" + + "github.com/buzzfeed/sso/internal/auth/providers" + log "github.com/buzzfeed/sso/internal/pkg/logging" + "github.com/buzzfeed/sso/internal/pkg/options" + + "github.com/datadog/datadog-go/statsd" +) + +type AuthenticatorMux struct { + mux *http.ServeMux + authenticators []*Authenticator +} + +func NewAuthenticatorMux(opts *Options, statsdClient *statsd.Client) (*AuthenticatorMux, error) { + logger := log.NewLogEntry() + + emailValidator := func(p *Authenticator) error { + if len(opts.EmailAddresses) != 0 { + p.Validator = options.NewEmailAddressValidator(opts.EmailAddresses) + } else { + p.Validator = options.NewEmailDomainValidator(opts.EmailDomains) + } + return nil + } + + // one day, we will contruct more providers here + idp, err := newProvider(opts) + if err != nil { + logger.Error(err, "error creating new Identity Provider") + return nil, err + } + identityProviders := []providers.Provider{idp} + authenticators := []*Authenticator{} + + idpMux := http.NewServeMux() + for _, idp := range identityProviders { + idpSlug := idp.Data().ProviderSlug + authenticator, err := NewAuthenticator(opts, + emailValidator, + SetProvider(idp), + SetCookieStore(opts, idpSlug), + SetStatsdClient(statsdClient), + SetRedirectURL(opts, idpSlug), + ) + if err != nil { + logger.Error(err, "error creating new Authenticator") + return nil, err + } + + authenticators = append(authenticators, authenticator) + + // setup our mux with the idpslug as the first part of the path + idpMux.Handle( + fmt.Sprintf("/%s/", idpSlug), + http.StripPrefix(fmt.Sprintf("/%s", idpSlug), authenticator.ServeMux), + ) + + // we setup default routes for the default provider, mainly helpful for transitionary services + if idpSlug == opts.DefaultProvider { + idpMux.Handle("/", authenticator.ServeMux) + } + } + + // load static files + fsHandler, err := loadFSHandler() + if err != nil { + logger.Fatal(err) + } + idpMux.Handle("/static/", http.StripPrefix("/static/", fsHandler)) + + // NOTE: we have to include trailing slash for the router to match the host header + host := opts.Host + if !strings.HasSuffix(host, "/") { + host = fmt.Sprintf("%s/", host) + } + + mux := http.NewServeMux() + // We setup our IPD Mux only on our declared host header + mux.Handle(host, idpMux) // setup our service mux to only handle our required host header + // Ping should respond to all requests + mux.HandleFunc("/ping", PingHandler) + + return &AuthenticatorMux{ + mux: mux, + authenticators: authenticators, + }, nil +} + +func (a *AuthenticatorMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + a.mux.ServeHTTP(w, r) +} + +func (a *AuthenticatorMux) Stop() { + for _, authenticator := range a.authenticators { + authenticator.Stop() + } +} + +// PingHandler handles the /ping route +func PingHandler(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + fmt.Fprintf(rw, http.StatusText(http.StatusOK)) +} diff --git a/internal/auth/mux_test.go b/internal/auth/mux_test.go new file mode 100644 index 00000000..51ea27e6 --- /dev/null +++ b/internal/auth/mux_test.go @@ -0,0 +1,76 @@ +package auth + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestHostHeader(t *testing.T) { + testCases := []struct { + Name string + Host string + RequestHost string + Path string + ExpectedStatusCode int + }{ + { + Name: "reject requests with an invalid hostname", + Host: "example.com", + RequestHost: "unknown.com", + Path: "/static/sso.css", + ExpectedStatusCode: http.StatusNotFound, + }, + { + Name: "allow requests to any hostname to /ping", + Host: "example.com", + RequestHost: "unknown.com", + Path: "/ping", + ExpectedStatusCode: http.StatusOK, + }, + { + Name: "allow requests with a valid hostname", + Host: "example.com", + RequestHost: "example.com", + Path: "/static/sso.css", + ExpectedStatusCode: http.StatusOK, + }, + } + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + opts := testOpts(t, "abced", "testtest") + opts.Host = tc.Host + err := opts.Validate() + if err != nil { + t.Fatalf("unexpected opts error: %v", err) + } + + opts.redirectURL = &url.URL{ + Host: tc.Host, + Path: "/callback", + Scheme: "https", + } + + authMux, err := NewAuthenticatorMux(opts, nil) + if err != nil { + t.Fatalf("unexpected err creating auth mux: %v", err) + } + + uri := fmt.Sprintf("http://%s%s", tc.RequestHost, tc.Path) + + rw := httptest.NewRecorder() + req := httptest.NewRequest("GET", uri, nil) + + authMux.ServeHTTP(rw, req) + if rw.Code != tc.ExpectedStatusCode { + t.Errorf("got unexpected status code") + t.Errorf("want %v", tc.ExpectedStatusCode) + t.Errorf(" got %v", rw.Code) + t.Errorf(" headers %v", rw) + t.Errorf(" body: %q", rw.Body) + } + }) + } +} diff --git a/internal/auth/options.go b/internal/auth/options.go index 32f2d1f0..e2b7bfbe 100644 --- a/internal/auth/options.go +++ b/internal/auth/options.go @@ -7,13 +7,15 @@ import ( "net/http" "net/url" "os" + "path" "reflect" "strings" "time" "github.com/buzzfeed/sso/internal/auth/providers" "github.com/buzzfeed/sso/internal/pkg/groups" - log "github.com/buzzfeed/sso/internal/pkg/logging" + + "github.com/datadog/datadog-go/statsd" "github.com/spf13/viper" ) @@ -38,6 +40,7 @@ import ( // CookieRefresh - duration - refresh the cookie after this duration default 0 // CookieSecure - bool - set secure (HTTPS) cookie flag // CookieHTTPOnly - bool - set httponly cookie flag +// DefaultProvider - string - specify the default provider // RequestTimeout - duration - overall request timeout // AuthCodeSecret - string - the seed string for secure auth codes (optionally base64 encoded) // GroupCacheProviderTTL - time.Duration - cache TTL for the group-cache provider used for on-demand group caching @@ -104,8 +107,11 @@ type Options struct { // These options allow for other providers besides Google, with potential overrides. Provider string `mapstructure:"provider"` + ProviderSlug string `mapstructure:"provider_slug"` ProviderServerID string `mapstructure:"provider_server_id"` + DefaultProvider string `mapstructure:"default_provider"` + SignInURL string `mapstructure:"signin_url"` RedeemURL string `mapstructure:"redeem_url"` RevokeURL string `mapstructure:"revoke_url"` @@ -187,6 +193,7 @@ func setDefaults(v *viper.Viper) { "pass_user_headers": true, "set_xauthrequest": false, "provider": "google", + "provider_slug": "google", "provider_server_id": "default", "approval_prompt": "force", "request_logging": true, @@ -307,6 +314,7 @@ func validateCookieName(o *Options, msgs []string) []string { func newProvider(o *Options) (providers.Provider, error) { p := &providers.ProviderData{ + ProviderSlug: o.ProviderSlug, Scope: o.Scope, ClientID: o.ClientID, ClientSecret: o.ClientSecret, @@ -365,32 +373,36 @@ func newProvider(o *Options) (providers.Provider, error) { return singleFlightProvider, nil } -// AssignProvider is a function that takes an Options struct and assigns the -// appropriate provider to the proxy. Should be called prior to -// AssignStatsdClient. -func AssignProvider(opts *Options) func(*Authenticator) error { - return func(proxy *Authenticator) error { - var err error - proxy.provider, err = newProvider(opts) - return err +// SetProvider is a function that takes a provider and assigns it to the authenticator. +func SetProvider(provider providers.Provider) func(*Authenticator) error { + return func(a *Authenticator) error { + a.provider = provider + return nil } } -// AssignStatsdClient is function that takes in an Options struct and assigns a statsd client -// to the proxy and provider. -func AssignStatsdClient(opts *Options) func(*Authenticator) error { - return func(proxy *Authenticator) error { - logger := log.NewLogEntry() +// SetStatsdClient is function that takes in a statsd client and assigns it to the +// authenticator and provider. +func SetStatsdClient(statsdClient *statsd.Client) func(*Authenticator) error { + return func(a *Authenticator) error { + a.StatsdClient = statsdClient - StatsdClient, err := newStatsdClient(opts.StatsdHost, opts.StatsdPort) - if err != nil { - return fmt.Errorf("error setting up statsd client error=%s", err) + if a.provider != nil { + a.provider.SetStatsdClient(statsdClient) } - logger.WithStatsdHost(opts.StatsdHost).WithStatsdPort(opts.StatsdPort).Info( - "statsd client is running") - proxy.StatsdClient = StatsdClient - proxy.provider.SetStatsdClient(StatsdClient) + return nil + } +} + +// SetRedirectURL takes an options struct and identity provider slug to construct the +// url callback using the slug and configured redirect url. +func SetRedirectURL(opts *Options, slug string) func(*Authenticator) error { + return func(a *Authenticator) error { + redirectURL := new(url.URL) + *redirectURL = *opts.redirectURL + redirectURL.Path = path.Join(slug, "callback") + a.redirectURL = redirectURL return nil } } diff --git a/internal/auth/options_test.go b/internal/auth/options_test.go index 9c8e581c..6e637864 100644 --- a/internal/auth/options_test.go +++ b/internal/auth/options_test.go @@ -68,10 +68,10 @@ func TestInitializedOptions(t *testing.T) { // seems to parse damn near anything. func TestRedirectURL(t *testing.T) { o := testOptions(t) - o.RedirectURL = "https://myhost.com/oauth2/callback" + o.RedirectURL = "https://myhost.com/callback" testutil.Equal(t, nil, o.Validate()) expected := &url.URL{ - Scheme: "https", Host: "myhost.com", Path: "/oauth2/callback"} + Scheme: "https", Host: "myhost.com", Path: "/callback"} testutil.Equal(t, expected, o.redirectURL) } diff --git a/internal/auth/providers/provider_data.go b/internal/auth/providers/provider_data.go index 76626965..707c2867 100644 --- a/internal/auth/providers/provider_data.go +++ b/internal/auth/providers/provider_data.go @@ -9,6 +9,7 @@ import ( // necessary to implement the Provider interface. type ProviderData struct { ProviderName string + ProviderSlug string ClientID string ClientSecret string SignInURL *url.URL diff --git a/internal/auth/static_files_test.go b/internal/auth/static_files_test.go index 4bc82b9e..1fcd0193 100644 --- a/internal/auth/static_files_test.go +++ b/internal/auth/static_files_test.go @@ -10,7 +10,10 @@ import ( func TestStaticFiles(t *testing.T) { opts := testOpts(t, "abced", "testtest") opts.Validate() - proxy, _ := NewAuthenticator(opts) + authMux, err := NewAuthenticatorMux(opts, nil) + if err != nil { + t.Fatalf("unexpected error creating auth mux: %v", err) + } testCases := []struct { name string @@ -45,11 +48,13 @@ func TestStaticFiles(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", tc.path, nil) - proxy.ServeMux.ServeHTTP(rw, req) + req := httptest.NewRequest("GET", tc.path, nil) + + authMux.ServeHTTP(rw, req) if rw.Code != tc.expectedStatus { t.Errorf("expected response %v, got %v\n%v", tc.expectedStatus, rw.Code, rw.HeaderMap) } + if tc.expectedContent != "" && !strings.Contains(rw.Body.String(), tc.expectedContent) { t.Errorf("substring %q not found in response body:\n%s", tc.expectedContent, rw.Body.String()) } diff --git a/internal/pkg/templates/templates.go b/internal/pkg/templates/templates.go index 60512cf2..33b2960e 100644 --- a/internal/pkg/templates/templates.go +++ b/internal/pkg/templates/templates.go @@ -64,7 +64,7 @@ Secured by SSO{{end}}`)) {{template "sign_in_message.html" .}} -
@@ -117,7 +117,7 @@ Secured by SSO{{end}}`))You're currently signed in as {{.Email}}. This will also sign you out of other internal apps.
-