Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Remove misleading token refresh logic from client credentials token source provider #383

Merged
merged 2 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 20 additions & 70 deletions clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ import (
"os"
"strings"
"sync"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
"k8s.io/apimachinery/pkg/util/wait"

"github.com/flyteorg/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyteidl/clients/go/admin/deviceflow"
Expand Down Expand Up @@ -167,9 +165,8 @@ func GetPKCEAuthTokenSource(ctx context.Context, pkceTokenOrchestrator pkce.Toke
}

type ClientCredentialsTokenSourceProvider struct {
ccConfig clientcredentials.Config
tokenRefreshWindow time.Duration
tokenCache cache.TokenCache
ccConfig clientcredentials.Config
tokenCache cache.TokenCache
}

func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string,
Expand Down Expand Up @@ -201,92 +198,45 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
Scopes: scopes,
EndpointParams: endpointParams,
},
tokenRefreshWindow: cfg.TokenRefreshWindow.Duration,
tokenCache: tokenCache}, nil
tokenCache: tokenCache}, nil
}

func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
if p.tokenRefreshWindow > 0 {
source := p.ccConfig.TokenSource(ctx)
refreshTime := time.Time{}
if token, err := p.tokenCache.GetToken(); err == nil {
refreshTime = token.Expiry.Add(-getRandomDuration(p.tokenRefreshWindow))
}
return &customTokenSource{
ctx: ctx,
new: source,
mu: sync.Mutex{},
tokenRefreshWindow: p.tokenRefreshWindow,
tokenCache: p.tokenCache,
refreshTime: refreshTime,
}, nil
}
return p.ccConfig.TokenSource(ctx), nil
return &customTokenSource{
ctx: ctx,
new: p.ccConfig.TokenSource(ctx),
mu: sync.Mutex{},
tokenCache: p.tokenCache,
}, nil
}

type customTokenSource struct {
ctx context.Context
new oauth2.TokenSource
tokenRefreshWindow time.Duration
mu sync.Mutex // guards everything else
refreshTime time.Time
failedToRefresh bool
tokenCache cache.TokenCache
}

// fetchTokenFromCache returns the cached token if available, and a bool indicating if we should try to refresh it.
// This function is not thread safe and should be called with the lock held.
func (s *customTokenSource) fetchTokenFromCache() (*oauth2.Token, bool) {
token, err := s.tokenCache.GetToken()
if err != nil {
logger.Infof(s.ctx, "no token found in cache")
return nil, false
}
if !token.Valid() {
logger.Infof(s.ctx, "cached token invalid")
return nil, false
}
if time.Now().After(s.refreshTime) && !s.failedToRefresh {
logger.Infof(s.ctx, "cached token refresh window exceeded")
return token, true
}
return token, false
ctx context.Context
mu sync.Mutex // guards everything else
new oauth2.TokenSource
tokenCache cache.TokenCache
}

func (s *customTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
defer s.mu.Unlock()

cachedToken, needsRefresh := s.fetchTokenFromCache()
if cachedToken != nil && !needsRefresh {
return cachedToken, nil
if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() {
return token, nil
}

token, err := s.new.Token()
if err != nil {
if needsRefresh {
logger.Warnf(s.ctx, "failed to refresh token, using last cached token until expired")
s.failedToRefresh = true
return cachedToken, nil
}
logger.Errorf(s.ctx, "failed to refresh token")
return nil, err
return nil, fmt.Errorf("failed to get token: %w", err)
}
logger.Infof(s.ctx, "refreshed token")
logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry)

err = s.tokenCache.SaveToken(token)
if err != nil {
logger.Warnf(s.ctx, "failed to cache token, using anyway")
logger.Warnf(s.ctx, "failed to cache token")
}
s.failedToRefresh = false
s.refreshTime = token.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
return token, nil
}

// Get random duration between 0 and maxDuration
func getRandomDuration(maxDuration time.Duration) time.Duration {
// d is 1.0 to 2.0 times maxDuration
d := wait.Jitter(maxDuration, 1)
return d - maxDuration
return token, nil
}

type DeviceFlowTokenSourceProvider struct {
Expand Down
207 changes: 29 additions & 178 deletions clients/go/admin/token_source_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
tokenCacheMocks "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/config"
)

func TestNewTokenSourceProvider(t *testing.T) {
Expand Down Expand Up @@ -81,149 +80,9 @@ func TestNewTokenSourceProvider(t *testing.T) {
}
}

func TestCustomTokenSource_GetTokenSource(t *testing.T) {
ctx := context.Background()
cfg := GetConfig(ctx)
cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute}
cfg.ClientSecretLocation = ""

hourAhead := time.Now().Add(time.Hour)
validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead}

tests := []struct {
name string
token *oauth2.Token
}{
{
name: "no token",
token: nil,
},
{

name: "valid token",
token: &validToken,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tokenCache := &tokenCacheMocks.TokenCache{}
var tokenErr error = nil
if test.token == nil {
tokenErr = fmt.Errorf("no token")
}
tokenCache.OnGetToken().Return(test.token, tokenErr).Once()
provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "")
assert.NoError(t, err)

source, err := provider.GetTokenSource(ctx)
assert.NoError(t, err)
customSource, ok := source.(*customTokenSource)
assert.True(t, ok)

if test.token == nil {
assert.Equal(t, time.Time{}, customSource.refreshTime)
} else {
assert.LessOrEqual(t, customSource.refreshTime.Unix(), test.token.Expiry.Unix())
assert.GreaterOrEqual(t, customSource.refreshTime.Unix(), test.token.Expiry.Add(-cfg.TokenRefreshWindow.Duration).Unix())
}
})
}
}

