Skip to content

Commit

Permalink
Add ability to provide JWKS directly instead of through URI (#79)
Browse files Browse the repository at this point in the history
* Add ability to provide JWKS directly instead of through URI

This adds the JWKS field to the AuthConfig which allows you to pass a
`jose.JSONWebKeySet` directly to the config instead of attempting to
fetch from a URI.

* Prevent fetch JWKS if not using URI

* Make refreshJWKS a no-op if using JWKS directly

* Add more tests and return an error if JWKS configuration is invalid

* Fix test name
  • Loading branch information
adammohammed authored Feb 13, 2023
1 parent f1f6f95 commit 487d58f
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 23 deletions.
6 changes: 3 additions & 3 deletions ginauth/multitokenmiddleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ func TestMultitokenMiddlewareValidatesTokens(t *testing.T) {

for _, tt := range testCases {
t.Run(tt.testName, func(t *testing.T) {
jwksURI1 := ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
jwksURI2 := ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey3ID, ginjwt.TestPrivRSAKey4ID)
jwksURI1 := ginjwt.TestHelperJWKSURIProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
jwksURI2 := ginjwt.TestHelperJWKSURIProvider(ginjwt.TestPrivRSAKey3ID, ginjwt.TestPrivRSAKey4ID)

cfg1 := ginjwt.AuthConfig{Enabled: true, Audience: tt.middlewareAud, Issuer: tt.middlewareIss, JWKSURI: jwksURI1}
cfg2 := ginjwt.AuthConfig{Enabled: true, Audience: tt.middlewareAud, Issuer: tt.middlewareIss, JWKSURI: jwksURI2}
Expand Down Expand Up @@ -258,7 +258,7 @@ func TestMultitokenInvalidAuthHeader(t *testing.T) {

for _, tt := range testCases {
t.Run(tt.testName, func(t *testing.T) {
jwksURI := ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
jwksURI := ginjwt.TestHelperJWKSURIProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
cfg := ginjwt.AuthConfig{Enabled: true, Audience: "aud", Issuer: "iss", JWKSURI: jwksURI}
authMW, err := ginjwt.NewMultiTokenMiddlewareFromConfigs(cfg)
require.NoError(t, err)
Expand Down
3 changes: 3 additions & 0 deletions ginjwt/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ var (

// ErrMissingJWKURIFlag is an error returned when the JWK URI isn't provided via a command line flag.
ErrMissingJWKURIFlag = errors.New("JWK URI wasn't provided")

// ErrJWKSConfigConflict is an error when both JWKSURI and JWKS are set
ErrJWKSConfigConflict = errors.New("JWKS and JWKSURI can't both be set at the same time")
)
34 changes: 28 additions & 6 deletions ginjwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ type Middleware struct {

// AuthConfig provides the configuration for the authentication service
type AuthConfig struct {
Enabled bool
Audience string
Issuer string
JWKSURI string
Enabled bool
Audience string
Issuer string
JWKSURI string

// JWKS allows the user to specify the JWKS directly instead of through URI
JWKS jose.JSONWebKeySet
LogFields []string
RolesClaim string
UsernameClaim string
Expand All @@ -70,8 +73,22 @@ func NewAuthMiddleware(cfg AuthConfig) (*Middleware, error) {
return mw, nil
}

if err := mw.refreshJWKS(); err != nil {
return nil, err
uriProvided := (cfg.JWKSURI != "")
jwksProvided := len(cfg.JWKS.Keys) > 0

// Either they were both provided, or neither was provided
if uriProvided == jwksProvided {
return nil, fmt.Errorf("%w: either JWKSURI or JWKS must be provided", ErrInvalidAuthConfig)
}

// Only refresh JWKSURI if static one isn't provided
if len(cfg.JWKS.Keys) > 0 {
mw.cachedJWKS = cfg.JWKS
} else {
// Fetch JWKS from URI
if err := mw.refreshJWKS(); err != nil {
return nil, err
}
}

return mw, nil
Expand Down Expand Up @@ -236,6 +253,11 @@ func (m *Middleware) VerifyScopes(c *gin.Context, scopes []string) error {
func (m *Middleware) refreshJWKS() error {
var ctx context.Context

// When using JWKS directly, refresh should be a no-op
if len(m.config.JWKS.Keys) > 0 {
return nil
}

if m.config.JWKSRemoteTimeout != 0 {
var cancel context.CancelFunc

Expand Down
136 changes: 130 additions & 6 deletions ginjwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
middlewareScopes []string
signingKey *rsa.PrivateKey
signingKeyID string
jwksFromURI bool
claims jwt.Claims
claimScopes []string
responseCode int
Expand All @@ -40,6 +41,7 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
[]string{"testScope"},
ginjwt.TestPrivRSAKey1,
"randomUnknownID",
true,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
Expand All @@ -57,6 +59,7 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
[]string{"testScope"},
ginjwt.TestPrivRSAKey1,
ginjwt.TestPrivRSAKey2ID,
true,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
Expand All @@ -74,6 +77,7 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
[]string{"testScope"},
ginjwt.TestPrivRSAKey1,
ginjwt.TestPrivRSAKey1ID,
true,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
Expand All @@ -91,6 +95,7 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
[]string{"testScope"},
ginjwt.TestPrivRSAKey1,
ginjwt.TestPrivRSAKey1ID,
true,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
Expand All @@ -108,6 +113,7 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
[]string{"adminscope"},
ginjwt.TestPrivRSAKey1,
ginjwt.TestPrivRSAKey1ID,
true,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
Expand All @@ -125,6 +131,7 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
[]string{"testScope"},
ginjwt.TestPrivRSAKey1,
ginjwt.TestPrivRSAKey1ID,
true,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
Expand All @@ -143,6 +150,7 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
[]string{"testScope"},
ginjwt.TestPrivRSAKey1,
ginjwt.TestPrivRSAKey1ID,
true,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
Expand All @@ -160,6 +168,7 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
[]string{"testScope"},
ginjwt.TestPrivRSAKey1,
ginjwt.TestPrivRSAKey1ID,
true,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
Expand All @@ -170,13 +179,55 @@ func TestMiddlewareValidatesTokensWithScopes(t *testing.T) {
http.StatusOK,
"ok",
},
{
"invalid key with directly loaded JWKS",
"ginjwt.test",
"ginjwt.test.issuer",
[]string{"testScope"},
ginjwt.TestPrivRSAKey1,
ginjwt.TestPrivRSAKey1ID,
false,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
NotBefore: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
Audience: jwt.Audience{"ginjwt.test", "another.test.service"},
},
[]string{"testScope", "anotherScope", "more-scopes"},
http.StatusOK,
"ok",
},
{
"valid with directly loaded JWKS",
"ginjwt.test",
"ginjwt.test.issuer",
[]string{"testScope"},
ginjwt.TestPrivRSAKey1,
"invalid-key-id",
false,
jwt.Claims{
Subject: "test-user",
Issuer: "ginjwt.test.issuer",
NotBefore: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
Audience: jwt.Audience{"ginjwt.test", "another.test.service"},
},
[]string{"testScope", "anotherScope", "more-scopes"},
http.StatusUnauthorized,
"invalid token signing key",
},
}

for _, tt := range testCases {
t.Run(tt.testName, func(t *testing.T) {
jwksURI := ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
var jwksURI string
var jwks jose.JSONWebKeySet
if tt.jwksFromURI {
jwksURI = ginjwt.TestHelperJWKSURIProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
} else {
jwks = ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
}

cfg := ginjwt.AuthConfig{Enabled: true, Audience: tt.middlewareAud, Issuer: tt.middlewareIss, JWKSURI: jwksURI}
cfg := ginjwt.AuthConfig{Enabled: true, Audience: tt.middlewareAud, Issuer: tt.middlewareIss, JWKSURI: jwksURI, JWKS: jwks}
authMW, err := ginjwt.NewAuthMiddleware(cfg)
require.NoError(t, err)

Expand Down Expand Up @@ -338,7 +389,7 @@ func TestMiddlewareAuthRequired(t *testing.T) {

for _, tt := range testCases {
t.Run(tt.testName, func(t *testing.T) {
jwksURI := ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
jwksURI := ginjwt.TestHelperJWKSURIProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)

cfg := ginjwt.AuthConfig{Enabled: true, Audience: tt.middlewareAud, Issuer: tt.middlewareIss, JWKSURI: jwksURI}
authMW, err := ginjwt.NewAuthMiddleware(cfg)
Expand Down Expand Up @@ -400,7 +451,7 @@ func TestInvalidAuthHeader(t *testing.T) {

for _, tt := range testCases {
t.Run(tt.testName, func(t *testing.T) {
jwksURI := ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
jwksURI := ginjwt.TestHelperJWKSURIProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
cfg := ginjwt.AuthConfig{Enabled: true, Audience: "aud", Issuer: "iss", JWKSURI: jwksURI}
authMW, err := ginjwt.NewAuthMiddleware(cfg)
require.NoError(t, err)
Expand All @@ -424,7 +475,7 @@ func TestInvalidAuthHeader(t *testing.T) {
}

func TestInvalidJWKURIWithWrongPath(t *testing.T) {
uri := ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
uri := ginjwt.TestHelperJWKSURIProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
uri += "/some-extra-path"
cfg := ginjwt.AuthConfig{Enabled: true, Audience: "aud", Issuer: "iss", JWKSURI: uri}
_, err := ginjwt.NewAuthMiddleware(cfg)
Expand Down Expand Up @@ -588,7 +639,7 @@ func TestVerifyTokenWithScopes(t *testing.T) {

for _, tt := range testCases {
t.Run(tt.testName, func(t *testing.T) {
jwksURI := ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
jwksURI := ginjwt.TestHelperJWKSURIProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
config := ginjwt.AuthConfig{
Enabled: true,
Audience: tt.middlewareAud,
Expand Down Expand Up @@ -620,3 +671,76 @@ func TestVerifyTokenWithScopes(t *testing.T) {
})
}
}

func TestAuthMiddlewareConfig(t *testing.T) {
jwks := ginjwt.TestHelperJWKSProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)
jwksURI := ginjwt.TestHelperJWKSURIProvider(ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey2ID)

testCases := []struct {
name string
input ginjwt.AuthConfig
checkFn func(*testing.T, ginauth.GenericAuthMiddleware, error)
}{
{
name: "ValidWithJWKS",
input: ginjwt.AuthConfig{
Enabled: true,
Audience: "example-aud",
Issuer: "example-iss",
JWKS: jwks,
RoleValidationStrategy: "all",
},
checkFn: func(t *testing.T, mw ginauth.GenericAuthMiddleware, err error) {
assert.NoError(t, err)
assert.NotNil(t, mw)
},
},
{
name: "ValidWithJWKSURI",
input: ginjwt.AuthConfig{
Enabled: true,
Audience: "example-aud",
Issuer: "example-iss",
JWKSURI: jwksURI,
RoleValidationStrategy: "all",
},
checkFn: func(t *testing.T, mw ginauth.GenericAuthMiddleware, err error) {
assert.NoError(t, err)
assert.NotNil(t, mw)
},
},
{
name: "InvalidJWKSConfig",
input: ginjwt.AuthConfig{
Enabled: true,
Audience: "example-aud",
Issuer: "example-iss",
JWKSURI: jwksURI,
JWKS: jwks,
RoleValidationStrategy: "all",
},
checkFn: func(t *testing.T, mw ginauth.GenericAuthMiddleware, err error) {
assert.ErrorIs(t, err, ginjwt.ErrInvalidAuthConfig)
},
},
{
name: "MissingJWKSConfig",
input: ginjwt.AuthConfig{
Enabled: true,
Audience: "example-aud",
Issuer: "example-iss",
RoleValidationStrategy: "all",
},
checkFn: func(t *testing.T, mw ginauth.GenericAuthMiddleware, err error) {
assert.ErrorIs(t, err, ginjwt.ErrInvalidAuthConfig)
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(tt *testing.T) {
mw, err := ginjwt.NewAuthMiddleware(tc.input)
tc.checkFn(tt, mw, err)
})
}
}
23 changes: 15 additions & 8 deletions ginjwt/testtools.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@ func TestHelperMustMakeSigner(alg jose.SignatureAlgorithm, kid string, k interfa
return sig
}

// TestHelperJWKSProvider returns a url for a webserver that will return JSONWebKeySets
func TestHelperJWKSProvider(keyIDs ...string) string {
gin.SetMode(gin.TestMode)
r := gin.New()

// TestHelperJWKSProvider returns a JWKS
func TestHelperJWKSProvider(keyIDs ...string) jose.JSONWebKeySet {
jwks := make([]jose.JSONWebKey, len(keyIDs))

for idx, keyID := range keyIDs {
Expand All @@ -76,10 +73,20 @@ func TestHelperJWKSProvider(keyIDs ...string) string {
}
}

return jose.JSONWebKeySet{
Keys: jwks,
}
}

// TestHelperJWKSURIProvider returns a url for a webserver that will return JSONWebKeySets
func TestHelperJWKSURIProvider(keyIDs ...string) string {
gin.SetMode(gin.TestMode)
r := gin.New()

keySet := TestHelperJWKSProvider(keyIDs...)

r.GET("/.well-known/jwks.json", func(c *gin.Context) {
c.JSON(http.StatusOK, jose.JSONWebKeySet{
Keys: jwks,
})
c.JSON(http.StatusOK, keySet)
})

listener, err := net.Listen("tcp", ":0")
Expand Down

0 comments on commit 487d58f

Please sign in to comment.