diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..b41214e6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,36 @@ +name: CI + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + + test: + name: Test + runs-on: ubuntu-latest + steps: + + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: Get dependencies + run: | + go get -v -t -d ./... + if [ -f Gopkg.toml ]; then + curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh + dep ensure + fi + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v ./... diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 00000000..f8786ef5 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,71 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +name: "CodeQL" + +on: + push: + branches: [master] + pull_request: + # The branches below must be a subset of the branches above + branches: [master] + schedule: + - cron: '0 10 * * 2' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + # Override automatic language detection by changing the below list + # Supported options are ['csharp', 'cpp', 'go', 'java', 'javascript', 'python'] + language: ['go'] + # Learn more... + # https://docs.github.com/en/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#overriding-automatic-language-detection + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + with: + # We must fetch at least the immediate parents so that if this is + # a pull request then we can checkout the head. + fetch-depth: 2 + + # If this run was triggered by a pull request event, then checkout + # the head of the pull request instead of the merge commit. + - run: git checkout HEAD^2 + if: ${{ github.event_name == 'pull_request' }} + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹī¸ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏ī¸ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml new file mode 100644 index 00000000..65fef000 --- /dev/null +++ b/.github/workflows/push.yml @@ -0,0 +1,39 @@ +name: Traefik Forward Auth +on: [push] +jobs: + test: + name: Test with Go version - + runs-on: ubuntu-latest + + strategy: + matrix: + go: ['1.12', '1.13', '1.14'] + + steps: + - uses: actions/checkout@v1 + + - name: Setup Go + uses: actions/setup-go@v1 + with: + go-version: ${{ matrix.go }} + + - name: Run Tests + run: go test ./... + + publish: + name: Publish Docker image + needs: test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@master + with: + fetch-depth: '0' + - name: Publish to Docker Registry + uses: docker/build-push-action@v1 + with: + repository: ${{ github.repository }} + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + tag_with_ref: true + tag_with_sha: true + diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index bcd1d064..00000000 --- a/.travis.yml +++ /dev/null @@ -1,5 +0,0 @@ -language: go -sudo: false -go: - - "1.12" -script: env GO111MODULE=on go test -v ./... diff --git a/README.md b/README.md index aea9da5a..fdfeffff 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# Traefik Forward Auth [![Build Status](https://travis-ci.org/thomseddon/traefik-forward-auth.svg?branch=master)](https://travis-ci.org/thomseddon/traefik-forward-auth) [![Go Report Card](https://goreportcard.com/badge/github.com/thomseddon/traefik-forward-auth)](https://goreportcard.com/report/github.com/thomseddon/traefik-forward-auth) ![Docker Pulls](https://img.shields.io/docker/pulls/thomseddon/traefik-forward-auth.svg) [![GitHub release](https://img.shields.io/github/release/thomseddon/traefik-forward-auth.svg)](https://GitHub.com/thomseddon/traefik-forward-auth/releases/) +# Traefik Forward Auth ![Build Status](https://img.shields.io/github/workflow/status/thomseddon/traefik-forward-auth/CI) [![Go Report Card](https://goreportcard.com/badge/github.com/thomseddon/traefik-forward-auth)](https://goreportcard.com/report/github.com/thomseddon/traefik-forward-auth) ![Docker Pulls](https://img.shields.io/docker/pulls/thomseddon/traefik-forward-auth.svg) [![GitHub release](https://img.shields.io/github/release/thomseddon/traefik-forward-auth.svg)](https://GitHub.com/thomseddon/traefik-forward-auth/releases/) A minimal forward authentication service that provides OAuth/SSO login and authentication for the [traefik](https://github.com/containous/traefik) reverse proxy/load balancer. @@ -9,8 +9,8 @@ A minimal forward authentication service that provides OAuth/SSO login and authe - Seamlessly overlays any http service with a single endpoint (see: `url-path` in [Configuration](#configuration)) - Supports multiple providers including Google and OpenID Connect (supported by Azure, Github, Salesforce etc.) - Supports multiple domains/subdomains by dynamically generating redirect_uri's -- Allows authentication to be selectively applied/bypassed based on request parameters (see `rules` in [Configuration](#configuration))) -- Supports use of centralised authentication host/redirect_uri (see `auth-host` in [Configuration](#configuration))) +- Allows authentication to be selectively applied/bypassed based on request parameters (see `rules` in [Configuration](#configuration)) +- Supports use of centralised authentication host/redirect_uri (see `auth-host` in [Configuration](#configuration)) - Allows authentication to persist across multiple domains (see [Cookie Domains](#cookie-domains)) - Supports extended authentication beyond Google token lifetime (see: `lifetime` in [Configuration](#configuration)) @@ -321,6 +321,7 @@ All options can be supplied in any of the following ways, in the following prece - `action` - same usage as [`default-action`](#default-action), supported values: - `auth` (default) - `allow` + - `domains` - optional, same usage as [`domain`](#domain) - `provider` - same usage as [`default-provider`](#default-provider), supported values: - `google` - `oidc` @@ -333,6 +334,7 @@ All options can be supplied in any of the following ways, in the following prece - ``Path(`path`, `/articles/{category}/{id:[0-9]+}`, ...)`` - ``PathPrefix(`/products/`, `/articles/{category}/{id:[0-9]+}`)`` - ``Query(`foo=bar`, `bar=baz`)`` + - `whitelist` - optional, same usage as whitelist`](#whitelist) For example: ``` @@ -348,6 +350,11 @@ All options can be supplied in any of the following ways, in the following prece rule.oidc.action = auth rule.oidc.provider = oidc rule.oidc.rule = PathPrefix(`/github`) + + # Allow jane@example.com to `/janes-eyes-only` + rule.two.action = allow + rule.two.rule = Path(`/janes-eyes-only`) + rule.two.whitelist = jane@example.com ``` Note: It is possible to break your redirect flow with rules, please be careful not to create an `allow` rule that matches your redirect_uri unless you know what you're doing. This limitation is being tracked in in #101 and the behaviour will change in future releases. @@ -361,7 +368,7 @@ You can restrict who can login with the following parameters: * `domain` - Use this to limit logins to a specific domain, e.g. test.com only * `whitelist` - Use this to only allow specific users to login e.g. thom@test.com only -Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). +Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). If you set `domains` or `whitelist` on a rule, the global configuration is ignored. ### Forwarded Headers diff --git a/internal/auth.go b/internal/auth.go index 9c667a01..1fac6617 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -59,18 +59,28 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { // ValidateUser checks if the given user matches either a whitelisted // user, as defined by the "whitelist" config parameter. Or is part of // a permitted domain, as defined by the "domains" config parameter -func ValidateUser(user string) bool { +func ValidateUser(user, ruleName string) bool { + // Use global config by default + whitelist := config.Whitelist + domains := config.Domains + + if rule, ok := config.Rules[ruleName]; ok { + // Override with rule config if found + if len(rule.Whitelist) > 0 || len(rule.Domains) > 0 { + whitelist = rule.Whitelist + domains = rule.Domains + } + } + // Do we have any validation to perform? - if len(config.Whitelist) == 0 && len(config.Domains) == 0 { + if len(whitelist) == 0 && len(domains) == 0 { return true } // Email whitelist validation - if len(config.Whitelist) > 0 { - for _, whitelist := range config.Whitelist { - if user == whitelist { - return true - } + if len(whitelist) > 0 { + if ValidateWhitelist(user, whitelist) { + return true } // If we're not matching *either*, stop here @@ -80,18 +90,34 @@ func ValidateUser(user string) bool { } // Domain validation - if len(config.Domains) > 0 { - parts := strings.Split(user, "@") - if len(parts) < 2 { - return false - } - for _, domain := range config.Domains { - if domain == parts[1] { - return true - } + if len(domains) > 0 && ValidateDomains(user, domains) { + return true + } + + return false +} + +// ValidateWhitelist checks if the email is in whitelist +func ValidateWhitelist(user string, whitelist CommaSeparatedList) bool { + for _, whitelist := range whitelist { + if user == whitelist { + return true } } + return false +} +// ValidateDomains checks if the email matches a whitelisted domain +func ValidateDomains(user string, domains CommaSeparatedList) bool { + parts := strings.Split(user, "@") + if len(parts) < 2 { + return false + } + for _, domain := range domains { + if domain == parts[1] { + return true + } + } return false } @@ -170,23 +196,31 @@ func ClearCookie(r *http.Request) *http.Cookie { } } +func buildCSRFCookieName(nonce string) string { + return config.CSRFCookieName + "_" + nonce[:6] +} + // MakeCSRFCookie makes a csrf cookie (used during login only) +// +// Note, CSRF cookies live shorter than auth cookies, a fixed 1h. +// That's because some CSRF cookies may belong to auth flows that don't complete +// and thus may not get cleared by ClearCookie. func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: buildCSRFCookieName(nonce), Value: nonce, Path: "/", Domain: csrfCookieDomain(r), HttpOnly: true, Secure: !config.InsecureCookie, - Expires: cookieExpiry(), + Expires: time.Now().Local().Add(time.Hour * 1), } } // ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie -func ClearCSRFCookie(r *http.Request) *http.Cookie { +func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: c.Name, Value: "", Path: "/", Domain: csrfCookieDomain(r), @@ -196,18 +230,18 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie { } } -// ValidateCSRFCookie validates the csrf cookie against state -func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) { - state := r.URL.Query().Get("state") +// FindCSRFCookie extracts the CSRF cookie from the request based on state. +func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) { + // Check for CSRF cookie + return r.Cookie(buildCSRFCookieName(state)) +} +// ValidateCSRFCookie validates the csrf cookie against state +func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) { if len(c.Value) != 32 { return false, "", "", errors.New("Invalid CSRF cookie value") } - if len(state) < 34 { - return false, "", "", errors.New("Invalid CSRF state value") - } - // Check nonce match if c.Value != state[:32] { return false, "", "", errors.New("CSRF cookie does not match state") @@ -229,6 +263,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string { return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r)) } +// ValidateState checks whether the state is of right length. +func ValidateState(state string) error { + if len(state) < 34 { + return errors.New("Invalid CSRF state value") + } + return nil +} + // Nonce generates a random nonce func Nonce() (error, string) { nonce := make([]byte, 16) @@ -350,7 +392,7 @@ func (c *CookieDomains) UnmarshalFlag(value string) error { return nil } -// MarshalFlag converts an array of CookieDomain to a comma separated list +// MarshalFlag converts an array of CookieDomain to a comma seperated list func (c *CookieDomains) MarshalFlag() (string, error) { var domains []string for _, d := range *c { diff --git a/internal/auth_test.go b/internal/auth_test.go index d626830c..931dd199 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -1,7 +1,6 @@ package tfa import ( - "fmt" "net/http" "net/url" "strings" @@ -62,36 +61,35 @@ func TestAuthValidateCookie(t *testing.T) { assert.Equal("test@test.com", email, "valid request should return user email") } -func TestAuthValidateEmail(t *testing.T) { +func TestAuthValidateUser(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) - // Should allow any - v := ValidateUser("test@test.com") + // Should allow any with no whitelist/domain is specified + v := ValidateUser("test@test.com", "default") assert.True(v, "should allow any domain if email domain is not defined") - v = ValidateUser("one@two.com") + v = ValidateUser("one@two.com", "default") assert.True(v, "should allow any domain if email domain is not defined") - // Should block non matching domain - config.Domains = []string{"test.com"} - v = ValidateUser("one@two.com") - assert.False(v, "should not allow user from another domain") - // Should allow matching domain config.Domains = []string{"test.com"} - v = ValidateUser("test@test.com") + v = ValidateUser("one@two.com", "default") + assert.False(v, "should not allow user from another domain") + v = ValidateUser("test@test.com", "default") assert.True(v, "should allow user from allowed domain") // Should block non whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateUser("one@two.com") + v = ValidateUser("one@two.com", "default") assert.False(v, "should not allow user not in whitelist") // Should allow matching whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateUser("test@test.com") + v = ValidateUser("one@two.com", "default") + assert.False(v, "should not allow user not in whitelist") + v = ValidateUser("test@test.com", "default") assert.True(v, "should allow user in whitelist") // Should allow only matching email address when @@ -99,28 +97,115 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = false - v = ValidateUser("test@test.com") + v = ValidateUser("test@test.com", "default") assert.True(v, "should allow user in whitelist") - v = ValidateUser("test@example.com") + v = ValidateUser("test@example.com", "default") assert.False(v, "should not allow user from valid domain") - v = ValidateUser("one@two.com") + v = ValidateUser("one@two.com", "default") assert.False(v, "should not allow user not in either") + v = ValidateUser("test@example.com", "default") + assert.False(v, "should not allow user from allowed domain") + v = ValidateUser("test@test.com", "default") + assert.True(v, "should allow user in whitelist") // Should allow either matching domain or email address when // MatchWhitelistOrDomain is enabled config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = true - v = ValidateUser("test@test.com") + v = ValidateUser("one@two.com", "default") + assert.False(v, "should not allow user not in either") + v = ValidateUser("test@example.com", "default") + assert.True(v, "should allow user from allowed domain") + v = ValidateUser("test@test.com", "default") assert.True(v, "should allow user in whitelist") - v = ValidateUser("test@example.com") + v = ValidateUser("test@example.com", "default") assert.True(v, "should allow user from valid domain") - v = ValidateUser("one@two.com") + + // Rule testing + + // Should use global whitelist/domain when not specified on rule + config.Domains = []string{"example.com"} + config.Whitelist = []string{"test@test.com"} + config.Rules = map[string]*Rule{"test": NewRule()} + config.MatchWhitelistOrDomain = true + v = ValidateUser("one@two.com", "test") assert.False(v, "should not allow user not in either") + v = ValidateUser("test@example.com", "test") + assert.True(v, "should allow user from allowed global domain") + v = ValidateUser("test@test.com", "test") + assert.True(v, "should allow user in global whitelist") + + // Should allow matching domain in rule + config.Domains = []string{"testglobal.com"} + config.Whitelist = []string{} + rule := NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Domains = []string{"testrule.com"} + config.MatchWhitelistOrDomain = false + v = ValidateUser("one@two.com", "test") + assert.False(v, "should not allow user from another domain") + v = ValidateUser("one@testglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateUser("test@testrule.com", "test") + assert.True(v, "should allow user from allowed domain") // Should allow comma separated email config.Whitelist = []string{"test@test.com", "test2@test2.com"} - v = ValidateUser("test2@test2.com") + v = ValidateUser("test2@test2.com", "default") + + // Should allow matching whitelist in rule + config.Domains = []string{} + config.Whitelist = []string{"test@testglobal.com"} + rule = NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Whitelist = []string{"test@testrule.com"} + config.MatchWhitelistOrDomain = false + v = ValidateUser("one@two.com", "test") + assert.False(v, "should not allow user from another domain") + v = ValidateUser("test@testglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateUser("test@testrule.com", "test") + assert.True(v, "should allow user from allowed domain") + + // Should allow only matching email address when + // MatchWhitelistOrDomain is disabled + config.Domains = []string{"exampleglobal.com"} + config.Whitelist = []string{"test@testglobal.com"} + rule = NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Domains = []string{"examplerule.com"} + rule.Whitelist = []string{"test@testrule.com"} + config.MatchWhitelistOrDomain = false + v = ValidateUser("one@two.com", "test") + assert.False(v, "should not allow user not in either") + v = ValidateUser("test@testglobal.com", "test") + assert.False(v, "should not allow user in global whitelist") + v = ValidateUser("test@exampleglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateUser("test@examplerule.com", "test") + assert.False(v, "should not allow user from allowed domain") + v = ValidateUser("test@testrule.com", "test") + assert.True(v, "should allow user in whitelist") + + // Should allow either matching domain or email address when + // MatchWhitelistOrDomain is enabled + config.Domains = []string{"exampleglobal.com"} + config.Whitelist = []string{"test@testglobal.com"} + rule = NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Domains = []string{"examplerule.com"} + rule.Whitelist = []string{"test@testrule.com"} + config.MatchWhitelistOrDomain = true + v = ValidateUser("one@two.com", "test") + assert.False(v, "should not allow user not in either") + v = ValidateUser("test@testglobal.com", "test") + assert.False(v, "should not allow user in global whitelist") + v = ValidateUser("test@exampleglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateUser("test@examplerule.com", "test") + assert.True(v, "should allow user from allowed domain") + v = ValidateUser("test@testrule.com", "test") assert.True(v, "should allow user in whitelist") } @@ -222,29 +307,30 @@ func TestAuthMakeCSRFCookie(t *testing.T) { // No cookie domain or auth url c := MakeCSRFCookie(r, "12345678901234567890123456789012") + assert.Equal("_forward_auth_csrf_123456", c.Name) assert.Equal("app.example.com", c.Domain) // With cookie domain but no auth url - config = &Config{ - CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, - } - c = MakeCSRFCookie(r, "12345678901234567890123456789012") + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} + c = MakeCSRFCookie(r, "12222278901234567890123456789012") + assert.Equal("_forward_auth_csrf_122222", c.Name) assert.Equal("app.example.com", c.Domain) // With cookie domain and auth url - config = &Config{ - AuthHost: "auth.example.com", - CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, - } - c = MakeCSRFCookie(r, "12345678901234567890123456789012") + config.AuthHost = "auth.example.com" + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} + c = MakeCSRFCookie(r, "12333378901234567890123456789012") + assert.Equal("_forward_auth_csrf_123333", c.Name) assert.Equal("example.com", c.Domain) } func TestAuthClearCSRFCookie(t *testing.T) { + assert := assert.New(t) config, _ = NewConfig([]string{}) r, _ := http.NewRequest("GET", "http://example.com", nil) - c := ClearCSRFCookie(r) + c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"}) + assert.Equal("someCsrfCookie", c.Name) if c.Value != "" { t.Error("ClearCSRFCookie should create cookie with empty value") } @@ -254,56 +340,57 @@ func TestAuthValidateCSRFCookie(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) c := &http.Cookie{} - - newCsrfRequest := func(state string) *http.Request { - u := fmt.Sprintf("http://example.com?state=%s", state) - r, _ := http.NewRequest("GET", u, nil) - return r - } + state := "" // Should require 32 char string - r := newCsrfRequest("") + state = "" c.Value = "" - valid, _, _, err := ValidateCSRFCookie(r, c) + valid, _, _, err := ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF cookie value", err.Error()) } c.Value = "123456789012345678901234567890123" - valid, _, _, err = ValidateCSRFCookie(r, c) + valid, _, _, err = ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF cookie value", err.Error()) } - // Should require valid state - r = newCsrfRequest("12345678901234567890123456789012:") - c.Value = "12345678901234567890123456789012" - valid, _, _, err = ValidateCSRFCookie(r, c) - assert.False(valid) - if assert.Error(err) { - assert.Equal("Invalid CSRF state value", err.Error()) - } - // Should require provider - r = newCsrfRequest("12345678901234567890123456789012:99") + state = "12345678901234567890123456789012:99" c.Value = "12345678901234567890123456789012" - valid, _, _, err = ValidateCSRFCookie(r, c) + valid, _, _, err = ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF state format", err.Error()) } // Should allow valid state - r = newCsrfRequest("12345678901234567890123456789012:p99:url123") + state = "12345678901234567890123456789012:p99:url123" c.Value = "12345678901234567890123456789012" - valid, provider, redirect, err := ValidateCSRFCookie(r, c) + valid, provider, redirect, err := ValidateCSRFCookie(c, state) assert.True(valid, "valid request should return valid") assert.Nil(err, "valid request should not return an error") assert.Equal("p99", provider, "valid request should return correct provider") assert.Equal("url123", redirect, "valid request should return correct redirect") } +func TestValidateState(t *testing.T) { + assert := assert.New(t) + + // Should require valid state + state := "12345678901234567890123456789012:" + err := ValidateState(state) + if assert.Error(err) { + assert.Equal("Invalid CSRF state value", err.Error()) + } + // Should pass this state + state = "12345678901234567890123456789012:p99:url123" + err = ValidateState(state) + assert.Nil(err, "valid request should not return an error") +} + func TestMakeState(t *testing.T) { assert := assert.New(t) diff --git a/internal/config.go b/internal/config.go index 70b2f778..9b5859d5 100644 --- a/internal/config.go +++ b/internal/config.go @@ -211,6 +211,14 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [ rule.Rule = val case "provider": rule.Provider = val + case "whitelist": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Whitelist = list + case "domains": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Domains = list default: return args, fmt.Errorf("invalid route param: %v", option) } @@ -327,9 +335,11 @@ func (c *Config) setupProvider(name string) error { // Rule holds defined rules type Rule struct { - Action string - Rule string - Provider string + Action string + Rule string + Provider string + Whitelist CommaSeparatedList + Domains CommaSeparatedList } // NewRule creates a new rule object diff --git a/internal/server.go b/internal/server.go index 2d6b2b15..55288659 100644 --- a/internal/server.go +++ b/internal/server.go @@ -102,7 +102,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { } // Validate user - valid := ValidateUser(user) + valid := ValidateUser(user, rule) if !valid { logger.WithField("user", user).Warn("Invalid user") http.Error(w, fmt.Sprintf("User '%s' is not authorized", user), 401) @@ -122,16 +122,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Logging setup logger := s.logger(r, "AuthCallback", "default", "Handling callback") + // Check state + state := r.URL.Query().Get("state") + if err := ValidateState(state); err != nil { + logger.WithFields(logrus.Fields{ + "error": err, + }).Warn("Error validating state") + http.Error(w, "Not authorized", 401) + return + } + // Check for CSRF cookie - c, err := r.Cookie(config.CSRFCookieName) + c, err := FindCSRFCookie(r, state) if err != nil { logger.Info("Missing csrf cookie") http.Error(w, "Not authorized", 401) return } - // Validate state - valid, providerName, redirect, err := ValidateCSRFCookie(r, c) + // Validate CSRF cookie against state + valid, providerName, redirect, err := ValidateCSRFCookie(c, state) if !valid { logger.WithFields(logrus.Fields{ "error": err, @@ -154,7 +164,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } // Clear CSRF cookie - http.SetCookie(w, ClearCSRFCookie(r)) + http.SetCookie(w, ClearCSRFCookie(r, c)) // Exchange code for token token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code")) diff --git a/internal/server_test.go b/internal/server_test.go index 2e543400..8ec0f01d 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -98,7 +98,7 @@ func TestServerAuthHandlerExpired(t *testing.T) { // Check for CSRF cookie var cookie *http.Cookie for _, c := range res.Cookies() { - if c.Name == config.CSRFCookieName { + if strings.HasPrefix(c.Name, config.CSRFCookieName) { cookie = c } }