Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Remove misleading token refresh logic from client credentials token s…
Browse files Browse the repository at this point in the history
…ource provider (#383)
  • Loading branch information
andrewwdye authored Mar 28, 2023
1 parent 40e0875 commit 9cdb02c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 248 deletions.
90 changes: 20 additions & 70 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ import (
"os"
"strings"
"sync"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
"k8s.io/apimachinery/pkg/util/wait"

"github.com/flyteorg/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyteidl/clients/go/admin/deviceflow"
Expand Down Expand Up @@ -167,9 +165,8 @@ func GetPKCEAuthTokenSource(ctx context.Context, pkceTokenOrchestrator pkce.Toke
}

type ClientCredentialsTokenSourceProvider struct {
ccConfig clientcredentials.Config
tokenRefreshWindow time.Duration
tokenCache cache.TokenCache
ccConfig clientcredentials.Config
tokenCache cache.TokenCache
}

func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string,
Expand Down Expand Up @@ -201,92 +198,45 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
Scopes: scopes,
EndpointParams: endpointParams,
},
tokenRefreshWindow: cfg.TokenRefreshWindow.Duration,
tokenCache: tokenCache}, nil
tokenCache: tokenCache}, nil
}

func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
if p.tokenRefreshWindow > 0 {
source := p.ccConfig.TokenSource(ctx)
refreshTime := time.Time{}
if token, err := p.tokenCache.GetToken(); err == nil {
refreshTime = token.Expiry.Add(-getRandomDuration(p.tokenRefreshWindow))
}
return &customTokenSource{
ctx: ctx,
new: source,
mu: sync.Mutex{},
tokenRefreshWindow: p.tokenRefreshWindow,
tokenCache: p.tokenCache,
refreshTime: refreshTime,
}, nil
}
return p.ccConfig.TokenSource(ctx), nil
return &customTokenSource{
ctx: ctx,
new: p.ccConfig.TokenSource(ctx),
mu: sync.Mutex{},
tokenCache: p.tokenCache,
}, nil
}

type customTokenSource struct {
ctx context.Context
new oauth2.TokenSource
tokenRefreshWindow time.Duration
mu sync.Mutex // guards everything else
refreshTime time.Time
failedToRefresh bool
tokenCache cache.TokenCache
}

// fetchTokenFromCache returns the cached token if available, and a bool indicating if we should try to refresh it.
// This function is not thread safe and should be called with the lock held.
func (s *customTokenSource) fetchTokenFromCache() (*oauth2.Token, bool) {
token, err := s.tokenCache.GetToken()
if err != nil {
logger.Infof(s.ctx, "no token found in cache")
return nil, false
}
if !token.Valid() {
logger.Infof(s.ctx, "cached token invalid")
return nil, false
}
if time.Now().After(s.refreshTime) && !s.failedToRefresh {
logger.Infof(s.ctx, "cached token refresh window exceeded")
return token, true
}
return token, false
ctx context.Context
mu sync.Mutex // guards everything else
new oauth2.TokenSource
tokenCache cache.TokenCache
}

func (s *customTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
defer s.mu.Unlock()

cachedToken, needsRefresh := s.fetchTokenFromCache()
if cachedToken != nil && !needsRefresh {
return cachedToken, nil
if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() {
return token, nil
}

token, err := s.new.Token()
if err != nil {
if needsRefresh {
logger.Warnf(s.ctx, "failed to refresh token, using last cached token until expired")
s.failedToRefresh = true
return cachedToken, nil
}
logger.Errorf(s.ctx, "failed to refresh token")
return nil, err
return nil, fmt.Errorf("failed to get token: %w", err)
}
logger.Infof(s.ctx, "refreshed token")
logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry)

err = s.tokenCache.SaveToken(token)
if err != nil {
logger.Warnf(s.ctx, "failed to cache token, using anyway")
logger.Warnf(s.ctx, "failed to cache token")
}
s.failedToRefresh = false
s.refreshTime = token.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
return token, nil
}

// Get random duration between 0 and maxDuration
func getRandomDuration(maxDuration time.Duration) time.Duration {
// d is 1.0 to 2.0 times maxDuration
d := wait.Jitter(maxDuration, 1)
return d - maxDuration
return token, nil
}

type DeviceFlowTokenSourceProvider struct {
Expand Down
207 changes: 29 additions & 178 deletions flyteidl/clients/go/admin/token_source_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
tokenCacheMocks "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/config"
)

func TestNewTokenSourceProvider(t *testing.T) {
Expand Down Expand Up @@ -81,149 +80,9 @@ func TestNewTokenSourceProvider(t *testing.T) {
}
}

func TestCustomTokenSource_GetTokenSource(t *testing.T) {
ctx := context.Background()
cfg := GetConfig(ctx)
cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute}
cfg.ClientSecretLocation = ""

hourAhead := time.Now().Add(time.Hour)
validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead}

tests := []struct {
name string
token *oauth2.Token
}{
{
name: "no token",
token: nil,
},
{

name: "valid token",
token: &validToken,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tokenCache := &tokenCacheMocks.TokenCache{}
var tokenErr error = nil
if test.token == nil {
tokenErr = fmt.Errorf("no token")
}
tokenCache.OnGetToken().Return(test.token, tokenErr).Once()
provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "")
assert.NoError(t, err)

source, err := provider.GetTokenSource(ctx)
assert.NoError(t, err)
customSource, ok := source.(*customTokenSource)
assert.True(t, ok)

if test.token == nil {
assert.Equal(t, time.Time{}, customSource.refreshTime)
} else {
assert.LessOrEqual(t, customSource.refreshTime.Unix(), test.token.Expiry.Unix())
assert.GreaterOrEqual(t, customSource.refreshTime.Unix(), test.token.Expiry.Add(-cfg.TokenRefreshWindow.Duration).Unix())
}
})
}
}

