Skip to content

Commit

Permalink
filters/auth/grant: support client id and secret file placeholders (#…
Browse files Browse the repository at this point in the history
…2246)

This change enables `{host}` placeholder in the client id and secret filenames.

E.g. for the request to `foo.example.org` when flag values are `-oauth2-client-id-file=/var/run/secrets/{host}-client-id` and
`-oauth2-client-secret-file=/var/run/secrets/{host}-client-secret`
the client id and secret files would be `/var/run/secrets/foo.example.org-client-id` and
`/var/run/secrets/foo.example.org-client-secret` respectively.

Signed-off-by: Alexander Yastrebov <alexander.yastrebov@zalando.de>
  • Loading branch information
AlexanderYastrebov authored Mar 7, 2023
1 parent 2c1c3b3 commit 5c9d7c2
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 91 deletions.
6 changes: 4 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,10 @@ func NewConfig() *Config {
flag.StringVar(&cfg.Oauth2SecretFile, "oauth2-secret-file", "", "sets the filename with the encryption key for the authentication cookie and grant flow state stored in secrets registry")
flag.StringVar(&cfg.Oauth2ClientID, "oauth2-client-id", "", "sets the OAuth2 client id of the current service, used to exchange the access code")
flag.StringVar(&cfg.Oauth2ClientSecret, "oauth2-client-secret", "", "sets the OAuth2 client secret associated with the oauth2-client-id, used to exchange the access code")
flag.StringVar(&cfg.Oauth2ClientIDFile, "oauth2-client-id-file", "", "sets the path of the file containing the OAuth2 client id of the current service, used to exchange the access code")
flag.StringVar(&cfg.Oauth2ClientSecretFile, "oauth2-client-secret-file", "", "sets the path of the file containing the OAuth2 client secret associated with the oauth2-client-id, used to exchange the access code")
flag.StringVar(&cfg.Oauth2ClientIDFile, "oauth2-client-id-file", "", "sets the path of the file containing the OAuth2 client id of the current service, used to exchange the access code. "+
"File name may contain {host} placeholder which will be replaced by the request host")
flag.StringVar(&cfg.Oauth2ClientSecretFile, "oauth2-client-secret-file", "", "sets the path of the file containing the OAuth2 client secret associated with the oauth2-client-id, used to exchange the access code. "+
"File name may contain {host} placeholder which will be replaced by the request host")
flag.StringVar(&cfg.Oauth2CallbackPath, "oauth2-callback-path", "", "sets the path where the OAuth2 callback requests with the authorization code should be redirected to")
flag.DurationVar(&cfg.Oauth2TokeninfoTimeout, "oauth2-tokeninfo-timeout", 2*time.Second, "sets the default tokeninfo request timeout duration to 2000ms")
flag.IntVar(&cfg.Oauth2TokeninfoCacheSize, "oauth2-tokeninfo-cache-size", 0, "non-zero value enables tokeninfo cache and sets the maximum number of cached tokens")
Expand Down
24 changes: 19 additions & 5 deletions filters/auth/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ func loginRedirect(ctx filters.FilterContext, config *OAuthConfig) {

func loginRedirectWithOverride(ctx filters.FilterContext, config *OAuthConfig, originalOverride string) {
req := ctx.Request()

authConfig, err := config.GetConfig(req)
if err != nil {
log.Debugf("Failed to obtain auth config: %v", err)
ctx.Serve(&http.Response{
StatusCode: http.StatusForbidden,
})
return
}

redirect, original := config.RedirectURLs(req)

if originalOverride != "" {
Expand All @@ -75,7 +85,6 @@ func loginRedirectWithOverride(ctx filters.FilterContext, config *OAuthConfig, o
return
}

authConfig := config.GetConfig()
ctx.Serve(&http.Response{
StatusCode: http.StatusTemporaryRedirect,
Header: http.Header{
Expand All @@ -84,7 +93,7 @@ func loginRedirectWithOverride(ctx filters.FilterContext, config *OAuthConfig, o
})
}

func (f *grantFilter) refreshToken(c *cookie) (*oauth2.Token, error) {
func (f *grantFilter) refreshToken(c *cookie, req *http.Request) (*oauth2.Token, error) {
// Set the expiry of the token to the past to trigger oauth2.TokenSource
// to refresh the access token.
token := &oauth2.Token{
Expand All @@ -95,9 +104,14 @@ func (f *grantFilter) refreshToken(c *cookie) (*oauth2.Token, error) {

ctx := providerContext(f.config)

authConfig, err := f.config.GetConfig(req)
if err != nil {
return nil, err
}

// oauth2.TokenSource implements the refresh functionality,
// we're hijacking it here.
tokenSource := f.config.GetConfig().TokenSource(ctx, token)
tokenSource := authConfig.TokenSource(ctx, token)
return tokenSource.Token()
}

Expand All @@ -106,7 +120,7 @@ func (f *grantFilter) refreshTokenIfRequired(c *cookie, ctx filters.FilterContex

if c.isAccessTokenExpired() {
if canRefresh {
token, err := f.refreshToken(c)
token, err := f.refreshToken(c, ctx.Request())
if err == nil {
// Remember that this token was just successfully refreshed
// so that we can send an updated cookie in the response.
Expand Down Expand Up @@ -175,7 +189,7 @@ func (f *grantFilter) Request(ctx filters.FilterContext) {
}

token, err := f.refreshTokenIfRequired(c, ctx)
if err != nil && c.isAccessTokenExpired() {
if err != nil {
// Refresh failed and we no longer have a valid access token.
loginRedirect(ctx, f.config)
return
Expand Down
177 changes: 177 additions & 0 deletions filters/auth/grant_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package auth_test

import (
"context"
"crypto/tls"
"encoding/json"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"

Expand All @@ -20,6 +24,7 @@ import (
"github.com/zalando/skipper/secrets"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
Expand All @@ -35,6 +40,42 @@ const (
testQueryParamValue = "param_value"
)

type loggingRoundTripper struct {
http.RoundTripper
t *testing.T
}

func (rt *loggingRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
rt.t.Logf("\n%v", rt.requestString(req))

resp, err = rt.RoundTripper.RoundTrip(req)

if err == nil {
rt.t.Logf("\n%v", rt.responseString(resp))
} else {
rt.t.Logf("response err: %v", err)
}
return
}

func (rt *loggingRoundTripper) requestString(req *http.Request) string {
tmp := req.Clone(context.Background())
tmp.Body = nil

var b strings.Builder
_ = tmp.Write(&b)
return b.String()
}

func (rt *loggingRoundTripper) responseString(resp *http.Response) string {
tmp := *resp
tmp.Body = nil

var b strings.Builder
_ = tmp.Write(&b)
return b.String()
}

func newGrantTestTokeninfo(validToken string, tokenInfoJSON string) *httptest.Server {
if tokenInfoJSON == "" {
tokenInfoJSON = "{}"
Expand Down Expand Up @@ -131,6 +172,7 @@ func newGrantTestConfig(tokeninfoURL, providerURL string) *auth.OAuthConfig {
ClientID: testClientID,
ClientSecret: testClientSecret,
Secrets: secrets.NewRegistry(),
SecretsProvider: secrets.NewSecretPaths(1 * time.Hour),
SecretFile: testSecretFile,
TokeninfoURL: tokeninfoURL,
AuthURL: providerURL + "/auth",
Expand Down Expand Up @@ -843,3 +885,138 @@ func TestGrantTokeninfoKeys(t *testing.T) {

assert.JSONEq(t, `{"uid":"bar", "scope":["baz"]}`, rsp.Header.Get("Backend-X-Tokeninfo-Forward"))
}

func TestGrantCredentialsFile(t *testing.T) {
const (
fooDomain = "foo.skipper.test"
barDomain = "bar.skipper.test"
)

dnstest.LoopbackNames(t, fooDomain, barDomain)

secretsDir := t.TempDir()

clientIdFile := secretsDir + "/test-client-id"
clientSecretFile := secretsDir + "/test-client-secret"

require.NoError(t, os.WriteFile(clientIdFile, []byte(testClientID), 0644))
require.NoError(t, os.WriteFile(clientSecretFile, []byte(testClientSecret), 0644))

provider := newGrantTestAuthServer(testToken, testAccessCode)
defer provider.Close()

tokeninfo := newGrantTestTokeninfo(testToken, "")
defer tokeninfo.Close()

zero := 0
config := newGrantTestConfig(tokeninfo.URL, provider.URL)
config.TokenCookieRemoveSubdomains = &zero
config.ClientID = ""
config.ClientSecret = ""
config.ClientIDFile = clientIdFile
config.ClientSecretFile = clientSecretFile

routes := eskip.MustParse(`* -> oauthGrant() -> status(204) -> <shunt>`)

proxy, client := newAuthProxy(t, config, routes, fooDomain, barDomain)
defer proxy.Close()

// Follow redirects as store cookies
client.CheckRedirect = nil
client.Jar, _ = cookiejar.New(nil)
httpLogger := &loggingRoundTripper{client.Transport, t}
client.Transport = httpLogger

resetClient := func(t *testing.T) {
client.Jar, _ = cookiejar.New(nil)
httpLogger.t = t
}

t.Run("request to "+fooDomain+" succeeds", func(t *testing.T) {
resetClient(t)

rsp, err := client.Get(proxy.URL + "/test")
require.NoError(t, err)
rsp.Body.Close()

checkStatus(t, rsp, http.StatusNoContent)
})

t.Run("request to "+barDomain+" succeeds", func(t *testing.T) {
resetClient(t)

barUrl := "https://" + net.JoinHostPort(barDomain, proxy.Port)

rsp, err := client.Get(barUrl + "/test")
require.NoError(t, err)
rsp.Body.Close()

checkStatus(t, rsp, http.StatusNoContent)
})
}

func TestGrantCredentialsPlaceholder(t *testing.T) {
const (
fooDomain = "foo.skipper.test"
barDomain = "bar.skipper.test"
)

dnstest.LoopbackNames(t, fooDomain, barDomain)

secretsDir := t.TempDir()

require.NoError(t, os.WriteFile(secretsDir+"/"+fooDomain+"-client-id", []byte(testClientID), 0644))
require.NoError(t, os.WriteFile(secretsDir+"/"+fooDomain+"-client-secret", []byte(testClientSecret), 0644))

provider := newGrantTestAuthServer(testToken, testAccessCode)
defer provider.Close()

tokeninfo := newGrantTestTokeninfo(testToken, "")
defer tokeninfo.Close()

zero := 0
config := newGrantTestConfig(tokeninfo.URL, provider.URL)
config.TokenCookieRemoveSubdomains = &zero
config.ClientID = ""
config.ClientSecret = ""
config.ClientIDFile = secretsDir + "/{host}-client-id"
config.ClientSecretFile = secretsDir + "/{host}-client-secret"

routes := eskip.MustParse(`* -> oauthGrant() -> status(204) -> <shunt>`)

proxy, client := newAuthProxy(t, config, routes, fooDomain, barDomain)
defer proxy.Close()

// Follow redirects as store cookies
client.CheckRedirect = nil
client.Jar, _ = cookiejar.New(nil)
httpLogger := &loggingRoundTripper{client.Transport, t}
client.Transport = httpLogger

resetClient := func(t *testing.T) {
client.Jar, _ = cookiejar.New(nil)
httpLogger.t = t
}

t.Run("request to the hostname with existing client credentials succeeds", func(t *testing.T) {
resetClient(t)

rsp, err := client.Get(proxy.URL + "/test")
require.NoError(t, err)
rsp.Body.Close()

checkStatus(t, rsp, http.StatusNoContent)
})

t.Run("request to the hostname without existing client credentials is forbidden", func(t *testing.T) {
resetClient(t)

barUrl := "https://" + net.JoinHostPort(barDomain, proxy.Port)

rsp, err := client.Get(barUrl + "/test")
require.NoError(t, err)
rsp.Body.Close()

checkStatus(t, rsp, http.StatusForbidden)
})
}
18 changes: 9 additions & 9 deletions filters/auth/grantcallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ func (s *grantCallbackSpec) CreateFilter([]interface{}) (filters.Filter, error)
}, nil
}

func (f *grantCallbackFilter) exchangeAccessToken(code string, redirectURI string) (*oauth2.Token, error) {
func (f *grantCallbackFilter) exchangeAccessToken(req *http.Request, code string) (*oauth2.Token, error) {
authConfig, err := f.config.GetConfig(req)
if err != nil {
return nil, err
}
redirectURI, _ := f.config.RedirectURLs(req)
ctx := providerContext(f.config)
params := f.config.GetAuthURLParameters(redirectURI)
return f.config.GetConfig().Exchange(ctx, code, params...)
return authConfig.Exchange(ctx, code, params...)
}

func (f *grantCallbackFilter) loginCallback(ctx filters.FilterContext) {
func (f *grantCallbackFilter) Request(ctx filters.FilterContext) {
req := ctx.Request()
q := req.URL.Query()

Expand Down Expand Up @@ -79,8 +84,7 @@ func (f *grantCallbackFilter) loginCallback(ctx filters.FilterContext) {
return
}

redirectURI, _ := f.config.RedirectURLs(req)
token, err := f.exchangeAccessToken(code, redirectURI)
token, err := f.exchangeAccessToken(req, code)
if err != nil {
log.Errorf("Failed to exchange access token: %v.", err)
serverError(ctx)
Expand All @@ -103,8 +107,4 @@ func (f *grantCallbackFilter) loginCallback(ctx filters.FilterContext) {
})
}

func (f *grantCallbackFilter) Request(ctx filters.FilterContext) {
f.loginCallback(ctx)
}

func (f *grantCallbackFilter) Response(ctx filters.FilterContext) {}
Loading

0 comments on commit 5c9d7c2

Please sign in to comment.