diff --git a/README.md b/README.md index bbca3eb3..e03ee155 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,8 @@ All options can be supplied in any of the following ways, in the following prece - ``Path(`path`, `/articles/{category}/{id:[0-9]+}`, ...)`` - ``PathPrefix(`/products/`, `/articles/{category}/{id:[0-9]+}`)`` - ``Query(`foo=bar`, `bar=baz`)`` + - `whitelist` - optional, same usage as [`whitelist`](#whitelist). + - `domains` - optional, same usage as [`domain`](#domain). For example: ``` @@ -328,7 +330,7 @@ You can restrict who can login with the following parameters: * `domain` - Use this to limit logins to a specific domain, e.g. test.com only * `whitelist` - Use this to only allow specific users to login e.g. thom@test.com only -Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). +Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). If you set `domains` or `whitelist` on a rule, the global configuration is ignored. ### Forwarded Headers diff --git a/internal/auth.go b/internal/auth.go index cefd2315..4184bf74 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -59,7 +59,22 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { // ValidateEmail checks if the given email address matches either a whitelisted // email address, as defined by the "whitelist" config parameter. Or is part of // a permitted domain, as defined by the "domains" config parameter -func ValidateEmail(email string) bool { +func ValidateEmail(email string, ruleName *string) bool { + if ruleName != nil { + rule, ruleExists := config.Rules[*ruleName] + + // Do we need to apply rule-level validation? + if ruleExists && (len(rule.Whitelist) > 0 || len(rule.Domains) > 0) { + if len(rule.Whitelist) > 0 && ValidateWhitelist(email, rule.Whitelist) { + return true + } else if config.MatchWhitelistOrDomain && len(rule.Domains) > 0 && ValidateDomains(email, rule.Domains) { + return true + } else { + return false + } + } + } + // Do we have any validation to perform? if len(config.Whitelist) == 0 && len(config.Domains) == 0 { return true @@ -67,10 +82,8 @@ func ValidateEmail(email string) bool { // Email whitelist validation if len(config.Whitelist) > 0 { - for _, whitelist := range config.Whitelist { - if email == whitelist { - return true - } + if ValidateWhitelist(email, config.Whitelist) { + return true } // If we're not matching *either*, stop here @@ -80,18 +93,34 @@ func ValidateEmail(email string) bool { } // Domain validation - if len(config.Domains) > 0 { - parts := strings.Split(email, "@") - if len(parts) < 2 { - return false - } - for _, domain := range config.Domains { - if domain == parts[1] { - return true - } + if len(config.Domains) > 0 && ValidateDomains(email, config.Domains) { + return true + } + + return false +} + +// ValidateWhitelist checks if the email is in whitelist +func ValidateWhitelist(email string, whitelist CommaSeparatedList) bool { + for _, whitelist := range whitelist { + if email == whitelist { + return true } } + return false +} +// ValidateDomains checks if the email matches a whitelisted domain +func ValidateDomains(email string, domains CommaSeparatedList) bool { + parts := strings.Split(email, "@") + if len(parts) < 2 { + return false + } + for _, domain := range domains { + if domain == parts[1] { + return true + } + } return false } diff --git a/internal/auth_test.go b/internal/auth_test.go index 840337c9..a358ec50 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -65,33 +65,34 @@ func TestAuthValidateCookie(t *testing.T) { func TestAuthValidateEmail(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) + var ruleName string // Should allow any - v := ValidateEmail("test@test.com") + v := ValidateEmail("test@test.com", nil) assert.True(v, "should allow any domain if email domain is not defined") - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", nil) assert.True(v, "should allow any domain if email domain is not defined") // Should block non matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", nil) assert.False(v, "should not allow user from another domain") // Should allow matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("test@test.com", nil) assert.True(v, "should allow user from allowed domain") // Should block non whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", nil) assert.False(v, "should not allow user not in whitelist") // Should allow matching whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("test@test.com", nil) assert.True(v, "should allow user in whitelist") // Should allow only matching email address when @@ -99,11 +100,31 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = false - v = ValidateEmail("test@test.com") + v = ValidateEmail("test@test.com", nil) assert.True(v, "should allow user in whitelist") - v = ValidateEmail("test@example.com") + v = ValidateEmail("test@example.com", nil) assert.False(v, "should not allow user from valid domain") - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", nil) + assert.False(v, "should not allow user not in either") + + // Should allow only matching email address when + // MatchWhitelistOrDomain is disabled with Rules + config.Domains = []string{"example.com"} + config.Whitelist = []string{"test@test.com"} + config.Rules = map[string]*Rule{"test": NewRule()} + ruleName = "test" + config.Rules[ruleName].Whitelist = []string{"test@testrule.com"} + config.Rules[ruleName].Domains = []string{"testruledomain.com"} + config.MatchWhitelistOrDomain = false + v = ValidateEmail("test@testrule.com", &ruleName) + assert.True(v, "should allow user in rule whitelist") + v = ValidateEmail("test@test.com", &ruleName) + assert.False(v, "should not allow user in global whitelist") + v = ValidateEmail("test@testruledomain.com", &ruleName) + assert.False(v, "should not allow user in rule domain list") + v = ValidateEmail("test@example.com", &ruleName) + assert.False(v, "should not allow user from valid global domain") + v = ValidateEmail("one@two.com", &ruleName) assert.False(v, "should not allow user not in either") // Should allow either matching domain or email address when @@ -111,12 +132,43 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = true - v = ValidateEmail("test@test.com") + v = ValidateEmail("test@test.com", nil) assert.True(v, "should allow user in whitelist") - v = ValidateEmail("test@example.com") + v = ValidateEmail("test@example.com", nil) assert.True(v, "should allow user from valid domain") - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", nil) + assert.False(v, "should not allow user not in either") + + // Should allow either matching domain or email address when + // MatchWhitelistOrDomain is enabled with Rules + config.Domains = []string{"example.com"} + config.Whitelist = []string{"test@test.com"} + config.Rules = map[string]*Rule{"test": NewRule()} + ruleName = "test" + config.Rules[ruleName].Whitelist = []string{"test@testrule.com"} + config.Rules[ruleName].Domains = []string{"testruledomain.com"} + config.MatchWhitelistOrDomain = true + v = ValidateEmail("test@testrule.com", &ruleName) + assert.True(v, "should allow user in rule whitelist") + v = ValidateEmail("test@test.com", &ruleName) + assert.False(v, "should not allow user in global whitelist") + v = ValidateEmail("test@testruledomain.com", &ruleName) + assert.True(v, "should allow user in rule domain list") + v = ValidateEmail("test@example.com", &ruleName) + assert.False(v, "should not allow user from valid global domain") + v = ValidateEmail("one@two.com", &ruleName) assert.False(v, "should not allow user not in either") + + // Rules should use global whitelist/domains config when not specified + config.Domains = []string{"example.com"} + config.Whitelist = []string{"test@test.com"} + config.Rules = map[string]*Rule{"test": NewRule()} + ruleName = "test" + config.MatchWhitelistOrDomain = true + v = ValidateEmail("test@test.com", &ruleName) + assert.True(v, "should allow user in global whitelist") + v = ValidateEmail("test@example.com", &ruleName) + assert.True(v, "should allow user from valid global domain") } func TestRedirectUri(t *testing.T) { diff --git a/internal/config.go b/internal/config.go index 8be0ae96..7af22c0b 100644 --- a/internal/config.go +++ b/internal/config.go @@ -210,6 +210,14 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [ rule.Rule = val case "provider": rule.Provider = val + case "whitelist": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Whitelist = list + case "domains": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Domains = list default: return args, fmt.Errorf("invalid route param: %v", option) } @@ -325,9 +333,11 @@ func (c *Config) setupProvider(name string) error { // Rule holds defined rules type Rule struct { - Action string - Rule string - Provider string + Action string + Rule string + Provider string + Whitelist CommaSeparatedList + Domains CommaSeparatedList } // NewRule creates a new rule object diff --git a/internal/server.go b/internal/server.go index 3fd76500..042628fe 100644 --- a/internal/server.go +++ b/internal/server.go @@ -32,9 +32,9 @@ func (s *Server) buildRoutes() { for name, rule := range config.Rules { matchRule := rule.formattedRule() if rule.Action == "allow" { - s.router.AddRoute(matchRule, 1, s.AllowHandler(name)) + s.router.AddRoute(matchRule, 1, s.AllowHandler(&name)) } else { - s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name)) + s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, &name)) } } @@ -46,9 +46,9 @@ func (s *Server) buildRoutes() { // Add a default handler if config.DefaultAction == "allow" { - s.router.NewRoute().Handler(s.AllowHandler("default")) + s.router.NewRoute().Handler(s.AllowHandler(nil)) } else { - s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default")) + s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, nil)) } } @@ -65,7 +65,7 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) { } // AllowHandler Allows requests -func (s *Server) AllowHandler(rule string) http.HandlerFunc { +func (s *Server) AllowHandler(rule *string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { s.logger(r, "Allow", rule, "Allowing request") w.WriteHeader(200) @@ -73,7 +73,7 @@ func (s *Server) AllowHandler(rule string) http.HandlerFunc { } // AuthHandler Authenticates requests -func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { +func (s *Server) AuthHandler(providerName string, rule *string) http.HandlerFunc { p, _ := config.GetConfiguredProvider(providerName) return func(w http.ResponseWriter, r *http.Request) { @@ -101,7 +101,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { } // Validate user - valid := ValidateEmail(email) + valid := ValidateEmail(email, rule) if !valid { logger.WithField("email", email).Warn("Invalid email") http.Error(w, "Not authorized", 401) @@ -119,7 +119,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { func (s *Server) AuthCallbackHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Logging setup - logger := s.logger(r, "AuthCallback", "default", "Handling callback") + logger := s.logger(r, "AuthCallback", nil, "Handling callback") // Check for CSRF cookie c, err := r.Cookie(config.CSRFCookieName) @@ -189,7 +189,7 @@ func (s *Server) LogoutHandler() http.HandlerFunc { // Clear cookie http.SetCookie(w, ClearCookie(r)) - logger := s.logger(r, "Logout", "default", "Handling logout") + logger := s.logger(r, "Logout", nil, "Handling logout") logger.Info("Logged out user") if config.LogoutRedirect != "" { @@ -229,7 +229,7 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht }).Debug("Set CSRF cookie and redirected to provider login url") } -func (s *Server) logger(r *http.Request, handler, rule, msg string) *logrus.Entry { +func (s *Server) logger(r *http.Request, handler string, rule *string, msg string) *logrus.Entry { // Create logger logger := log.WithFields(logrus.Fields{ "handler": handler,