diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index 2488e112d5..dfa3b0e4f9 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -5,7 +5,6 @@ namespace Microsoft.Azure.Cosmos { using System; - using System.Diagnostics; using System.Globalization; using System.Net; using System.Threading; @@ -15,7 +14,6 @@ namespace Microsoft.Azure.Cosmos using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.Resource.CosmosExceptions; using Microsoft.Azure.Cosmos.Tracing; - using Microsoft.Azure.Cosmos.Tracing.TraceData; using Microsoft.Azure.Documents; /// @@ -32,7 +30,7 @@ internal sealed class TokenCredentialCache : IDisposable public static readonly double DefaultBackgroundTokenCredentialRefreshIntervalPercentage = .50; // The maximum time a task delayed is allowed is Int32.MaxValue in Milliseconds which is roughly 24 days - public static readonly TimeSpan MaxBackgroundRefreshInterval = TimeSpan.FromMilliseconds(Int32.MaxValue); + public static readonly TimeSpan MaxBackgroundRefreshInterval = TimeSpan.FromMilliseconds(int.MaxValue); // The token refresh retries half the time. Given default of 1hr it will retry at 30m, 15, 7.5, 3.75, 1.875 // If the background refresh fails with less than a minute then just allow the request to hit the exception. @@ -49,7 +47,8 @@ internal sealed class TokenCredentialCache : IDisposable private readonly object backgroundRefreshLock = new object(); private TimeSpan? systemBackgroundTokenCredentialRefreshInterval; - private AccessToken? cachedAccessToken; + private Task? currentRefreshOperation = null; + private AccessToken? cachedAccessToken = null; private bool isBackgroundTaskRunning = false; private bool isDisposed = false; @@ -107,8 +106,13 @@ internal async ValueTask GetTokenAsync(ITrace trace) return this.cachedAccessToken.Value.Token; } - AccessToken accessToken = await this.RefreshCachedTokenWithRetryHelperAsync(trace); - this.StartBackgroundTokenRefreshLoop(); + AccessToken accessToken = await this.GetNewTokenAsync(trace); + if (!this.isBackgroundTaskRunning) + { + // This is a background thread so no need to await + Task backgroundThread = Task.Run(this.StartBackgroundTokenRefreshLoop); + } + return accessToken.Token; } @@ -124,137 +128,160 @@ public void Dispose() this.isDisposed = true; } - private async ValueTask RefreshCachedTokenWithRetryHelperAsync( + private async Task GetNewTokenAsync( ITrace trace) { - Exception? lastException = null; - const int totalRetryCount = 2; - for (int retry = 0; retry < totalRetryCount; retry++) + // Use a local variable to avoid the possibility the task gets changed + // between the null check and the await operation. + Task? currentTask = this.currentRefreshOperation; + if (currentTask != null) { - if (this.cancellationToken.IsCancellationRequested) - { - DefaultTrace.TraceInformation( - "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); + // The refresh is already occurring wait on the existing task + return await currentTask; + } - break; - } + try + { + await this.isTokenRefreshingLock.WaitAsync(); - using (ITrace getTokenTrace = trace.StartChild( - name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), - component: TraceComponent.Authorization, - level: Tracing.TraceLevel.Info)) + // avoid doing the await in the semaphore to unblock the parallel requests + if (this.currentRefreshOperation == null) { - try - { - return await this.UpdateCachedTokenAsync(); - } - catch (RequestFailedException requestFailedException) - { - lastException = requestFailedException; - getTokenTrace.AddDatum( - $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedException); + // ValueTask can not be awaited multiple times + currentTask = this.RefreshCachedTokenWithRetryHelperAsync(trace).AsTask(); + this.currentRefreshOperation = currentTask; + } + else + { + currentTask = this.currentRefreshOperation; + } + } + finally + { + this.isTokenRefreshingLock.Release(); + } - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); + return await currentTask; + } - // Don't retry on auth failures - if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || - requestFailedException.Status == (int)HttpStatusCode.Forbidden) - { - this.cachedAccessToken = default; - throw; - } - } - catch (OperationCanceledException operationCancelled) + private async ValueTask RefreshCachedTokenWithRetryHelperAsync( + ITrace trace) + { + try + { + Exception? lastException = null; + const int totalRetryCount = 2; + for (int retry = 0; retry < totalRetryCount; retry++) + { + if (this.cancellationToken.IsCancellationRequested) { - lastException = operationCancelled; - getTokenTrace.AddDatum( - $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - operationCancelled); + DefaultTrace.TraceInformation( + "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); - - throw CosmosExceptionFactory.CreateRequestTimeoutException( - message: ClientResources.FailedToGetAadToken, - headers: new Headers() - { - SubStatusCode = SubStatusCodes.FailedToGetAadToken, - }, - innerException: lastException, - trace: getTokenTrace); + break; } - catch (Exception exception) - { - lastException = exception; - getTokenTrace.AddDatum( - $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - exception); - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); - } - } - } + using (ITrace getTokenTrace = trace.StartChild( + name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), + component: TraceComponent.Authorization, + level: Tracing.TraceLevel.Info)) + { + try + { + this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( + requestContext: this.tokenRequestContext, + cancellationToken: default); - if (lastException == null) - { - throw new ArgumentException("Last exception is null."); - } + if (!this.cachedAccessToken.HasValue) + { + throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); + } - // The retries have been exhausted. Throw the last exception. - throw lastException; - } + if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) + { + throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}"); + } - /// - /// This method takes a lock to only allow one thread to update the token - /// at a time. If the token was updated while it was waiting for the lock it - /// returns the new cached token. - /// - private async Task UpdateCachedTokenAsync() - { - DateTimeOffset? initialExpireTime = this.cachedAccessToken?.ExpiresOn; + if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) + { + double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; - await this.isTokenRefreshingLock.WaitAsync(); + // Ensure the background refresh interval is a valid range. + refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); + refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); + this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); + } - // Token was already refreshed successfully from another thread. - if (this.cachedAccessToken.HasValue && - (!initialExpireTime.HasValue || this.cachedAccessToken.Value.ExpiresOn != initialExpireTime.Value)) - { - return this.cachedAccessToken.Value; - } + return this.cachedAccessToken.Value; + } + catch (RequestFailedException requestFailedException) + { + lastException = requestFailedException; + getTokenTrace.AddDatum( + $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + requestFailedException); - try - { - this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( - requestContext: this.tokenRequestContext, - cancellationToken: default); + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); - if (!this.cachedAccessToken.HasValue) - { - throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); - } + // Don't retry on auth failures + if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || + requestFailedException.Status == (int)HttpStatusCode.Forbidden) + { + this.cachedAccessToken = default; + throw; + } + } + catch (OperationCanceledException operationCancelled) + { + lastException = operationCancelled; + getTokenTrace.AddDatum( + $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + operationCancelled); + + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); + + throw CosmosExceptionFactory.CreateRequestTimeoutException( + message: ClientResources.FailedToGetAadToken, + headers: new Headers() + { + SubStatusCode = SubStatusCodes.FailedToGetAadToken, + }, + innerException: lastException, + trace: getTokenTrace); + } + catch (Exception exception) + { + lastException = exception; + getTokenTrace.AddDatum( + $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + exception); - if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) - { - throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}"); + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); + } + } } - if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) + if (lastException == null) { - double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; - - // Ensure the background refresh interval is a valid range. - refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); - refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); - this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); + throw new ArgumentException("Last exception is null."); } - return this.cachedAccessToken.Value; + // The retries have been exhausted. Throw the last exception. + throw lastException; } finally { - this.isTokenRefreshingLock.Release(); + try + { + await this.isTokenRefreshingLock.WaitAsync(); + this.currentRefreshOperation = null; + } + finally + { + this.isTokenRefreshingLock.Release(); + } } } @@ -300,7 +327,7 @@ private async void StartBackgroundTokenRefreshLoop() DefaultTrace.TraceInformation("BackgroundTokenRefreshLoop() - Invoking refresh"); - await this.UpdateCachedTokenAsync(); + await this.GetNewTokenAsync(Tracing.Trace.GetRootTrace("TokenCredentialCacheBackground refresh")); } catch (Exception ex) { diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs index c1789f51ef..9daf960f79 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs @@ -6,8 +6,11 @@ namespace Microsoft.Azure.Cosmos.Tests { using System; using System.Collections.Generic; + using System.Diagnostics; + using System.Linq; using System.Net; using System.Net.Http; + using System.Reflection; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -209,6 +212,7 @@ public async Task TestTokenCredentialCacheHappyPathAsync() using (TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential)) { await this.GetAndVerifyTokenAsync(tokenCredentialCache); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); } } @@ -237,6 +241,7 @@ public async Task TestTokenCredentialErrorAsync() // TokenCredential.GetTokenAsync() is retried for 3 times, so it should have been invoked for 4 times. Assert.AreEqual(2, testTokenCredential.NumTimesInvoked); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); } } @@ -286,6 +291,7 @@ public async Task TestTokenCredentialBackgroundRefreshAsync() Assert.AreEqual(token2, t3); Assert.AreEqual(2, testTokenCredential.NumTimesInvoked); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); } } @@ -313,6 +319,8 @@ public async Task TestTokenCredentialBackgroundRefreshAsync_OnDispose() TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential, TimeSpan.FromMilliseconds(100)); string t1 = await tokenCredentialCache.GetTokenAsync(NoOpTrace.Singleton); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); + tokenCredentialCache.Dispose(); Assert.AreEqual(token1, t1); @@ -323,53 +331,91 @@ public async Task TestTokenCredentialBackgroundRefreshAsync_OnDispose() public async Task TestTokenCredentialFailedToRefreshAsync() { string token = "Token"; - bool firstTimeGetToken = true; + bool throwExceptionOnGetToken = false; Exception exception = new Exception(); TestTokenCredential testTokenCredential = new TestTokenCredential(() => { - if (firstTimeGetToken) + if (throwExceptionOnGetToken) { - firstTimeGetToken = false; - - return new ValueTask(new AccessToken(token, DateTimeOffset.UtcNow + TimeSpan.FromSeconds(6))); + throw exception; } else { - throw exception; + return new ValueTask(new AccessToken(token, DateTimeOffset.UtcNow + TimeSpan.FromSeconds(8))); } }); - using ITrace trace = Trace.GetRootTrace("test"); + using ITrace trace = Cosmos.Tracing.Trace.GetRootTrace("test"); using (TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential)) { Assert.AreEqual(token, await tokenCredentialCache.GetTokenAsync(trace)); + Assert.AreEqual(1, testTokenCredential.NumTimesInvoked); + throwExceptionOnGetToken = true; - // Token is valid for 6 seconds. Client TokenCredentialRefreshBuffer is set to 5 seconds. + // Token is valid for 10 seconds. Client TokenCredentialRefreshBuffer is set to 5 seconds. // After waiting for 2 seconds, the cache token is still valid, but it will be refreshed in the background. await Task.Delay(TimeSpan.FromSeconds(2)); Assert.AreEqual(token, await tokenCredentialCache.GetTokenAsync(trace)); + Assert.AreEqual(1, testTokenCredential.NumTimesInvoked); // Token refreshes fails except for the first time, but the cached token will be served as long as it is valid. - await Task.Delay(TimeSpan.FromSeconds(3)); + // Wait for the background refresh to occur. It should fail but the cached token should still be valid + Stopwatch stopwatch = Stopwatch.StartNew(); + while (testTokenCredential.NumTimesInvoked != 3) + { + Assert.IsTrue(stopwatch.Elapsed.TotalSeconds < 10, "The background task did not start in 10 seconds"); + await Task.Delay(200); + } Assert.AreEqual(token, await tokenCredentialCache.GetTokenAsync(trace)); + Assert.AreEqual(3, testTokenCredential.NumTimesInvoked, $"The cached token was not used. Waited time for background refresh: {stopwatch.Elapsed.TotalSeconds} seconds"); // Cache token has expired, and it fails to refresh. - await Task.Delay(TimeSpan.FromSeconds(2)); + await Task.Delay(TimeSpan.FromSeconds(5)); + throwExceptionOnGetToken = true; - try + // Simulate multiple concurrent request on the failed token + List tasks = new List(); + for (int i = 0; i < 40; i++) { - await tokenCredentialCache.GetTokenAsync(trace); - Assert.Fail("TokenCredentialCache.GetTokenAsync() is expected to fail but succeeded"); + Task task = Task.Run(async () => + { + try + { + await tokenCredentialCache.GetTokenAsync(trace); + Assert.Fail("TokenCredentialCache.GetTokenAsync() is expected to fail but succeeded"); + } + catch (Exception thrownException) + { + // It should just throw the original exception and not be wrapped in a CosmosException + // This avoids any confusion on where the error was thrown from. + Assert.IsTrue(object.ReferenceEquals( + exception, + thrownException), $"Incorrect exception thrown: Expected: {exception}; Actual: {thrownException}"); + } + }); + tasks.Add(task); } - catch (Exception thrownException) + + await Task.WhenAll(tasks); + + this.ValidateSemaphoreIsReleased(tokenCredentialCache); + + + // Simulate multiple concurrent request that should succeed after a failure + throwExceptionOnGetToken = false; + int numGetTokenCallsAfterFailures = testTokenCredential.NumTimesInvoked; + tasks = new List(); + for (int i = 0; i < 40; i++) { - // It should just throw the original exception and not be wrapped in a CosmosException - // This avoids any confusion on where the error was thrown from. - Assert.IsTrue(object.ReferenceEquals( - exception, - thrownException)); + Task task = Task.Run(async () => await tokenCredentialCache.GetTokenAsync(trace)); + tasks.Add(task); } + + await Task.WhenAll(tasks); + Assert.AreEqual(numGetTokenCallsAfterFailures+1, testTokenCredential.NumTimesInvoked, "There should only be 1 GetToken call to get the new token after the failures"); + + this.ValidateSemaphoreIsReleased(tokenCredentialCache); } } @@ -379,12 +425,15 @@ public async Task TestTokenCredentialMultiThreadAsync() // When multiple thread calls TokenCredentialCache.GetTokenAsync and a valid cached token // is not available, TokenCredentialCache will only create one task to get token. int numTasks = 100; - - TestTokenCredential testTokenCredential = new TestTokenCredential(() => + bool delayTokenRefresh = true; + TestTokenCredential testTokenCredential = new TestTokenCredential(async () => { - Task.Delay(TimeSpan.FromSeconds(3)).Wait(); + while (delayTokenRefresh) + { + await Task.Delay(TimeSpan.FromMilliseconds(10)); + } - return new ValueTask(this.AccessToken); + return this.AccessToken; }); using (TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential)) @@ -393,11 +442,29 @@ public async Task TestTokenCredentialMultiThreadAsync() for (int i = 0; i < numTasks; i++) { - tasks[i] = this.GetAndVerifyTokenAsync(tokenCredentialCache); + tasks[i] = Task.Run(() => this.GetAndVerifyTokenAsync(tokenCredentialCache)); } + bool waitForTasksToStart = false; + do + { + waitForTasksToStart = tasks.Where(x => x.Status == TaskStatus.Created).Any(); + await Task.Delay(TimeSpan.FromMilliseconds(10)); + } while (waitForTasksToStart); + + // Verify a task took the semaphore lock + bool isRefreshing = false; + do + { + isRefreshing = this.IsTokenRefreshInProgress(tokenCredentialCache); + await Task.Delay(TimeSpan.FromMilliseconds(10)); + } while (!isRefreshing); + + delayTokenRefresh = false; + await Task.WhenAll(tasks); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); Assert.AreEqual(1, testTokenCredential.NumTimesInvoked); } } @@ -418,11 +485,34 @@ private TokenCredentialCache CreateTokenCredentialCache( backgroundTokenCredentialRefreshInterval: refreshInterval); } + private bool IsTokenRefreshInProgress(TokenCredentialCache tokenCredentialCache) + { + Type type = typeof(TokenCredentialCache); + FieldInfo sempahoreFieldInfo = type.GetField("currentRefreshOperation", BindingFlags.NonPublic | BindingFlags.Instance); + Task refreshToken = (Task)sempahoreFieldInfo.GetValue(tokenCredentialCache); + return refreshToken != null; + } + + private int GetSemaphoreCurrentCount(TokenCredentialCache tokenCredentialCache) + { + Type type = typeof(TokenCredentialCache); + FieldInfo sempahoreFieldInfo = type.GetField("isTokenRefreshingLock", BindingFlags.NonPublic | BindingFlags.Instance); + SemaphoreSlim semaphoreSlim = (SemaphoreSlim)sempahoreFieldInfo.GetValue(tokenCredentialCache); + return semaphoreSlim.CurrentCount; + } + + private void ValidateSemaphoreIsReleased(TokenCredentialCache tokenCredentialCache) + { + int currentCount = this.GetSemaphoreCurrentCount(tokenCredentialCache); + Assert.AreEqual(1, currentCount); + } + private async Task GetAndVerifyTokenAsync(TokenCredentialCache tokenCredentialCache) { + string result = await tokenCredentialCache.GetTokenAsync(NoOpTrace.Singleton); Assert.AreEqual( this.AccessToken.Token, - await tokenCredentialCache.GetTokenAsync(NoOpTrace.Singleton)); + result); } private sealed class TestTokenCredential : TokenCredential @@ -443,7 +533,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell Assert.AreEqual(1, requestContext.Scopes.Length); Assert.AreEqual(CosmosAuthorizationTests.ExpectedScope, requestContext.Scopes[0]); - return this.accessTokenFunc().Result; + return this.accessTokenFunc().GetAwaiter().GetResult(); } public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)