diff --git a/clients/go/admin/token_source_provider.go b/clients/go/admin/token_source_provider.go index 41db678f6..ba6bb0a46 100644 --- a/clients/go/admin/token_source_provider.go +++ b/clients/go/admin/token_source_provider.go @@ -8,11 +8,9 @@ import ( "os" "strings" "sync" - "time" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" - "k8s.io/apimachinery/pkg/util/wait" "github.com/flyteorg/flyteidl/clients/go/admin/cache" "github.com/flyteorg/flyteidl/clients/go/admin/deviceflow" @@ -167,9 +165,8 @@ func GetPKCEAuthTokenSource(ctx context.Context, pkceTokenOrchestrator pkce.Toke } type ClientCredentialsTokenSourceProvider struct { - ccConfig clientcredentials.Config - tokenRefreshWindow time.Duration - tokenCache cache.TokenCache + ccConfig clientcredentials.Config + tokenCache cache.TokenCache } func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string, @@ -201,92 +198,45 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s Scopes: scopes, EndpointParams: endpointParams, }, - tokenRefreshWindow: cfg.TokenRefreshWindow.Duration, - tokenCache: tokenCache}, nil + tokenCache: tokenCache}, nil } func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) { - if p.tokenRefreshWindow > 0 { - source := p.ccConfig.TokenSource(ctx) - refreshTime := time.Time{} - if token, err := p.tokenCache.GetToken(); err == nil { - refreshTime = token.Expiry.Add(-getRandomDuration(p.tokenRefreshWindow)) - } - return &customTokenSource{ - ctx: ctx, - new: source, - mu: sync.Mutex{}, - tokenRefreshWindow: p.tokenRefreshWindow, - tokenCache: p.tokenCache, - refreshTime: refreshTime, - }, nil - } - return p.ccConfig.TokenSource(ctx), nil + return &customTokenSource{ + ctx: ctx, + new: p.ccConfig.TokenSource(ctx), + mu: sync.Mutex{}, + tokenCache: p.tokenCache, + }, nil } type customTokenSource struct { - ctx context.Context - new oauth2.TokenSource - tokenRefreshWindow time.Duration - mu sync.Mutex // guards everything else - refreshTime time.Time - failedToRefresh bool - tokenCache cache.TokenCache -} - -// fetchTokenFromCache returns the cached token if available, and a bool indicating if we should try to refresh it. -// This function is not thread safe and should be called with the lock held. -func (s *customTokenSource) fetchTokenFromCache() (*oauth2.Token, bool) { - token, err := s.tokenCache.GetToken() - if err != nil { - logger.Infof(s.ctx, "no token found in cache") - return nil, false - } - if !token.Valid() { - logger.Infof(s.ctx, "cached token invalid") - return nil, false - } - if time.Now().After(s.refreshTime) && !s.failedToRefresh { - logger.Infof(s.ctx, "cached token refresh window exceeded") - return token, true - } - return token, false + ctx context.Context + mu sync.Mutex // guards everything else + new oauth2.TokenSource + tokenCache cache.TokenCache } func (s *customTokenSource) Token() (*oauth2.Token, error) { s.mu.Lock() defer s.mu.Unlock() - cachedToken, needsRefresh := s.fetchTokenFromCache() - if cachedToken != nil && !needsRefresh { - return cachedToken, nil + if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() { + return token, nil } token, err := s.new.Token() if err != nil { - if needsRefresh { - logger.Warnf(s.ctx, "failed to refresh token, using last cached token until expired") - s.failedToRefresh = true - return cachedToken, nil - } - logger.Errorf(s.ctx, "failed to refresh token") - return nil, err + return nil, fmt.Errorf("failed to get token: %w", err) } - logger.Infof(s.ctx, "refreshed token") + logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry) + err = s.tokenCache.SaveToken(token) if err != nil { - logger.Warnf(s.ctx, "failed to cache token, using anyway") + logger.Warnf(s.ctx, "failed to cache token") } - s.failedToRefresh = false - s.refreshTime = token.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow)) - return token, nil -} -// Get random duration between 0 and maxDuration -func getRandomDuration(maxDuration time.Duration) time.Duration { - // d is 1.0 to 2.0 times maxDuration - d := wait.Jitter(maxDuration, 1) - return d - maxDuration + return token, nil } type DeviceFlowTokenSourceProvider struct { diff --git a/clients/go/admin/token_source_provider_test.go b/clients/go/admin/token_source_provider_test.go index b1858b46c..a0d4cb240 100644 --- a/clients/go/admin/token_source_provider_test.go +++ b/clients/go/admin/token_source_provider_test.go @@ -14,7 +14,6 @@ import ( tokenCacheMocks "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks" adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flytestdlib/config" ) func TestNewTokenSourceProvider(t *testing.T) { @@ -81,149 +80,9 @@ func TestNewTokenSourceProvider(t *testing.T) { } } -func TestCustomTokenSource_GetTokenSource(t *testing.T) { - ctx := context.Background() - cfg := GetConfig(ctx) - cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} - cfg.ClientSecretLocation = "" - - hourAhead := time.Now().Add(time.Hour) - validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} - - tests := []struct { - name string - token *oauth2.Token - }{ - { - name: "no token", - token: nil, - }, - { - - name: "valid token", - token: &validToken, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - tokenCache := &tokenCacheMocks.TokenCache{} - var tokenErr error = nil - if test.token == nil { - tokenErr = fmt.Errorf("no token") - } - tokenCache.OnGetToken().Return(test.token, tokenErr).Once() - provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") - assert.NoError(t, err) - - source, err := provider.GetTokenSource(ctx) - assert.NoError(t, err) - customSource, ok := source.(*customTokenSource) - assert.True(t, ok) - - if test.token == nil { - assert.Equal(t, time.Time{}, customSource.refreshTime) - } else { - assert.LessOrEqual(t, customSource.refreshTime.Unix(), test.token.Expiry.Unix()) - assert.GreaterOrEqual(t, customSource.refreshTime.Unix(), test.token.Expiry.Add(-cfg.TokenRefreshWindow.Duration).Unix()) - } - }) - } -} - -func TestCustomTokenSource_fetchTokenFromCache(t *testing.T) { - ctx := context.Background() - cfg := GetConfig(ctx) - cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} - cfg.ClientSecretLocation = "" - - minuteAgo := time.Now().Add(-time.Minute) - hourAhead := time.Now().Add(time.Hour) - invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo} - validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} - - tests := []struct { - name string - refreshTime time.Time - failedToRefresh bool - token *oauth2.Token - expectToken bool - expectNeedsRefresh bool - }{ - { - name: "no token", - refreshTime: hourAhead, - failedToRefresh: false, - token: nil, - expectToken: false, - expectNeedsRefresh: false, - }, - { - name: "invalid token", - refreshTime: hourAhead, - failedToRefresh: false, - token: &invalidToken, - expectToken: false, - expectNeedsRefresh: false, - }, - { - name: "refresh exceeded", - refreshTime: minuteAgo, - failedToRefresh: false, - token: &validToken, - expectToken: true, - expectNeedsRefresh: true, - }, - { - name: "refresh exceeded failed", - refreshTime: minuteAgo, - failedToRefresh: true, - token: &validToken, - expectToken: true, - expectNeedsRefresh: false, - }, - { - name: "valid token", - refreshTime: hourAhead, - failedToRefresh: false, - token: &validToken, - expectToken: true, - expectNeedsRefresh: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - tokenCache := &tokenCacheMocks.TokenCache{} - var tokenErr error = nil - if test.token == nil { - tokenErr = fmt.Errorf("no token") - } - tokenCache.OnGetToken().Return(test.token, tokenErr).Twice() - provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") - assert.NoError(t, err) - source, err := provider.GetTokenSource(ctx) - assert.NoError(t, err) - customSource, ok := source.(*customTokenSource) - assert.True(t, ok) - - customSource.refreshTime = test.refreshTime - customSource.failedToRefresh = test.failedToRefresh - token, needsRefresh := customSource.fetchTokenFromCache() - if test.expectToken { - assert.NotNil(t, token) - } else { - assert.Nil(t, token) - } - assert.Equal(t, test.expectNeedsRefresh, needsRefresh) - }) - } -} - func TestCustomTokenSource_Token(t *testing.T) { ctx := context.Background() cfg := GetConfig(ctx) - cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} cfg.ClientSecretLocation = "" minuteAgo := time.Now().Add(-time.Minute) @@ -234,51 +93,41 @@ func TestCustomTokenSource_Token(t *testing.T) { newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead} tests := []struct { - name string - refreshTime time.Time - failedToRefresh bool - token *oauth2.Token - newToken *oauth2.Token - expectedToken *oauth2.Token + name string + token *oauth2.Token + newToken *oauth2.Token + expectedToken *oauth2.Token }{ { - name: "cached token", - refreshTime: hourAhead, - failedToRefresh: false, - token: &validToken, - newToken: nil, - expectedToken: &validToken, + name: "no cached token", + token: nil, + newToken: &newToken, + expectedToken: &newToken, }, { - name: "failed refresh still valid", - refreshTime: minuteAgo, - failedToRefresh: false, - token: &validToken, - newToken: nil, - expectedToken: &validToken, + name: "cached token valid", + token: &validToken, + newToken: nil, + expectedToken: &validToken, }, { - name: "failed refresh invalid", - refreshTime: minuteAgo, - failedToRefresh: false, - token: &invalidToken, - newToken: nil, - expectedToken: nil, + name: "cached token expired", + token: &invalidToken, + newToken: &newToken, + expectedToken: &newToken, }, { - name: "refresh", - refreshTime: minuteAgo, - failedToRefresh: false, - token: &invalidToken, - newToken: &newToken, - expectedToken: &newToken, + name: "failed new token", + token: &invalidToken, + newToken: nil, + expectedToken: nil, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { tokenCache := &tokenCacheMocks.TokenCache{} - tokenCache.OnGetToken().Return(test.token, nil).Twice() + tokenCache.OnGetToken().Return(test.token, nil).Once() provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") assert.NoError(t, err) source, err := provider.GetTokenSource(ctx) @@ -287,14 +136,14 @@ func TestCustomTokenSource_Token(t *testing.T) { assert.True(t, ok) mockSource := &adminMocks.TokenSource{} - if test.newToken != nil { - mockSource.OnToken().Return(test.newToken, nil) - } else { - mockSource.OnToken().Return(nil, fmt.Errorf("refresh token failed")) + if test.token != &validToken { + if test.newToken != nil { + mockSource.OnToken().Return(test.newToken, nil) + } else { + mockSource.OnToken().Return(nil, fmt.Errorf("refresh token failed")) + } } customSource.new = mockSource - customSource.refreshTime = test.refreshTime - customSource.failedToRefresh = test.failedToRefresh if test.newToken != nil { tokenCache.OnSaveToken(test.newToken).Return(nil).Once() } @@ -306,6 +155,8 @@ func TestCustomTokenSource_Token(t *testing.T) { assert.Nil(t, token) assert.Error(t, err) } + tokenCache.AssertExpectations(t) + mockSource.AssertExpectations(t) }) } }