diff --git a/README.md b/README.md index 9a42b1a4..7f92c38c 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,8 @@ All options can be supplied in any of the following ways, in the following prece - ``Path(`path`, `/articles/{category}/{id:[0-9]+}`, ...)`` - ``PathPrefix(`/products/`, `/articles/{category}/{id:[0-9]+}`)`` - ``Query(`foo=bar`, `bar=baz`)`` + - `whitelist` - optional, same usage as [`whitelist`](#whitelist). + - `domains` - optional, same usage as [`domain`](#domain). For example: ``` @@ -283,7 +285,7 @@ You can restrict who can login with the following parameters: * `domain` - Use this to limit logins to a specific domain, e.g. test.com only * `whitelist` - Use this to only allow specific users to login e.g. thom@test.com only -Note, if you pass `whitelist` then only this is checked and `domain` is effectively ignored. +Note, if you pass `whitelist` then only this is checked and `domain` is effectively ignored. If you set `domains` or `whitelist` on a rules, the global configuration is ignored. ### Forwarded Headers diff --git a/internal/auth.go b/internal/auth.go index b5723de2..e830ad78 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -56,24 +56,18 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { } // Validate email -func ValidateEmail(email string) bool { +func ValidateEmail(email string, rule string) bool { found := false - if len(config.Whitelist) > 0 { - for _, whitelist := range config.Whitelist { - if email == whitelist { - found = true - } - } + + _, ruleExists := config.Rules[rule] + if ruleExists && len(config.Rules[rule].Whitelist) > 0 { + found = ValidateWhitelist(email, config.Rules[rule].Whitelist) + } else if ruleExists && len(config.Rules[rule].Domains) > 0 { + found = ValidateDomains(email, config.Rules[rule].Domains) + } else if len(config.Whitelist) > 0 { + found = ValidateWhitelist(email, config.Whitelist) } else if len(config.Domains) > 0 { - parts := strings.Split(email, "@") - if len(parts) < 2 { - return false - } - for _, domain := range config.Domains { - if domain == parts[1] { - found = true - } - } + found = ValidateDomains(email, config.Domains) } else { return true } @@ -81,6 +75,32 @@ func ValidateEmail(email string) bool { return found } +// Validate email is in whitelist +func ValidateWhitelist(email string, whitelist CommaSeparatedList) bool { + found := false + for _, whitelist := range whitelist { + if email == whitelist { + found = true + } + } + return found +} + +// Validate email match a domains +func ValidateDomains(email string, domains CommaSeparatedList) bool { + found := false + parts := strings.Split(email, "@") + if len(parts) < 2 { + return false + } + for _, domain := range domains { + if domain == parts[1] { + found = true + } + } + return found +} + // OAuth Methods // Get login url diff --git a/internal/auth_test.go b/internal/auth_test.go index 9a914989..fdccfce5 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -67,32 +67,65 @@ func TestAuthValidateEmail(t *testing.T) { config, _ = NewConfig([]string{}) // Should allow any - v := ValidateEmail("test@test.com") + v := ValidateEmail("test@test.com", "default") assert.True(v, "should allow any domain if email domain is not defined") - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.True(v, "should allow any domain if email domain is not defined") // Should block non matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.False(v, "should not allow user from another domain") // Should allow matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user from allowed domain") // Should block non whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.False(v, "should not allow user not in whitelist") // Should allow matching whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user in whitelist") + + // Should allow matching whitelisted in rules email address + config.Domains = []string{"globaltestdomain.com"} + config.Whitelist = []string{} + config.Rules = map[string]*Rule{"test": NewRule()} + config.Rules["test"].Whitelist = []string{"test@test.com"} + // Validation for user in the rule whitelist + v = ValidateEmail("test@test.com", "test") + assert.True(v, "should allow user in rule whitelist") + // Validation for user not in the rule whitelist + v = ValidateEmail("test2@test.com", "test") + assert.False(v, "should not allow user not in rule whitelist") + // Validation for user in global domain but not in rule + v = ValidateEmail("test@globaltestdomain.com", "test") + assert.False(v, "should not allow user in global but not in rule") + // Validation for user in the whitelist, but that not this rule + v = ValidateEmail("test@test.com", "default") + assert.False(v, "should not allow user not in the rule whitelisted") + + // Should allow matching domains + config.Domains = []string{"globaltestdomain.com"} + config.Whitelist = []string{} + config.Rules = map[string]*Rule{"test": NewRule()} + config.Rules["test"].Domains = []string{"test.com"} + // Validation for user in the rule domains + v = ValidateEmail("test@test.com", "test") + assert.True(v, "should allow user in rule domains") + // Validation for user not in the rule whitelist + v = ValidateEmail("test@test2.com", "test") + assert.False(v, "should not allow user not in rule domains") + // Validation for user in the whitelist, but that not this rule + v = ValidateEmail("test@test.com", "default") + assert.False(v, "should not allow user not in the rule") } // TODO: Split google tests out diff --git a/internal/config.go b/internal/config.go index 5ea19a6d..02f1d0b0 100644 --- a/internal/config.go +++ b/internal/config.go @@ -213,6 +213,14 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [ rule.Rule = val case "provider": rule.Provider = val + case "whitelist": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Whitelist = list + case "domains": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Domains = list default: return args, fmt.Errorf("inavlid route param: %v", option) } @@ -266,9 +274,11 @@ func (c Config) String() string { } type Rule struct { - Action string - Rule string - Provider string + Action string + Rule string + Provider string + Whitelist CommaSeparatedList + Domains CommaSeparatedList } func NewRule() *Rule { diff --git a/internal/server.go b/internal/server.go index 3119876b..831b8793 100644 --- a/internal/server.go +++ b/internal/server.go @@ -90,7 +90,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc { } // Validate user - valid := ValidateEmail(email) + valid := ValidateEmail(email, rule) if !valid { logger.WithFields(logrus.Fields{ "email": email,