Skip to content

Commit

Permalink
Use pointers for ruleName
Browse files Browse the repository at this point in the history
This removes the chance for the _default_ (no rule applied) case to overlap with a user-specified rule
  • Loading branch information
lozlow committed Jun 10, 2020
1 parent 0d371c4 commit db03097
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 48 deletions.
24 changes: 13 additions & 11 deletions internal/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,19 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
// ValidateEmail checks if the given email address matches either a whitelisted
// email address, as defined by the "whitelist" config parameter. Or is part of
// a permitted domain, as defined by the "domains" config parameter
func ValidateEmail(email string, ruleName string) bool {
rule, ruleExists := config.Rules[ruleName]

// Do we need to apply rule-level validation?
if ruleExists && (len(rule.Whitelist) > 0 || len(rule.Domains) > 0) {
if len(rule.Whitelist) > 0 && ValidateWhitelist(email, rule.Whitelist) {
return true
} else if config.MatchWhitelistOrDomain && len(rule.Domains) > 0 && ValidateDomains(email, rule.Domains) {
return true
} else {
return false
func ValidateEmail(email string, ruleName *string) bool {
if ruleName != nil {
rule, ruleExists := config.Rules[*ruleName]

// Do we need to apply rule-level validation?
if ruleExists && (len(rule.Whitelist) > 0 || len(rule.Domains) > 0) {
if len(rule.Whitelist) > 0 && ValidateWhitelist(email, rule.Whitelist) {
return true
} else if config.MatchWhitelistOrDomain && len(rule.Domains) > 0 && ValidateDomains(email, rule.Domains) {
return true
} else {
return false
}
}
}

Expand Down
60 changes: 32 additions & 28 deletions internal/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,105 +65,109 @@ func TestAuthValidateCookie(t *testing.T) {
func TestAuthValidateEmail(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
var ruleName string

// Should allow any
v := ValidateEmail("test@test.com", "default")
v := ValidateEmail("test@test.com", nil)
assert.True(v, "should allow any domain if email domain is not defined")
v = ValidateEmail("one@two.com", "default")
v = ValidateEmail("one@two.com", nil)
assert.True(v, "should allow any domain if email domain is not defined")

// Should block non matching domain
config.Domains = []string{"test.com"}
v = ValidateEmail("one@two.com", "default")
v = ValidateEmail("one@two.com", nil)
assert.False(v, "should not allow user from another domain")

// Should allow matching domain
config.Domains = []string{"test.com"}
v = ValidateEmail("test@test.com", "default")
v = ValidateEmail("test@test.com", nil)
assert.True(v, "should allow user from allowed domain")

// Should block non whitelisted email address
config.Domains = []string{}
config.Whitelist = []string{"test@test.com"}
v = ValidateEmail("one@two.com", "default")
v = ValidateEmail("one@two.com", nil)
assert.False(v, "should not allow user not in whitelist")

// Should allow matching whitelisted email address
config.Domains = []string{}
config.Whitelist = []string{"test@test.com"}
v = ValidateEmail("test@test.com", "default")
v = ValidateEmail("test@test.com", nil)
assert.True(v, "should allow user in whitelist")

// Should allow only matching email address when
// MatchWhitelistOrDomain is disabled
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.MatchWhitelistOrDomain = false
v = ValidateEmail("test@test.com", "default")
v = ValidateEmail("test@test.com", nil)
assert.True(v, "should allow user in whitelist")
v = ValidateEmail("test@example.com", "default")
v = ValidateEmail("test@example.com", nil)
assert.False(v, "should not allow user from valid domain")
v = ValidateEmail("one@two.com", "default")
v = ValidateEmail("one@two.com", nil)
assert.False(v, "should not allow user not in either")

// Should allow only matching email address when
// MatchWhitelistOrDomain is disabled with Rules
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.Rules = map[string]*Rule{"test": NewRule()}
config.Rules["test"].Whitelist = []string{"test@testrule.com"}
config.Rules["test"].Domains = []string{"testruledomain.com"}
ruleName = "test"
config.Rules[ruleName].Whitelist = []string{"test@testrule.com"}
config.Rules[ruleName].Domains = []string{"testruledomain.com"}
config.MatchWhitelistOrDomain = false
v = ValidateEmail("test@testrule.com", "test")
v = ValidateEmail("test@testrule.com", &ruleName)
assert.True(v, "should allow user in rule whitelist")
v = ValidateEmail("test@test.com", "test")
v = ValidateEmail("test@test.com", &ruleName)
assert.False(v, "should not allow user in global whitelist")
v = ValidateEmail("test@testruledomain.com", "test")
v = ValidateEmail("test@testruledomain.com", &ruleName)
assert.False(v, "should not allow user in rule domain list")
v = ValidateEmail("test@example.com", "test")
v = ValidateEmail("test@example.com", &ruleName)
assert.False(v, "should not allow user from valid global domain")
v = ValidateEmail("one@two.com", "test")
v = ValidateEmail("one@two.com", &ruleName)
assert.False(v, "should not allow user not in either")

// Should allow either matching domain or email address when
// MatchWhitelistOrDomain is enabled
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.MatchWhitelistOrDomain = true
v = ValidateEmail("test@test.com", "default")
v = ValidateEmail("test@test.com", nil)
assert.True(v, "should allow user in whitelist")
v = ValidateEmail("test@example.com", "default")
v = ValidateEmail("test@example.com", nil)
assert.True(v, "should allow user from valid domain")
v = ValidateEmail("one@two.com", "default")
v = ValidateEmail("one@two.com", nil)
assert.False(v, "should not allow user not in either")

// Should allow either matching domain or email address when
// MatchWhitelistOrDomain is enabled with Rules
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.Rules = map[string]*Rule{"test": NewRule()}
config.Rules["test"].Whitelist = []string{"test@testrule.com"}
config.Rules["test"].Domains = []string{"testruledomain.com"}
ruleName = "test"
config.Rules[ruleName].Whitelist = []string{"test@testrule.com"}
config.Rules[ruleName].Domains = []string{"testruledomain.com"}
config.MatchWhitelistOrDomain = true
v = ValidateEmail("test@testrule.com", "test")
v = ValidateEmail("test@testrule.com", &ruleName)
assert.True(v, "should allow user in rule whitelist")
v = ValidateEmail("test@test.com", "test")
v = ValidateEmail("test@test.com", &ruleName)
assert.False(v, "should not allow user in global whitelist")
v = ValidateEmail("test@testruledomain.com", "test")
v = ValidateEmail("test@testruledomain.com", &ruleName)
assert.True(v, "should allow user in rule domain list")
v = ValidateEmail("test@example.com", "test")
v = ValidateEmail("test@example.com", &ruleName)
assert.False(v, "should not allow user from valid global domain")
v = ValidateEmail("one@two.com", "test")
v = ValidateEmail("one@two.com", &ruleName)
assert.False(v, "should not allow user not in either")

// Rules should use global whitelist/domains config when not specified
config.Domains = []string{"example.com"}
config.Whitelist = []string{"test@test.com"}
config.Rules = map[string]*Rule{"test": NewRule()}
ruleName = "test"
config.MatchWhitelistOrDomain = true
v = ValidateEmail("test@test.com", "test")
v = ValidateEmail("test@test.com", &ruleName)
assert.True(v, "should allow user in global whitelist")
v = ValidateEmail("test@example.com", "test")
v = ValidateEmail("test@example.com", &ruleName)
assert.True(v, "should allow user from valid global domain")
}

Expand Down
18 changes: 9 additions & 9 deletions internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ func (s *Server) buildRoutes() {
for name, rule := range config.Rules {
matchRule := rule.formattedRule()
if rule.Action == "allow" {
s.router.AddRoute(matchRule, 1, s.AllowHandler(name))
s.router.AddRoute(matchRule, 1, s.AllowHandler(&name))
} else {
s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name))
s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, &name))
}
}

Expand All @@ -46,9 +46,9 @@ func (s *Server) buildRoutes() {

// Add a default handler
if config.DefaultAction == "allow" {
s.router.NewRoute().Handler(s.AllowHandler("default"))
s.router.NewRoute().Handler(s.AllowHandler(nil))
} else {
s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default"))
s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, nil))
}
}

