From 13c697febf18c0994bbcadc38315299433c378c2 Mon Sep 17 00:00:00 2001 From: Mathieu Cantin Date: Tue, 6 Aug 2019 11:09:13 -0400 Subject: [PATCH 1/5] Allow to override domains and whitelist in rules --- README.md | 4 +++- internal/auth.go | 52 ++++++++++++++++++++++++++++++------------- internal/auth_test.go | 45 ++++++++++++++++++++++++++++++++----- internal/config.go | 16 ++++++++++--- internal/server.go | 2 +- 5 files changed, 92 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 9a42b1a4..7f92c38c 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,8 @@ All options can be supplied in any of the following ways, in the following prece - ``Path(`path`, `/articles/{category}/{id:[0-9]+}`, ...)`` - ``PathPrefix(`/products/`, `/articles/{category}/{id:[0-9]+}`)`` - ``Query(`foo=bar`, `bar=baz`)`` + - `whitelist` - optional, same usage as [`whitelist`](#whitelist). + - `domains` - optional, same usage as [`domain`](#domain). For example: ``` @@ -283,7 +285,7 @@ You can restrict who can login with the following parameters: * `domain` - Use this to limit logins to a specific domain, e.g. test.com only * `whitelist` - Use this to only allow specific users to login e.g. thom@test.com only -Note, if you pass `whitelist` then only this is checked and `domain` is effectively ignored. +Note, if you pass `whitelist` then only this is checked and `domain` is effectively ignored. If you set `domains` or `whitelist` on a rules, the global configuration is ignored. ### Forwarded Headers diff --git a/internal/auth.go b/internal/auth.go index b5723de2..e830ad78 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -56,24 +56,18 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { } // Validate email -func ValidateEmail(email string) bool { +func ValidateEmail(email string, rule string) bool { found := false - if len(config.Whitelist) > 0 { - for _, whitelist := range config.Whitelist { - if email == whitelist { - found = true - } - } + + _, ruleExists := config.Rules[rule] + if ruleExists && len(config.Rules[rule].Whitelist) > 0 { + found = ValidateWhitelist(email, config.Rules[rule].Whitelist) + } else if ruleExists && len(config.Rules[rule].Domains) > 0 { + found = ValidateDomains(email, config.Rules[rule].Domains) + } else if len(config.Whitelist) > 0 { + found = ValidateWhitelist(email, config.Whitelist) } else if len(config.Domains) > 0 { - parts := strings.Split(email, "@") - if len(parts) < 2 { - return false - } - for _, domain := range config.Domains { - if domain == parts[1] { - found = true - } - } + found = ValidateDomains(email, config.Domains) } else { return true } @@ -81,6 +75,32 @@ func ValidateEmail(email string) bool { return found } +// Validate email is in whitelist +func ValidateWhitelist(email string, whitelist CommaSeparatedList) bool { + found := false + for _, whitelist := range whitelist { + if email == whitelist { + found = true + } + } + return found +} + +// Validate email match a domains +func ValidateDomains(email string, domains CommaSeparatedList) bool { + found := false + parts := strings.Split(email, "@") + if len(parts) < 2 { + return false + } + for _, domain := range domains { + if domain == parts[1] { + found = true + } + } + return found +} + // OAuth Methods // Get login url diff --git a/internal/auth_test.go b/internal/auth_test.go index 9a914989..fdccfce5 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -67,32 +67,65 @@ func TestAuthValidateEmail(t *testing.T) { config, _ = NewConfig([]string{}) // Should allow any - v := ValidateEmail("test@test.com") + v := ValidateEmail("test@test.com", "default") assert.True(v, "should allow any domain if email domain is not defined") - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.True(v, "should allow any domain if email domain is not defined") // Should block non matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.False(v, "should not allow user from another domain") // Should allow matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user from allowed domain") // Should block non whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.False(v, "should not allow user not in whitelist") // Should allow matching whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user in whitelist") + + // Should allow matching whitelisted in rules email address + config.Domains = []string{"globaltestdomain.com"} + config.Whitelist = []string{} + config.Rules = map[string]*Rule{"test": NewRule()} + config.Rules["test"].Whitelist = []string{"test@test.com"} + // Validation for user in the rule whitelist + v = ValidateEmail("test@test.com", "test") + assert.True(v, "should allow user in rule whitelist") + // Validation for user not in the rule whitelist + v = ValidateEmail("test2@test.com", "test") + assert.False(v, "should not allow user not in rule whitelist") + // Validation for user in global domain but not in rule + v = ValidateEmail("test@globaltestdomain.com", "test") + assert.False(v, "should not allow user in global but not in rule") + // Validation for user in the whitelist, but that not this rule + v = ValidateEmail("test@test.com", "default") + assert.False(v, "should not allow user not in the rule whitelisted") + + // Should allow matching domains + config.Domains = []string{"globaltestdomain.com"} + config.Whitelist = []string{} + config.Rules = map[string]*Rule{"test": NewRule()} + config.Rules["test"].Domains = []string{"test.com"} + // Validation for user in the rule domains + v = ValidateEmail("test@test.com", "test") + assert.True(v, "should allow user in rule domains") + // Validation for user not in the rule whitelist + v = ValidateEmail("test@test2.com", "test") + assert.False(v, "should not allow user not in rule domains") + // Validation for user in the whitelist, but that not this rule + v = ValidateEmail("test@test.com", "default") + assert.False(v, "should not allow user not in the rule") } // TODO: Split google tests out diff --git a/internal/config.go b/internal/config.go index 5ea19a6d..02f1d0b0 100644 --- a/internal/config.go +++ b/internal/config.go @@ -213,6 +213,14 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [ rule.Rule = val case "provider": rule.Provider = val + case "whitelist": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Whitelist = list + case "domains": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Domains = list default: return args, fmt.Errorf("inavlid route param: %v", option) } @@ -266,9 +274,11 @@ func (c Config) String() string { } type Rule struct { - Action string - Rule string - Provider string + Action string + Rule string + Provider string + Whitelist CommaSeparatedList + Domains CommaSeparatedList } func NewRule() *Rule { diff --git a/internal/server.go b/internal/server.go index 3119876b..831b8793 100644 --- a/internal/server.go +++ b/internal/server.go @@ -90,7 +90,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc { } // Validate user - valid := ValidateEmail(email) + valid := ValidateEmail(email, rule) if !valid { logger.WithFields(logrus.Fields{ "email": email, From 8e24d4d7cdcd1d13bb1d077b06d83585f5a7bb8d Mon Sep 17 00:00:00 2001 From: Pete Shaw Date: Mon, 8 Jun 2020 12:57:24 +0100 Subject: [PATCH 2/5] Fix lint for function comments --- internal/auth.go | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/internal/auth.go b/internal/auth.go index 95466f9f..eacab507 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -57,49 +57,44 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { } // ValidateEmail verifies that an email is permitted by the current config -func ValidateEmail(email string, rule string) bool { - found := false - - _, ruleExists := config.Rules[rule] - if ruleExists && len(config.Rules[rule].Whitelist) > 0 { - found = ValidateWhitelist(email, config.Rules[rule].Whitelist) - } else if ruleExists && len(config.Rules[rule].Domains) > 0 { - found = ValidateDomains(email, config.Rules[rule].Domains) +func ValidateEmail(email string, ruleName string) bool { + rule, ruleExists := config.Rules[ruleName] + + if ruleExists && len(rule.Whitelist) > 0 { + return ValidateWhitelist(email, rule.Whitelist) + } else if ruleExists && len(rule.Domains) > 0 { + return ValidateDomains(email, rule.Domains) } else if len(config.Whitelist) > 0 { - found = ValidateWhitelist(email, config.Whitelist) + return ValidateWhitelist(email, config.Whitelist) } else if len(config.Domains) > 0 { - found = ValidateDomains(email, config.Domains) + return ValidateDomains(email, config.Domains) } else { return true } - - return found } -// Validate email is in whitelist +// ValidateWhitelist checks if the email is in whitelist func ValidateWhitelist(email string, whitelist CommaSeparatedList) bool { - found := false for _, whitelist := range whitelist { if email == whitelist { - found = true + return true } } - return found + return false } -// Validate email match a domains +// ValidateDomains checks if the email matches a whitelisted domain func ValidateDomains(email string, domains CommaSeparatedList) bool { - found := false parts := strings.Split(email, "@") if len(parts) < 2 { return false } for _, domain := range domains { if domain == parts[1] { - found = true + return true } } - return found + return false } // Utility methods From ececa20edc71693ee3e6126b24f5a6dc84a9db2d Mon Sep 17 00:00:00 2001 From: Pete Shaw Date: Wed, 10 Jun 2020 10:12:48 +0100 Subject: [PATCH 3/5] Simplify bool logic --- internal/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/auth.go b/internal/auth.go index 03d4be94..a91d9c9a 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -63,7 +63,7 @@ func ValidateEmail(email string, ruleName string) bool { rule, ruleExists := config.Rules[ruleName] // Do we need to apply rule-level validation? - if ruleExists && !(len(rule.Whitelist) == 0 && len(rule.Domains) == 0) { + if ruleExists && (len(rule.Whitelist) > 0 || len(rule.Domains) > 0) { if len(rule.Whitelist) > 0 && ValidateWhitelist(email, rule.Whitelist) { return true } else if config.MatchWhitelistOrDomain && len(rule.Domains) > 0 && ValidateDomains(email, rule.Domains) { From 0d371c41c1af0de22acf7a9bb9fd36a2683f1bb8 Mon Sep 17 00:00:00 2001 From: Pete Shaw Date: Wed, 10 Jun 2020 10:31:10 +0100 Subject: [PATCH 4/5] Fixup readme typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6bb5c5c6..e03ee155 100644 --- a/README.md +++ b/README.md @@ -330,7 +330,7 @@ You can restrict who can login with the following parameters: * `domain` - Use this to limit logins to a specific domain, e.g. test.com only * `whitelist` - Use this to only allow specific users to login e.g. thom@test.com only -Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). If you set `domains` or `whitelist` on a rules, the global configuration is ignored. +Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). If you set `domains` or `whitelist` on a rule, the global configuration is ignored. ### Forwarded Headers From db030979cce9fb8e061ecb23627bab039dfe1a8e Mon Sep 17 00:00:00 2001 From: Pete Shaw Date: Wed, 10 Jun 2020 10:32:38 +0100 Subject: [PATCH 5/5] Use pointers for ruleName This removes the chance for the _default_ (no rule applied) case to overlap with a user-specified rule --- internal/auth.go | 24 +++++++++-------- internal/auth_test.go | 60 +++++++++++++++++++++++-------------------- internal/server.go | 18 ++++++------- 3 files changed, 54 insertions(+), 48 deletions(-) diff --git a/internal/auth.go b/internal/auth.go index a91d9c9a..4184bf74 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -59,17 +59,19 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { // ValidateEmail checks if the given email address matches either a whitelisted // email address, as defined by the "whitelist" config parameter. Or is part of // a permitted domain, as defined by the "domains" config parameter -func ValidateEmail(email string, ruleName string) bool { - rule, ruleExists := config.Rules[ruleName] - - // Do we need to apply rule-level validation? - if ruleExists && (len(rule.Whitelist) > 0 || len(rule.Domains) > 0) { - if len(rule.Whitelist) > 0 && ValidateWhitelist(email, rule.Whitelist) { - return true - } else if config.MatchWhitelistOrDomain && len(rule.Domains) > 0 && ValidateDomains(email, rule.Domains) { - return true - } else { - return false +func ValidateEmail(email string, ruleName *string) bool { + if ruleName != nil { + rule, ruleExists := config.Rules[*ruleName] + + // Do we need to apply rule-level validation? + if ruleExists && (len(rule.Whitelist) > 0 || len(rule.Domains) > 0) { + if len(rule.Whitelist) > 0 && ValidateWhitelist(email, rule.Whitelist) { + return true + } else if config.MatchWhitelistOrDomain && len(rule.Domains) > 0 && ValidateDomains(email, rule.Domains) { + return true + } else { + return false + } } } diff --git a/internal/auth_test.go b/internal/auth_test.go index d8cab899..a358ec50 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -65,33 +65,34 @@ func TestAuthValidateCookie(t *testing.T) { func TestAuthValidateEmail(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) + var ruleName string // Should allow any - v := ValidateEmail("test@test.com", "default") + v := ValidateEmail("test@test.com", nil) assert.True(v, "should allow any domain if email domain is not defined") - v = ValidateEmail("one@two.com", "default") + v = ValidateEmail("one@two.com", nil) assert.True(v, "should allow any domain if email domain is not defined") // Should block non matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("one@two.com", "default") + v = ValidateEmail("one@two.com", nil) assert.False(v, "should not allow user from another domain") // Should allow matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("test@test.com", "default") + v = ValidateEmail("test@test.com", nil) assert.True(v, "should allow user from allowed domain") // Should block non whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("one@two.com", "default") + v = ValidateEmail("one@two.com", nil) assert.False(v, "should not allow user not in whitelist") // Should allow matching whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("test@test.com", "default") + v = ValidateEmail("test@test.com", nil) assert.True(v, "should allow user in whitelist") // Should allow only matching email address when @@ -99,11 +100,11 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = false - v = ValidateEmail("test@test.com", "default") + v = ValidateEmail("test@test.com", nil) assert.True(v, "should allow user in whitelist") - v = ValidateEmail("test@example.com", "default") + v = ValidateEmail("test@example.com", nil) assert.False(v, "should not allow user from valid domain") - v = ValidateEmail("one@two.com", "default") + v = ValidateEmail("one@two.com", nil) assert.False(v, "should not allow user not in either") // Should allow only matching email address when @@ -111,18 +112,19 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.Rules = map[string]*Rule{"test": NewRule()} - config.Rules["test"].Whitelist = []string{"test@testrule.com"} - config.Rules["test"].Domains = []string{"testruledomain.com"} + ruleName = "test" + config.Rules[ruleName].Whitelist = []string{"test@testrule.com"} + config.Rules[ruleName].Domains = []string{"testruledomain.com"} config.MatchWhitelistOrDomain = false - v = ValidateEmail("test@testrule.com", "test") + v = ValidateEmail("test@testrule.com", &ruleName) assert.True(v, "should allow user in rule whitelist") - v = ValidateEmail("test@test.com", "test") + v = ValidateEmail("test@test.com", &ruleName) assert.False(v, "should not allow user in global whitelist") - v = ValidateEmail("test@testruledomain.com", "test") + v = ValidateEmail("test@testruledomain.com", &ruleName) assert.False(v, "should not allow user in rule domain list") - v = ValidateEmail("test@example.com", "test") + v = ValidateEmail("test@example.com", &ruleName) assert.False(v, "should not allow user from valid global domain") - v = ValidateEmail("one@two.com", "test") + v = ValidateEmail("one@two.com", &ruleName) assert.False(v, "should not allow user not in either") // Should allow either matching domain or email address when @@ -130,11 +132,11 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = true - v = ValidateEmail("test@test.com", "default") + v = ValidateEmail("test@test.com", nil) assert.True(v, "should allow user in whitelist") - v = ValidateEmail("test@example.com", "default") + v = ValidateEmail("test@example.com", nil) assert.True(v, "should allow user from valid domain") - v = ValidateEmail("one@two.com", "default") + v = ValidateEmail("one@two.com", nil) assert.False(v, "should not allow user not in either") // Should allow either matching domain or email address when @@ -142,28 +144,30 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.Rules = map[string]*Rule{"test": NewRule()} - config.Rules["test"].Whitelist = []string{"test@testrule.com"} - config.Rules["test"].Domains = []string{"testruledomain.com"} + ruleName = "test" + config.Rules[ruleName].Whitelist = []string{"test@testrule.com"} + config.Rules[ruleName].Domains = []string{"testruledomain.com"} config.MatchWhitelistOrDomain = true - v = ValidateEmail("test@testrule.com", "test") + v = ValidateEmail("test@testrule.com", &ruleName) assert.True(v, "should allow user in rule whitelist") - v = ValidateEmail("test@test.com", "test") + v = ValidateEmail("test@test.com", &ruleName) assert.False(v, "should not allow user in global whitelist") - v = ValidateEmail("test@testruledomain.com", "test") + v = ValidateEmail("test@testruledomain.com", &ruleName) assert.True(v, "should allow user in rule domain list") - v = ValidateEmail("test@example.com", "test") + v = ValidateEmail("test@example.com", &ruleName) assert.False(v, "should not allow user from valid global domain") - v = ValidateEmail("one@two.com", "test") + v = ValidateEmail("one@two.com", &ruleName) assert.False(v, "should not allow user not in either") // Rules should use global whitelist/domains config when not specified config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.Rules = map[string]*Rule{"test": NewRule()} + ruleName = "test" config.MatchWhitelistOrDomain = true - v = ValidateEmail("test@test.com", "test") + v = ValidateEmail("test@test.com", &ruleName) assert.True(v, "should allow user in global whitelist") - v = ValidateEmail("test@example.com", "test") + v = ValidateEmail("test@example.com", &ruleName) assert.True(v, "should allow user from valid global domain") } diff --git a/internal/server.go b/internal/server.go index 45c13e8c..042628fe 100644 --- a/internal/server.go +++ b/internal/server.go @@ -32,9 +32,9 @@ func (s *Server) buildRoutes() { for name, rule := range config.Rules { matchRule := rule.formattedRule() if rule.Action == "allow" { - s.router.AddRoute(matchRule, 1, s.AllowHandler(name)) + s.router.AddRoute(matchRule, 1, s.AllowHandler(&name)) } else { - s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name)) + s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, &name)) } } @@ -46,9 +46,9 @@ func (s *Server) buildRoutes() { // Add a default handler if config.DefaultAction == "allow" { - s.router.NewRoute().Handler(s.AllowHandler("default")) + s.router.NewRoute().Handler(s.AllowHandler(nil)) } else { - s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default")) + s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, nil)) } } @@ -65,7 +65,7 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) { } // AllowHandler Allows requests -func (s *Server) AllowHandler(rule string) http.HandlerFunc { +func (s *Server) AllowHandler(rule *string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { s.logger(r, "Allow", rule, "Allowing request") w.WriteHeader(200) @@ -73,7 +73,7 @@ func (s *Server) AllowHandler(rule string) http.HandlerFunc { } // AuthHandler Authenticates requests -func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { +func (s *Server) AuthHandler(providerName string, rule *string) http.HandlerFunc { p, _ := config.GetConfiguredProvider(providerName) return func(w http.ResponseWriter, r *http.Request) { @@ -119,7 +119,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { func (s *Server) AuthCallbackHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Logging setup - logger := s.logger(r, "AuthCallback", "default", "Handling callback") + logger := s.logger(r, "AuthCallback", nil, "Handling callback") // Check for CSRF cookie c, err := r.Cookie(config.CSRFCookieName) @@ -189,7 +189,7 @@ func (s *Server) LogoutHandler() http.HandlerFunc { // Clear cookie http.SetCookie(w, ClearCookie(r)) - logger := s.logger(r, "Logout", "default", "Handling logout") + logger := s.logger(r, "Logout", nil, "Handling logout") logger.Info("Logged out user") if config.LogoutRedirect != "" { @@ -229,7 +229,7 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht }).Debug("Set CSRF cookie and redirected to provider login url") } -func (s *Server) logger(r *http.Request, handler, rule, msg string) *logrus.Entry { +func (s *Server) logger(r *http.Request, handler string, rule *string, msg string) *logrus.Entry { // Create logger logger := log.WithFields(logrus.Fields{ "handler": handler,