Skip to content

Commit

Permalink
auth: AuthenticatorMux to allow multiple authenticators/providers
Browse files Browse the repository at this point in the history
  • Loading branch information
jphines committed May 28, 2019
1 parent 66ed9e6 commit 7b67cfe
Show file tree
Hide file tree
Showing 17 changed files with 362 additions and 196 deletions.
22 changes: 9 additions & 13 deletions cmd/sso-auth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/buzzfeed/sso/internal/auth"
log "github.com/buzzfeed/sso/internal/pkg/logging"
"github.com/buzzfeed/sso/internal/pkg/options"
)

func init() {
Expand All @@ -29,31 +28,28 @@ func main() {
os.Exit(1)
}

emailValidator := func(p *auth.Authenticator) error {
if len(opts.EmailAddresses) != 0 {
p.Validator = options.NewEmailAddressValidator(opts.EmailAddresses)
} else {
p.Validator = options.NewEmailDomainValidator(opts.EmailDomains)
}
return nil
statsdClient, err := auth.NewStatsdClient(opts.StatsdHost, opts.StatsdPort)
if err != nil {
logger.Error(err, "error creating statsd client")
os.Exit(1)
}

authenticator, err := auth.NewAuthenticator(opts, emailValidator, auth.AssignProvider(opts), auth.SetCookieStore(opts), auth.AssignStatsdClient(opts))
authMux, err := auth.NewAuthenticatorMux(opts, statsdClient)
if err != nil {
logger.Error(err, "error creating new Authenticator")
logger.Error(err, "error creating new AuthenticatorMux")
os.Exit(1)
}
defer authenticator.Stop()
defer authMux.Stop()

// we leave the message field blank, which will inherit the stdlib timeout page which is sufficient
// and better than other naive messages we would currently place here
timeoutHandler := http.TimeoutHandler(authenticator.ServeMux, opts.RequestTimeout, "")
timeoutHandler := http.TimeoutHandler(authMux, opts.RequestTimeout, "")

s := &http.Server{
Addr: fmt.Sprintf(":%d", opts.Port),
ReadTimeout: opts.TCPReadTimeout,
WriteTimeout: opts.TCPWriteTimeout,
Handler: auth.NewLoggingHandler(os.Stdout, timeoutHandler, opts.RequestLogging, authenticator.StatsdClient),
Handler: auth.NewLoggingHandler(os.Stdout, timeoutHandler, opts.RequestLogging, statsdClient),
}

logger.Fatal(s.ListenAndServe())
Expand Down
79 changes: 30 additions & 49 deletions internal/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ type getProfileResponse struct {
}

// SetCookieStore sets the cookie store to use a miscreant cipher
func SetCookieStore(opts *Options) func(*Authenticator) error {
func SetCookieStore(opts *Options, providerSlug string) func(*Authenticator) error {
return func(a *Authenticator) error {
decodedAuthCodeSecret, err := base64.StdEncoding.DecodeString(opts.AuthCodeSecret)
if err != nil {
Expand All @@ -98,7 +98,8 @@ func SetCookieStore(opts *Options) func(*Authenticator) error {
return err
}

cookieStore, err := sessions.NewCookieStore(opts.CookieName,
cookieName := fmt.Sprintf("%s_%s", opts.CookieName, providerSlug)
cookieStore, err := sessions.NewCookieStore(cookieName,
sessions.CreateMiscreantCookieCipher(opts.decodedCookieSecret),
func(c *sessions.CookieStore) error {
c.CookieDomain = opts.CookieDomain
Expand All @@ -123,9 +124,6 @@ func SetCookieStore(opts *Options) func(*Authenticator) error {
func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error) (*Authenticator, error) {
logger := log.NewLogEntry()

redirectURL := opts.redirectURL
redirectURL.Path = "/oauth2/callback"

templates := templates.NewHTMLTemplate()

proxyRootDomains := []string{}
Expand All @@ -144,7 +142,6 @@ func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error)
Host: opts.Host,
CookieSecure: opts.CookieSecure,

redirectURL: redirectURL,
SetXAuthRequest: opts.SetXAuthRequest,
PassUserHeaders: opts.PassUserHeaders,
SkipProviderButton: opts.SkipProviderButton,
Expand All @@ -166,39 +163,19 @@ func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error)
}

func (p *Authenticator) newMux() http.Handler {
logger := log.NewLogEntry()

mux := http.NewServeMux()

// we setup global endpoints that should respond to any hostname
mux.HandleFunc("/ping", p.withMethods(p.PingPage, "GET"))

// we setup our service mux to handle service routes that use the required host header
serviceMux := http.NewServeMux()
serviceMux.HandleFunc("/ping", p.withMethods(p.PingPage, "GET"))
serviceMux.HandleFunc("/robots.txt", p.withMethods(p.RobotsTxt, "GET"))
serviceMux.HandleFunc("/start", p.withMethods(p.OAuthStart, "GET"))
serviceMux.HandleFunc("/sign_in", p.withMethods(p.validateClientID(p.validateRedirectURI(p.validateSignature(p.SignIn))), "GET"))
serviceMux.HandleFunc("/sign_out", p.withMethods(p.validateRedirectURI(p.validateSignature(p.SignOut)), "GET", "POST"))
serviceMux.HandleFunc("/oauth2/callback", p.withMethods(p.OAuthCallback, "GET"))
serviceMux.HandleFunc("/callback", p.withMethods(p.OAuthCallback, "GET"))
serviceMux.HandleFunc("/profile", p.withMethods(p.validateClientID(p.validateClientSecret(p.GetProfile)), "GET"))
serviceMux.HandleFunc("/validate", p.withMethods(p.validateClientID(p.validateClientSecret(p.ValidateToken)), "GET"))
serviceMux.HandleFunc("/redeem", p.withMethods(p.validateClientID(p.validateClientSecret(p.Redeem)), "POST"))
serviceMux.HandleFunc("/refresh", p.withMethods(p.validateClientID(p.validateClientSecret(p.Refresh)), "POST"))
fsHandler, err := loadFSHandler()
if err != nil {
logger.Fatal(err)
}
serviceMux.Handle("/static/", http.StripPrefix("/static/", fsHandler))

// NOTE: we have to include trailing slash for the router to match the host header
host := p.Host
if !strings.HasSuffix(host, "/") {
host = fmt.Sprintf("%s/", host)
}
mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header

return setHeaders(mux)
return setHeaders(serviceMux)
}

// GetRedirectURI returns the redirect url for a given OAuthProxy,
Expand All @@ -215,13 +192,8 @@ func (p *Authenticator) RobotsTxt(rw http.ResponseWriter, req *http.Request) {
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
}

// PingPage handles the /ping route
func (p *Authenticator) PingPage(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "OK")
}

type signInResp struct {
ProviderSlug string
ProviderName string
EmailDomains []string
Redirect string
Expand All @@ -237,13 +209,20 @@ func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, co
// We don't want to rely on req.Host, as that can be attacked via Host header injection
// This ends up looking like:
// https://sso-auth.example.com/sign_in?client_id=...&redirect_uri=...
redirectURL := p.redirectURL.ResolveReference(req.URL)
path := strings.TrimPrefix(req.URL.Path, "/")
redirectURL := p.redirectURL.ResolveReference(
&url.URL{
Path: path,
RawQuery: req.URL.RawQuery,
},
)

// validateRedirectURI middleware already ensures that this is a valid URL
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))

t := signInResp{
ProviderName: p.provider.Data().ProviderName,
ProviderSlug: p.provider.Data().ProviderSlug,
EmailDomains: p.EmailDomains,
Redirect: redirectURL.String(),
Destination: destinationURL.Host,
Expand Down Expand Up @@ -474,13 +453,14 @@ func (p *Authenticator) SignOut(rw http.ResponseWriter, req *http.Request) {
}

type signOutResp struct {
Version string
Redirect string
Signature string
Timestamp string
Message string
Destination string
Email string
ProviderSlug string
Version string
Redirect string
Signature string
Timestamp string
Message string
Destination string
Email string
}

// SignOutPage renders a sign out page with a message
Expand All @@ -504,13 +484,14 @@ func (p *Authenticator) SignOutPage(rw http.ResponseWriter, req *http.Request, m
}

t := signOutResp{
Version: VERSION,
Redirect: redirectURI,
Signature: signature,
Timestamp: timestamp,
Message: message,
Destination: destinationURL.Host,
Email: session.Email,
ProviderSlug: p.provider.Data().ProviderSlug,
Version: VERSION,
Redirect: redirectURI,
Signature: signature,
Timestamp: timestamp,
Message: message,
Destination: destinationURL.Host,
Email: session.Email,
}
p.templates.ExecuteTemplate(rw, "sign_out.html", t)
return
Expand Down
113 changes: 44 additions & 69 deletions internal/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ func setTestProvider(provider *providers.TestProvider) func(*Authenticator) erro
}
}

func setMockValidator(response bool) func(*Authenticator) error {
return func(a *Authenticator) error {
a.Validator = func(string) bool { return response }
return nil
}
}

func setRedirectURL(redirectURL *url.URL) func(*Authenticator) error {
return func(a *Authenticator) error {
a.redirectURL = redirectURL
return nil
}
}

func assignProvider(opts *Options) func(*Authenticator) error {
return func(a *Authenticator) error {
var err error
a.provider, err = newProvider(opts)
return err
}
}

// generated using `openssl rand 32 -base64`
var testEncodedCookieSecret = "x7xzsM1Ky4vGQPwqy6uTztfr3jtm/pIdRbJXgE0q8kU="
var testAuthCodeSecret = "qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38="
Expand All @@ -83,6 +105,10 @@ func testOpts(t *testing.T, proxyClientID, proxyClientSecret string) *Options {
opts.AuthCodeSecret = testAuthCodeSecret
opts.ProxyRootDomains = []string{"example.com"}
opts.Host = "/"
opts.EmailDomains = []string{"*"}
opts.StatsdPort = 8125
opts.StatsdHost = "localhost"
opts.RedirectURL = "http://example.com"
return opts
}

Expand Down Expand Up @@ -429,10 +455,13 @@ func TestSignIn(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
opts := testOpts(t, "test", "secret")
opts.Validate()
auth, err := NewAuthenticator(opts, func(p *Authenticator) error {
p.Validator = func(string) bool { return tc.validEmail }
return nil
}, setMockSessionStore(tc.mockSessionStore), setMockTempl(), setMockAuthCodeCipher(tc.mockAuthCodeCipher, nil))
auth, err := NewAuthenticator(opts,
setMockValidator(tc.validEmail),
setMockSessionStore(tc.mockSessionStore),
setMockTempl(),
setRedirectURL(opts.redirectURL),
setMockAuthCodeCipher(tc.mockAuthCodeCipher, nil),
)
testutil.Ok(t, err)

// set test provider
Expand Down Expand Up @@ -1477,7 +1506,7 @@ func TestGlobalHeaders(t *testing.T) {
testCases := []struct {
path string
}{
{"/oauth2/callback"},
{"/callback"},
{"/ping"},
{"/profile"},
{"/redeem"},
Expand Down Expand Up @@ -1545,14 +1574,16 @@ func TestOAuthStart(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {

opts := testOpts(t, "abced", "testtest")
opts.RedirectURL = "https://example.com/oauth2/callback"
opts.RedirectURL = "https://example.com/"
opts.Validate()
u, _ := url.Parse("http://example.com")
provider := providers.NewTestProvider(u)
proxy, _ := NewAuthenticator(opts, setTestProvider(provider), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
}, setMockCSRFStore(&sessions.MockCSRFStore{}))
proxy, _ := NewAuthenticator(opts,
setTestProvider(provider),
setMockValidator(true),
setRedirectURL(opts.redirectURL),
setMockCSRFStore(&sessions.MockCSRFStore{}),
)

params := url.Values{}
if tc.RedirectURI != "" {
Expand Down Expand Up @@ -1585,67 +1616,11 @@ func TestOAuthStart(t *testing.T) {
}
}

func TestHostHeader(t *testing.T) {
testCases := []struct {
Name string
Host string
RequestHost string
Path string
ExpectedStatusCode int
}{
{
Name: "reject requests with an invalid hostname",
Host: "example.com",
RequestHost: "unknown.com",
Path: "/robots.txt",
ExpectedStatusCode: http.StatusNotFound,
},
{
Name: "allow requests to any hostname to /ping",
Host: "example.com",
RequestHost: "unknown.com",
Path: "/ping",
ExpectedStatusCode: http.StatusOK,
},
{
Name: "allow requests with a valid hostname",
Host: "example.com",
RequestHost: "example.com",
Path: "/robots.txt",
ExpectedStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
opts := testOpts(t, "abced", "testtest")
opts.Host = tc.Host
opts.Validate()

proxy, _ := NewAuthenticator(opts, func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
})

uri := fmt.Sprintf("http://%s%s", tc.RequestHost, tc.Path)
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", uri, nil)
proxy.ServeMux.ServeHTTP(rw, req)
if rw.Code != tc.ExpectedStatusCode {
t.Errorf("got unexpected status code")
t.Errorf("want %v", tc.ExpectedStatusCode)
t.Errorf(" got %v", rw.Code)
t.Errorf(" headers %v", rw)
t.Errorf(" body: %q", rw.Body)
}
})
}
}

func TestGoogleProviderApiSettings(t *testing.T) {
opts := testOpts(t, "abced", "testtest")
opts.Provider = "google"
opts.Validate()
proxy, _ := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error {
proxy, _ := NewAuthenticator(opts, assignProvider(opts), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
})
Expand All @@ -1664,7 +1639,7 @@ func TestGoogleGroupInvalidFile(t *testing.T) {
opts.GoogleAdminEmail = "admin@example.com"
opts.GoogleServiceAccountJSON = "file_doesnt_exist.json"
opts.Validate()
_, err := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error {
_, err := NewAuthenticator(opts, assignProvider(opts), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
})
Expand All @@ -1676,7 +1651,7 @@ func TestUnimplementedProvider(t *testing.T) {
opts := testOpts(t, "abced", "testtest")
opts.Provider = "null_provider"
opts.Validate()
_, err := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error {
_, err := NewAuthenticator(opts, assignProvider(opts), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
})
Expand Down
Loading

0 comments on commit 7b67cfe

Please sign in to comment.