Expand All @@ -65,15 +65,15 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
}

// AllowHandler Allows requests
func (s *Server) AllowHandler(rule string) http.HandlerFunc {
func (s *Server) AllowHandler(rule *string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s.logger(r, "Allow", rule, "Allowing request")
w.WriteHeader(200)
}
}

// AuthHandler Authenticates requests
func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
func (s *Server) AuthHandler(providerName string, rule *string) http.HandlerFunc {
p, _ := config.GetConfiguredProvider(providerName)

return func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -119,7 +119,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
func (s *Server) AuthCallbackHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := s.logger(r, "AuthCallback", "default", "Handling callback")
logger := s.logger(r, "AuthCallback", nil, "Handling callback")

// Check for CSRF cookie
c, err := r.Cookie(config.CSRFCookieName)
Expand Down Expand Up @@ -189,7 +189,7 @@ func (s *Server) LogoutHandler() http.HandlerFunc {
// Clear cookie
http.SetCookie(w, ClearCookie(r))

logger := s.logger(r, "Logout", "default", "Handling logout")
logger := s.logger(r, "Logout", nil, "Handling logout")
logger.Info("Logged out user")

if config.LogoutRedirect != "" {
Expand Down Expand Up @@ -229,7 +229,7 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht
}).Debug("Set CSRF cookie and redirected to provider login url")
}

func (s *Server) logger(r *http.Request, handler, rule, msg string) *logrus.Entry {
func (s *Server) logger(r *http.Request, handler string, rule *string, msg string) *logrus.Entry {
// Create logger
logger := log.WithFields(logrus.Fields{
"handler": handler,
Expand Down

0 comments on commit db03097

Please sign in to comment.