Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow override of domains and whitelist in rules #132

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ All options can be supplied in any of the following ways, in the following prece
- ``Path(`path`, `/articles/{category}/{id:[0-9]+}`, ...)``
- ``PathPrefix(`/products/`, `/articles/{category}/{id:[0-9]+}`)``
- ``Query(`foo=bar`, `bar=baz`)``
- `whitelist` - optional, same usage as [`whitelist`](#whitelist).
- `domains` - optional, same usage as [`domain`](#domain).

For example:
```
Expand Down Expand Up @@ -328,7 +330,7 @@ You can restrict who can login with the following parameters:
* `domain` - Use this to limit logins to a specific domain, e.g. test.com only
* `whitelist` - Use this to only allow specific users to login e.g. thom@test.com only

Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3).
Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). If you set `domains` or `whitelist` on a rule, the global configuration is ignored.

### Forwarded Headers

Expand Down
57 changes: 43 additions & 14 deletions internal/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,31 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
// ValidateEmail checks if the given email address matches either a whitelisted
// email address, as defined by the "whitelist" config parameter. Or is part of
// a permitted domain, as defined by the "domains" config parameter
func ValidateEmail(email string) bool {
func ValidateEmail(email string, ruleName *string) bool {
if ruleName != nil {
rule, ruleExists := config.Rules[*ruleName]

// Do we need to apply rule-level validation?
if ruleExists && (len(rule.Whitelist) > 0 || len(rule.Domains) > 0) {
if len(rule.Whitelist) > 0 && ValidateWhitelist(email, rule.Whitelist) {
return true
} else if config.MatchWhitelistOrDomain && len(rule.Domains) > 0 && ValidateDomains(email, rule.Domains) {
return true
} else {
return false
}
}
}

// Do we have any validation to perform?
if len(config.Whitelist) == 0 && len(config.Domains) == 0 {
return true
}

// Email whitelist validation
if len(config.Whitelist) > 0 {
for _, whitelist := range config.Whitelist {
if email == whitelist {
return true
}
if ValidateWhitelist(email, config.Whitelist) {
return true
}

// If we're not matching *either*, stop here
Expand All @@ -80,18 +93,34 @@ func ValidateEmail(email string) bool {
}

// Domain validation
if len(config.Domains) > 0 {
parts := strings.Split(email, "@")
if len(parts) < 2 {
return false
}
for _, domain := range config.Domains {
if domain == parts[1] {
return true
}
if len(config.Domains) > 0 && ValidateDomains(email, config.Domains) {
return true
}

return false
}

// ValidateWhitelist checks if the email is in whitelist
func ValidateWhitelist(email string, whitelist CommaSeparatedList) bool {
for _, whitelist := range whitelist {
if email == whitelist {
return true
}
}
return false
}

// ValidateDomains checks if the email matches a whitelisted domain
func ValidateDomains(email string, domains CommaSeparatedList) bool {
parts := strings.Split(email, "@")
if len(parts) < 2 {
return false
}
for _, domain := range domains {
if domain == parts[1] {
return true
}
}
return false
}

Expand Down
76 changes: 64 additions & 12 deletions internal/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,58 +65,110 @@ func TestAuthValidateCookie(t *testing.T) {
func TestAuthValidateEmail(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
var ruleName string

// Should allow any
v := ValidateEmail("test@test.com")
v := ValidateEmail("test@test.com", nil)
assert.True(v, "should allow any domain if email domain is not defined")
v = ValidateEmail("one@two.com")
v = ValidateEmail("one@two.com", nil)
assert.True(v, "should allow any domain if email domain is not defined")

// Should block non matching domain
config.Domains = []string{"test.com"}
v = ValidateEmail("one@two.com")
v = ValidateEmail("one@two.com", nil)
assert.False(v, "should not allow user from another domain")

// Should allow matching domain
config.Domains = []string{"test.com"}
v = ValidateEmail("test@test.com")
v = ValidateEmail("test@test.com", nil)
assert.True(v, "should allow user from allowed domain")

// Should block non whitelisted email address
config.Domains = []string{}
config.Whitelist = []string{"test@test.com"}
v = ValidateEmail("one@two.com")
v = ValidateEmail("one@two.com", nil)
assert.False(v, "should not allow user not in whitelist")

// Should allow matching whitelisted email address
config.Domains = []string{}
config.Whitelist = []string{"test@test.com"}
v = ValidateEmail("test@test.com")
v = ValidateEmail("test@test.com", nil)
assert.True(v, "should allow user in whitelist")

// Should allow only matching email address when
// MatchWhitelistOrDomain is disabled
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.MatchWhitelistOrDomain = false
v = ValidateEmail("test@test.com")
v = ValidateEmail("test@test.com", nil)
assert.True(v, "should allow user in whitelist")
v = ValidateEmail("test@example.com")
v = ValidateEmail("test@example.com", nil)
assert.False(v, "should not allow user from valid domain")
v = ValidateEmail("one@two.com")
v = ValidateEmail("one@two.com", nil)
assert.False(v, "should not allow user not in either")

// Should allow only matching email address when
// MatchWhitelistOrDomain is disabled with Rules
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.Rules = map[string]*Rule{"test": NewRule()}
ruleName = "test"
config.Rules[ruleName].Whitelist = []string{"test@testrule.com"}
config.Rules[ruleName].Domains = []string{"testruledomain.com"}
config.MatchWhitelistOrDomain = false
v = ValidateEmail("test@testrule.com", &ruleName)
assert.True(v, "should allow user in rule whitelist")
v = ValidateEmail("test@test.com", &ruleName)
assert.False(v, "should not allow user in global whitelist")
v = ValidateEmail("test@testruledomain.com", &ruleName)
assert.False(v, "should not allow user in rule domain list")
v = ValidateEmail("test@example.com", &ruleName)
assert.False(v, "should not allow user from valid global domain")
v = ValidateEmail("one@two.com", &ruleName)
assert.False(v, "should not allow user not in either")

// Should allow either matching domain or email address when
// MatchWhitelistOrDomain is enabled
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.MatchWhitelistOrDomain = true
v = ValidateEmail("test@test.com")
v = ValidateEmail("test@test.com", nil)
assert.True(v, "should allow user in whitelist")
v = ValidateEmail("test@example.com")
v = ValidateEmail("test@example.com", nil)
assert.True(v, "should allow user from valid domain")
v = ValidateEmail("one@two.com")
v = ValidateEmail("one@two.com", nil)
assert.False(v, "should not allow user not in either")

// Should allow either matching domain or email address when
// MatchWhitelistOrDomain is enabled with Rules
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.Rules = map[string]*Rule{"test": NewRule()}
ruleName = "test"
config.Rules[ruleName].Whitelist = []string{"test@testrule.com"}
config.Rules[ruleName].Domains = []string{"testruledomain.com"}
config.MatchWhitelistOrDomain = true
v = ValidateEmail("test@testrule.com", &ruleName)
assert.True(v, "should allow user in rule whitelist")
v = ValidateEmail("test@test.com", &ruleName)
assert.False(v, "should not allow user in global whitelist")
v = ValidateEmail("test@testruledomain.com", &ruleName)
assert.True(v, "should allow user in rule domain list")
v = ValidateEmail("test@example.com", &ruleName)
assert.False(v, "should not allow user from valid global domain")
v = ValidateEmail("one@two.com", &ruleName)
assert.False(v, "should not allow user not in either")

// Rules should use global whitelist/domains config when not specified
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.Rules = map[string]*Rule{"test": NewRule()}
ruleName = "test"
config.MatchWhitelistOrDomain = true
v = ValidateEmail("test@test.com", &ruleName)
assert.True(v, "should allow user in global whitelist")
v = ValidateEmail("test@example.com", &ruleName)
assert.True(v, "should allow user from valid global domain")
}

func TestRedirectUri(t *testing.T) {
Expand Down
16 changes: 13 additions & 3 deletions internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [
rule.Rule = val
case "provider":
rule.Provider = val
case "whitelist":
list := CommaSeparatedList{}
list.UnmarshalFlag(val)
rule.Whitelist = list
case "domains":
list := CommaSeparatedList{}
list.UnmarshalFlag(val)
rule.Domains = list
default:
return args, fmt.Errorf("invalid route param: %v", option)
}
Expand Down Expand Up @@ -325,9 +333,11 @@ func (c *Config) setupProvider(name string) error {

// Rule holds defined rules
type Rule struct {
Action string
Rule string
Provider string
Action string
Rule string
Provider string
Whitelist CommaSeparatedList
Domains CommaSeparatedList
}

// NewRule creates a new rule object
Expand Down
20 changes: 10 additions & 10 deletions internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ func (s *Server) buildRoutes() {
for name, rule := range config.Rules {
matchRule := rule.formattedRule()
if rule.Action == "allow" {
s.router.AddRoute(matchRule, 1, s.AllowHandler(name))
s.router.AddRoute(matchRule, 1, s.AllowHandler(&name))
} else {
s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name))
s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, &name))
}
}

Expand All @@ -46,9 +46,9 @@ func (s *Server) buildRoutes() {

// Add a default handler
if config.DefaultAction == "allow" {
s.router.NewRoute().Handler(s.AllowHandler("default"))
s.router.NewRoute().Handler(s.AllowHandler(nil))
} else {
s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default"))
s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, nil))
}
}

Expand All @@ -65,15 +65,15 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
}

// AllowHandler Allows requests
func (s *Server) AllowHandler(rule string) http.HandlerFunc {
func (s *Server) AllowHandler(rule *string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s.logger(r, "Allow", rule, "Allowing request")
w.WriteHeader(200)
}
}

// AuthHandler Authenticates requests
func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
func (s *Server) AuthHandler(providerName string, rule *string) http.HandlerFunc {
p, _ := config.GetConfiguredProvider(providerName)

return func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -101,7 +101,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
}

// Validate user
valid := ValidateEmail(email)
valid := ValidateEmail(email, rule)
if !valid {
logger.WithField("email", email).Warn("Invalid email")
http.Error(w, "Not authorized", 401)
Expand All @@ -119,7 +119,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
func (s *Server) AuthCallbackHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := s.logger(r, "AuthCallback", "default", "Handling callback")
logger := s.logger(r, "AuthCallback", nil, "Handling callback")

// Check for CSRF cookie
c, err := r.Cookie(config.CSRFCookieName)
Expand Down Expand Up @@ -189,7 +189,7 @@ func (s *Server) LogoutHandler() http.HandlerFunc {
// Clear cookie
http.SetCookie(w, ClearCookie(r))

logger := s.logger(r, "Logout", "default", "Handling logout")
logger := s.logger(r, "Logout", nil, "Handling logout")
logger.Info("Logged out user")

if config.LogoutRedirect != "" {
Expand Down Expand Up @@ -229,7 +229,7 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht
}).Debug("Set CSRF cookie and redirected to provider login url")
}

func (s *Server) logger(r *http.Request, handler, rule, msg string) *logrus.Entry {
func (s *Server) logger(r *http.Request, handler string, rule *string, msg string) *logrus.Entry {
// Create logger
logger := log.WithFields(logrus.Fields{
"handler": handler,
Expand Down