func TestCustomTokenSource_fetchTokenFromCache(t *testing.T) {
ctx := context.Background()
cfg := GetConfig(ctx)
cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute}
cfg.ClientSecretLocation = ""

minuteAgo := time.Now().Add(-time.Minute)
hourAhead := time.Now().Add(time.Hour)
invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo}
validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead}

tests := []struct {
name string
refreshTime time.Time
failedToRefresh bool
token *oauth2.Token
expectToken bool
expectNeedsRefresh bool
}{
{
name: "no token",
refreshTime: hourAhead,
failedToRefresh: false,
token: nil,
expectToken: false,
expectNeedsRefresh: false,
},
{
name: "invalid token",
refreshTime: hourAhead,
failedToRefresh: false,
token: &invalidToken,
expectToken: false,
expectNeedsRefresh: false,
},
{
name: "refresh exceeded",
refreshTime: minuteAgo,
failedToRefresh: false,
token: &validToken,
expectToken: true,
expectNeedsRefresh: true,
},
{
name: "refresh exceeded failed",
refreshTime: minuteAgo,
failedToRefresh: true,
token: &validToken,
expectToken: true,
expectNeedsRefresh: false,
},
{
name: "valid token",
refreshTime: hourAhead,
failedToRefresh: false,
token: &validToken,
expectToken: true,
expectNeedsRefresh: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tokenCache := &tokenCacheMocks.TokenCache{}
var tokenErr error = nil
if test.token == nil {
tokenErr = fmt.Errorf("no token")
}
tokenCache.OnGetToken().Return(test.token, tokenErr).Twice()
provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "")
assert.NoError(t, err)
source, err := provider.GetTokenSource(ctx)
assert.NoError(t, err)
customSource, ok := source.(*customTokenSource)
assert.True(t, ok)

customSource.refreshTime = test.refreshTime
customSource.failedToRefresh = test.failedToRefresh
token, needsRefresh := customSource.fetchTokenFromCache()
if test.expectToken {
assert.NotNil(t, token)
} else {
assert.Nil(t, token)
}
assert.Equal(t, test.expectNeedsRefresh, needsRefresh)
})
}
}

func TestCustomTokenSource_Token(t *testing.T) {
ctx := context.Background()
cfg := GetConfig(ctx)
cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute}
cfg.ClientSecretLocation = ""

minuteAgo := time.Now().Add(-time.Minute)
Expand All @@ -234,51 +93,41 @@ func TestCustomTokenSource_Token(t *testing.T) {
newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead}

tests := []struct {
name string
refreshTime time.Time
failedToRefresh bool
token *oauth2.Token
newToken *oauth2.Token
expectedToken *oauth2.Token
name string
token *oauth2.Token
newToken *oauth2.Token
expectedToken *oauth2.Token
}{
{
name: "cached token",
refreshTime: hourAhead,
failedToRefresh: false,
token: &validToken,
newToken: nil,
expectedToken: &validToken,
name: "no cached token",
token: nil,
newToken: &newToken,
expectedToken: &newToken,
},
{
name: "failed refresh still valid",
refreshTime: minuteAgo,
failedToRefresh: false,
token: &validToken,
newToken: nil,
expectedToken: &validToken,
name: "cached token valid",
token: &validToken,
newToken: nil,
expectedToken: &validToken,
},
{
name: "failed refresh invalid",
refreshTime: minuteAgo,
failedToRefresh: false,
token: &invalidToken,
newToken: nil,
expectedToken: nil,
name: "cached token expired",
token: &invalidToken,
newToken: &newToken,
expectedToken: &newToken,
},
{
name: "refresh",
refreshTime: minuteAgo,
failedToRefresh: false,
token: &invalidToken,
newToken: &newToken,
expectedToken: &newToken,
name: "failed new token",
token: &invalidToken,
newToken: nil,
expectedToken: nil,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tokenCache := &tokenCacheMocks.TokenCache{}
tokenCache.OnGetToken().Return(test.token, nil).Twice()
tokenCache.OnGetToken().Return(test.token, nil).Once()
provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "")
assert.NoError(t, err)
source, err := provider.GetTokenSource(ctx)
Expand All @@ -287,14 +136,14 @@ func TestCustomTokenSource_Token(t *testing.T) {
assert.True(t, ok)

mockSource := &adminMocks.TokenSource{}
if test.newToken != nil {
mockSource.OnToken().Return(test.newToken, nil)
} else {
mockSource.OnToken().Return(nil, fmt.Errorf("refresh token failed"))
if test.token != &validToken {
if test.newToken != nil {
mockSource.OnToken().Return(test.newToken, nil)
} else {
mockSource.OnToken().Return(nil, fmt.Errorf("refresh token failed"))
}
}
customSource.new = mockSource
customSource.refreshTime = test.refreshTime
customSource.failedToRefresh = test.failedToRefresh
if test.newToken != nil {
tokenCache.OnSaveToken(test.newToken).Return(nil).Once()
}
Expand All @@ -306,6 +155,8 @@ func TestCustomTokenSource_Token(t *testing.T) {
assert.Nil(t, token)
assert.Error(t, err)
}
tokenCache.AssertExpectations(t)
mockSource.AssertExpectations(t)
})
}
}

0 comments on commit 9cdb02c

Please sign in to comment.