Skip to content

Commit

Permalink
refactor: improve dependency injection capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Jul 11, 2024
1 parent 27de382 commit 010f694
Show file tree
Hide file tree
Showing 13 changed files with 249 additions and 167 deletions.
7 changes: 2 additions & 5 deletions compose/compose_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,10 @@ type HMACSHAStrategyConfigurator interface {
}

func NewOAuth2HMACStrategy(config HMACSHAStrategyConfigurator) *oauth2.HMACSHAStrategy {
return &oauth2.HMACSHAStrategy{
Enigma: &hmac.HMACStrategy{Config: config},
Config: config,
}
return oauth2.NewHMACSHAStrategy(&hmac.HMACStrategy{Config: config}, config)
}

func NewOAuth2JWTStrategy(keyGetter func(context.Context) (interface{}, error), strategy *oauth2.HMACSHAStrategy, config fosite.Configurator) *oauth2.DefaultJWTStrategy {
func NewOAuth2JWTStrategy(keyGetter func(context.Context) (interface{}, error), strategy oauth2.CoreStrategy, config fosite.Configurator) *oauth2.DefaultJWTStrategy {
return &oauth2.DefaultJWTStrategy{
Signer: &jwt.DefaultSigner{GetPrivateKey: keyGetter},
HMACSHAStrategy: strategy,
Expand Down
2 changes: 1 addition & 1 deletion handler/oauth2/flow_authorize_code_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func parseUrl(uu string) *url.URL {

func TestAuthorizeCode_HandleAuthorizeEndpointRequest(t *testing.T) {
for k, strategy := range map[string]CoreStrategy{
"hmac": &hmacshaStrategy,
"hmac": hmacshaStrategy,
} {
t.Run("strategy="+k, func(t *testing.T) {
store := storage.NewMemoryStore()
Expand Down
12 changes: 6 additions & 6 deletions handler/oauth2/flow_authorize_code_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (

func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) {
for k, strategy := range map[string]CoreStrategy{
"hmac": &hmacshaStrategy,
"hmac": hmacshaStrategy,
} {
t.Run("strategy="+k, func(t *testing.T) {
store := storage.NewMemoryStore()
Expand Down Expand Up @@ -241,14 +241,14 @@ func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) {

func TestAuthorizeCode_HandleTokenEndpointRequest(t *testing.T) {
for k, strategy := range map[string]CoreStrategy{
"hmac": &hmacshaStrategy,
"hmac": hmacshaStrategy,
} {
t.Run("strategy="+k, func(t *testing.T) {
store := storage.NewMemoryStore()

h := AuthorizeExplicitGrantHandler{
CoreStorage: store,
AuthorizeCodeStrategy: &hmacshaStrategy,
AuthorizeCodeStrategy: hmacshaStrategy,
TokenRevocationStorage: store,
Config: &fosite.Config{
ScopeStrategy: fosite.HierarchicScopeStrategy,
Expand Down Expand Up @@ -657,9 +657,9 @@ func TestAuthorizeCodeTransactional_HandleTokenEndpointRequest(t *testing.T) {
mockTransactional,
mockCoreStore,
},
AccessTokenStrategy: &strategy,
RefreshTokenStrategy: &strategy,
AuthorizeCodeStrategy: &strategy,
AccessTokenStrategy: strategy,
RefreshTokenStrategy: strategy,
AuthorizeCodeStrategy: strategy,
Config: &fosite.Config{
ScopeStrategy: fosite.HierarchicScopeStrategy,
AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy,
Expand Down
12 changes: 6 additions & 6 deletions handler/oauth2/flow_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) {
}

for k, strategy := range map[string]RefreshTokenStrategy{
"hmac": &hmacshaStrategy,
"hmac": hmacshaStrategy,
} {
t.Run("strategy="+k, func(t *testing.T) {
store := storage.NewMemoryStore()
Expand Down Expand Up @@ -419,8 +419,8 @@ func TestRefreshFlowTransactional_HandleTokenEndpointRequest(t *testing.T) {
mockTransactional,
mockRevocationStore,
},
AccessTokenStrategy: &hmacshaStrategy,
RefreshTokenStrategy: &hmacshaStrategy,
AccessTokenStrategy: hmacshaStrategy,
RefreshTokenStrategy: hmacshaStrategy,
Config: &fosite.Config{
AccessTokenLifespan: time.Hour,
ScopeStrategy: fosite.HierarchicScopeStrategy,
Expand All @@ -440,7 +440,7 @@ func TestRefreshFlow_PopulateTokenEndpointResponse(t *testing.T) {
var aresp *fosite.AccessResponse

for k, strategy := range map[string]CoreStrategy{
"hmac": &hmacshaStrategy,
"hmac": hmacshaStrategy,
} {
t.Run("strategy="+k, func(t *testing.T) {
store := storage.NewMemoryStore()
Expand Down Expand Up @@ -1071,8 +1071,8 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
mockTransactional,
mockRevocationStore,
},
AccessTokenStrategy: &hmacshaStrategy,
RefreshTokenStrategy: &hmacshaStrategy,
AccessTokenStrategy: hmacshaStrategy,
RefreshTokenStrategy: hmacshaStrategy,
Config: &fosite.Config{
AccessTokenLifespan: time.Hour,
ScopeStrategy: fosite.HierarchicScopeStrategy,
Expand Down
9 changes: 9 additions & 0 deletions handler/oauth2/providers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package oauth2

import "github.com/ory/fosite"

type lifespanProvider interface {
fosite.AccessTokenLifespanProvider
fosite.RefreshTokenLifespanProvider
fosite.AuthorizeCodeLifespanProvider
}
122 changes: 0 additions & 122 deletions handler/oauth2/strategy_hmacsha.go

This file was deleted.

108 changes: 108 additions & 0 deletions handler/oauth2/strategy_hmacsha_plain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package oauth2

import (
"context"
"time"

"github.com/ory/x/errorsx"

"github.com/ory/fosite"
enigma "github.com/ory/fosite/token/hmac"
)

var _ CoreStrategy = (*HMACSHAStrategyUnPrefixed)(nil)

type HMACSHAStrategyUnPrefixed struct {
Enigma *enigma.HMACStrategy
Config lifespanProvider
}

func NewHMACSHAStrategyUnPrefixed(
enigma *enigma.HMACStrategy,
config lifespanProvider,
) *HMACSHAStrategyUnPrefixed {
return &HMACSHAStrategyUnPrefixed{
Enigma: enigma,
Config: config,
}
}

func (h *HMACSHAStrategyUnPrefixed) AccessTokenSignature(ctx context.Context, token string) string {
return h.Enigma.Signature(token)
}
func (h *HMACSHAStrategyUnPrefixed) RefreshTokenSignature(ctx context.Context, token string) string {
return h.Enigma.Signature(token)
}
func (h *HMACSHAStrategyUnPrefixed) AuthorizeCodeSignature(ctx context.Context, token string) string {
return h.Enigma.Signature(token)
}

func (h *HMACSHAStrategyUnPrefixed) GenerateAccessToken(ctx context.Context, _ fosite.Requester) (token string, signature string, err error) {
token, sig, err := h.Enigma.Generate(ctx)
if err != nil {
return "", "", err
}

return token, sig, nil
}

func (h *HMACSHAStrategyUnPrefixed) ValidateAccessToken(ctx context.Context, r fosite.Requester, token string) (err error) {
var exp = r.GetSession().GetExpiresAt(fosite.AccessToken)
if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetAccessTokenLifespan(ctx)).Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Access token expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetAccessTokenLifespan(ctx))))
}

if !exp.IsZero() && exp.Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Access token expired at '%s'.", exp))
}

return h.Enigma.Validate(ctx, token)
}

func (h *HMACSHAStrategyUnPrefixed) GenerateRefreshToken(ctx context.Context, _ fosite.Requester) (token string, signature string, err error) {
token, sig, err := h.Enigma.Generate(ctx)
if err != nil {
return "", "", err
}

return token, sig, nil
}

func (h *HMACSHAStrategyUnPrefixed) ValidateRefreshToken(ctx context.Context, r fosite.Requester, token string) (err error) {
var exp = r.GetSession().GetExpiresAt(fosite.RefreshToken)
if exp.IsZero() {
// Unlimited lifetime
return h.Enigma.Validate(ctx, token)
}

if !exp.IsZero() && exp.Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Refresh token expired at '%s'.", exp))
}

return h.Enigma.Validate(ctx, token)
}

func (h *HMACSHAStrategyUnPrefixed) GenerateAuthorizeCode(ctx context.Context, _ fosite.Requester) (token string, signature string, err error) {
token, sig, err := h.Enigma.Generate(ctx)
if err != nil {
return "", "", err
}

return token, sig, nil
}

func (h *HMACSHAStrategyUnPrefixed) ValidateAuthorizeCode(ctx context.Context, r fosite.Requester, token string) (err error) {
var exp = r.GetSession().GetExpiresAt(fosite.AuthorizeCode)
if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetAuthorizeCodeLifespan(ctx)).Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Authorize code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetAuthorizeCodeLifespan(ctx))))
}

if !exp.IsZero() && exp.Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Authorize code expired at '%s'.", exp))
}

return h.Enigma.Validate(ctx, token)
}
Loading

0 comments on commit 010f694

Please sign in to comment.