func TestCustomTokenSource_fetchTokenFromCache(t *testing.T) {
ctx := context.Background()
cfg := GetConfig(ctx)
cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute}
cfg.ClientSecretLocation = ""

minuteAgo := time.Now().Add(-time.Minute)
hourAhead := time.Now().Add(time.Hour)
invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo}
validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead}

tests := []struct {
name string
refreshTime time.Time
failedToRefresh bool
token *oauth2.Token
expectToken bool
expectNeedsRefresh bool
}{
{
name: "no token",
refreshTime: hourAhead,
failedToRefresh: false,
token: nil,
expectToken: false,
expectNeedsRefresh: false,
},
{
name: "invalid token",
refreshTime: hourAhead,
failedToRefresh: false,
token: &invalidToken,
expectToken: false,
expectNeedsRefresh: false,
},
{
name: "refresh exceeded",
refreshTime: minuteAgo,
failedToRefresh: false,
token: &validToken,
expectToken: true,
expectNeedsRefresh: true,
},
{
name: "refresh exceeded failed",
refreshTime: minuteAgo,
failedToRefresh: true,
token: &validToken,
expectToken: true,
expectNeedsRefresh: false,
},
{
name: "valid token",
refreshTime: hourAhead,
failedToRefresh: false,
token: &validToken,
expectToken: true,
expectNeedsRefresh: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tokenCache := &tokenCacheMocks.TokenCache{}
var tokenErr error = nil
if test.token == nil {
tokenErr = fmt.Errorf("no token")
}
tokenCache.OnGetToken().Return(test.token, tokenErr).Twice()
provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "")
assert.NoError(t, err)
source, err := provider.GetTokenSource(ctx)
assert.NoError(t, err)
customSource, ok := source.(*customTokenSource)
assert.True(t, ok)

customSource.refreshTime = test.refreshTime
customSource.failedToRefresh = test.failedToRefresh
token, needsRefresh := customSource.fetchTokenFromCache()
if test.expectToken {
assert.NotNil(t, token)
} else {
assert.Nil(t, token)
}
assert.Equal(t, test.expectNeedsRefresh, needsRefresh)
})
}
}

func TestCustomTokenSource_Token(t *testing.T) {
ctx := context.Background()
cfg := GetConfig(ctx)
cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute}
cfg.ClientSecretLocation = ""

minuteAgo := time.Now().Add(-time.Minute)
Expand All @@ -234,51 +93,41 @@ func TestCustomTokenSource_Token(t *testing.T) {
newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead}

tests := []struct {
name string
refreshTime time.Time
failedToRefresh bool
token *oauth2.Token
newToken *oauth2.Token
expectedToken *oauth2.Token
name string
token *oauth2.Token
newToken *oauth2.Token
expectedToken *oauth2.Token
}{
{
name: "cached token",
refreshTime: hourAhead,
failedToRefresh: false,
token: &validToken,
newToken: nil,
expectedToken: &validToken,
name: "no cached token",
token: nil,
newToken: &newToken,
expectedToken: &newToken,
},
{
name: "failed refresh still valid",
refreshTime: minuteAgo,
failedToRefresh: false,
token: &validToken,
newToken: nil,
expectedToken: &validToken,
name: "cached token valid",
token: &validToken,
newToken: nil,
expectedToken: &validToken,
},
{
name: "failed refresh invalid",
refreshTime: minuteAgo,
failedToRefresh: false,
token: &invalidToken,
newToken: nil,
expectedToken: nil,
name: "cached token expired",
token: &invalidToken,
newToken: &newToken,
expectedToken: &newToken,
},
{
name: "refresh",
refreshTime: minuteAgo,
failedToRefresh: false,
token: &invalidToken,
newToken: &newToken,
expectedToken: &newToken,
name: "failed new token",
token: &invalidToken,
newToken: nil,
expectedToken: nil,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tokenCache := &tokenCacheMocks.TokenCache{}
tokenCache.OnGetToken().Return(test.token, nil).Twice()
tokenCache.OnGetToken().Return(test.token, nil).Once()
provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "")
assert.NoError(t, err)
source, err := provider.GetTokenSource(ctx)
Expand All @@ -287,14 +136,14 @@ func TestCustomTokenSource_Token(t *testing.T) {
assert.True(t, ok)

mockSource := &adminMocks.TokenSource{}
if test.newToken != nil {
mockSource.OnToken().Return(test.newToken, nil)
} else {
mockSource.OnToken().Return(nil, fmt.Errorf("refresh token failed"))
if test.token != &validToken {
if test.newToken != nil {
mockSource.OnToken().Return(test.newToken, nil)
} else {
mockSource.OnToken().Return(nil, fmt.Errorf("refresh token failed"))
}
}
customSource.new = mockSource
customSource.refreshTime = test.refreshTime
customSource.failedToRefresh = test.failedToRefresh
if test.newToken != nil {
tokenCache.OnSaveToken(test.newToken).Return(nil).Once()
}
Expand All @@ -306,6 +155,8 @@ func TestCustomTokenSource_Token(t *testing.T) {
assert.Nil(t, token)
assert.Error(t, err)
}
tokenCache.AssertExpectations(t)
mockSource.AssertExpectations(t)
})
}
}