From 42d6fcebcae816ddb51e77ef5226a5ef2111e51e Mon Sep 17 00:00:00 2001 From: Jake Willey Date: Thu, 2 Sep 2021 04:44:30 -0700 Subject: [PATCH 1/4] AAD: Fixes stuck requests when token fails to refresh --- .../src/Authorization/TokenCredentialCache.cs | 18 ++++++------ .../CosmosAuthorizationTests.cs | 28 ++++++++++++++----- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index 2488e112d5..3dda23b07a 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -215,17 +215,17 @@ private async Task UpdateCachedTokenAsync() { DateTimeOffset? initialExpireTime = this.cachedAccessToken?.ExpiresOn; - await this.isTokenRefreshingLock.WaitAsync(); - - // Token was already refreshed successfully from another thread. - if (this.cachedAccessToken.HasValue && - (!initialExpireTime.HasValue || this.cachedAccessToken.Value.ExpiresOn != initialExpireTime.Value)) - { - return this.cachedAccessToken.Value; - } - try { + await this.isTokenRefreshingLock.WaitAsync(); + + // Token was already refreshed successfully from another thread. + if (this.cachedAccessToken.HasValue && + (!initialExpireTime.HasValue || this.cachedAccessToken.Value.ExpiresOn != initialExpireTime.Value)) + { + return this.cachedAccessToken.Value; + } + this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( requestContext: this.tokenRequestContext, cancellationToken: default); diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs index f70a5327b4..0241557440 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Cosmos.Tests { using System; using System.Collections.Generic; + using System.Linq; using System.Net; using System.Net.Http; using System.Security.Cryptography.X509Certificates; @@ -381,12 +382,15 @@ public async Task TestTokenCredentialMultiThreadAsync() // When multiple thread calls TokenCredentialCache.GetTokenAsync and a valid cached token // is not available, TokenCredentialCache will only create one task to get token. int numTasks = 100; - - TestTokenCredential testTokenCredential = new TestTokenCredential(() => + bool delayTokenRefresh= true; + TestTokenCredential testTokenCredential = new TestTokenCredential(async () => { - Task.Delay(TimeSpan.FromSeconds(3)).Wait(); + while (delayTokenRefresh) + { + await Task.Delay(TimeSpan.FromMilliseconds(10)); + } - return new ValueTask(this.AccessToken); + return this.AccessToken; }); using (TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential)) @@ -395,9 +399,18 @@ public async Task TestTokenCredentialMultiThreadAsync() for (int i = 0; i < numTasks; i++) { - tasks[i] = this.GetAndVerifyTokenAsync(tokenCredentialCache); + tasks[i] = Task.Run(() => this.GetAndVerifyTokenAsync(tokenCredentialCache)); } + bool waitForTasksToStart = false; + do + { + waitForTasksToStart = tasks.Where(x => x.Status == TaskStatus.Created).Any(); + await Task.Delay(TimeSpan.FromMilliseconds(10)); + } while(waitForTasksToStart); + + delayTokenRefresh = false; + await Task.WhenAll(tasks); Assert.AreEqual(1, testTokenCredential.NumTimesInvoked); @@ -422,9 +435,10 @@ private TokenCredentialCache CreateTokenCredentialCache( private async Task GetAndVerifyTokenAsync(TokenCredentialCache tokenCredentialCache) { + string result = await tokenCredentialCache.GetTokenAsync(NoOpTrace.Singleton); Assert.AreEqual( this.AccessToken.Token, - await tokenCredentialCache.GetTokenAsync(NoOpTrace.Singleton)); + result); } private sealed class TestTokenCredential : TokenCredential @@ -445,7 +459,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell Assert.AreEqual(1, requestContext.Scopes.Length); Assert.AreEqual(CosmosAuthorizationTests.ExpectedScope, requestContext.Scopes[0]); - return this.accessTokenFunc().Result; + return this.accessTokenFunc().GetAwaiter().GetResult(); } public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) From b665689e0876bda6c8f1c2b46af6af101bb7149d Mon Sep 17 00:00:00 2001 From: Jake Willey Date: Thu, 2 Sep 2021 07:43:29 -0700 Subject: [PATCH 2/4] Adding check to validate semaphore is released in all scenarios --- .../CosmosAuthorizationTests.cs | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs index 0241557440..7c1b508295 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs @@ -9,6 +9,7 @@ namespace Microsoft.Azure.Cosmos.Tests using System.Linq; using System.Net; using System.Net.Http; + using System.Reflection; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -212,6 +213,7 @@ public async Task TestTokenCredentialCacheHappyPathAsync() using (TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential)) { await this.GetAndVerifyTokenAsync(tokenCredentialCache); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); } } @@ -240,6 +242,7 @@ public async Task TestTokenCredentialErrorAsync() // TokenCredential.GetTokenAsync() is retried for 3 times, so it should have been invoked for 4 times. Assert.AreEqual(2, testTokenCredential.NumTimesInvoked); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); } } @@ -289,6 +292,7 @@ public async Task TestTokenCredentialBackgroundRefreshAsync() Assert.AreEqual(token2, t3); Assert.AreEqual(2, testTokenCredential.NumTimesInvoked); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); } } @@ -316,6 +320,8 @@ public async Task TestTokenCredentialBackgroundRefreshAsync_OnDispose() TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential, TimeSpan.FromMilliseconds(100)); string t1 = await tokenCredentialCache.GetTokenAsync(NoOpTrace.Singleton); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); + tokenCredentialCache.Dispose(); Assert.AreEqual(token1, t1); @@ -373,6 +379,8 @@ public async Task TestTokenCredentialFailedToRefreshAsync() exception, thrownException)); } + + this.ValidateSemaphoreIsReleased(tokenCredentialCache); } } @@ -382,7 +390,7 @@ public async Task TestTokenCredentialMultiThreadAsync() // When multiple thread calls TokenCredentialCache.GetTokenAsync and a valid cached token // is not available, TokenCredentialCache will only create one task to get token. int numTasks = 100; - bool delayTokenRefresh= true; + bool delayTokenRefresh = true; TestTokenCredential testTokenCredential = new TestTokenCredential(async () => { while (delayTokenRefresh) @@ -407,12 +415,21 @@ public async Task TestTokenCredentialMultiThreadAsync() { waitForTasksToStart = tasks.Where(x => x.Status == TaskStatus.Created).Any(); await Task.Delay(TimeSpan.FromMilliseconds(10)); - } while(waitForTasksToStart); + } while (waitForTasksToStart); + + // Verify a task took the semaphore lock + int waitCount = int.MinValue; + do + { + waitCount = this.GetSemaphoreCurrentCount(tokenCredentialCache); + await Task.Delay(TimeSpan.FromMilliseconds(10)); + } while (waitCount > 0); delayTokenRefresh = false; await Task.WhenAll(tasks); + this.ValidateSemaphoreIsReleased(tokenCredentialCache); Assert.AreEqual(1, testTokenCredential.NumTimesInvoked); } } @@ -433,6 +450,20 @@ private TokenCredentialCache CreateTokenCredentialCache( backgroundTokenCredentialRefreshInterval: refreshInterval); } + private int GetSemaphoreCurrentCount(TokenCredentialCache tokenCredentialCache) + { + Type type = typeof(TokenCredentialCache); + FieldInfo sempahoreFieldInfo = type.GetField("isTokenRefreshingLock", BindingFlags.NonPublic | BindingFlags.Instance); + SemaphoreSlim semaphoreSlim = (SemaphoreSlim)sempahoreFieldInfo.GetValue(tokenCredentialCache); + return semaphoreSlim.CurrentCount; + } + + private void ValidateSemaphoreIsReleased(TokenCredentialCache tokenCredentialCache) + { + int currentCount = this.GetSemaphoreCurrentCount(tokenCredentialCache); + Assert.AreEqual(1, currentCount); + } + private async Task GetAndVerifyTokenAsync(TokenCredentialCache tokenCredentialCache) { string result = await tokenCredentialCache.GetTokenAsync(NoOpTrace.Singleton); From b6482023a2698b1bcec5d4fc82b3aa0a417257cc Mon Sep 17 00:00:00 2001 From: Jake Willey Date: Tue, 7 Sep 2021 11:56:03 -0700 Subject: [PATCH 3/4] Fix multiple threads awaiting the failure scenario --- .../src/Authorization/TokenCredentialCache.cs | 245 ++++++++++-------- .../CosmosAuthorizationTests.cs | 47 +++- 2 files changed, 169 insertions(+), 123 deletions(-) diff --git a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs index 3dda23b07a..dfa3b0e4f9 100644 --- a/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs +++ b/Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs @@ -5,7 +5,6 @@ namespace Microsoft.Azure.Cosmos { using System; - using System.Diagnostics; using System.Globalization; using System.Net; using System.Threading; @@ -15,7 +14,6 @@ namespace Microsoft.Azure.Cosmos using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.Resource.CosmosExceptions; using Microsoft.Azure.Cosmos.Tracing; - using Microsoft.Azure.Cosmos.Tracing.TraceData; using Microsoft.Azure.Documents; /// @@ -32,7 +30,7 @@ internal sealed class TokenCredentialCache : IDisposable public static readonly double DefaultBackgroundTokenCredentialRefreshIntervalPercentage = .50; // The maximum time a task delayed is allowed is Int32.MaxValue in Milliseconds which is roughly 24 days - public static readonly TimeSpan MaxBackgroundRefreshInterval = TimeSpan.FromMilliseconds(Int32.MaxValue); + public static readonly TimeSpan MaxBackgroundRefreshInterval = TimeSpan.FromMilliseconds(int.MaxValue); // The token refresh retries half the time. Given default of 1hr it will retry at 30m, 15, 7.5, 3.75, 1.875 // If the background refresh fails with less than a minute then just allow the request to hit the exception. @@ -49,7 +47,8 @@ internal sealed class TokenCredentialCache : IDisposable private readonly object backgroundRefreshLock = new object(); private TimeSpan? systemBackgroundTokenCredentialRefreshInterval; - private AccessToken? cachedAccessToken; + private Task? currentRefreshOperation = null; + private AccessToken? cachedAccessToken = null; private bool isBackgroundTaskRunning = false; private bool isDisposed = false; @@ -107,8 +106,13 @@ internal async ValueTask GetTokenAsync(ITrace trace) return this.cachedAccessToken.Value.Token; } - AccessToken accessToken = await this.RefreshCachedTokenWithRetryHelperAsync(trace); - this.StartBackgroundTokenRefreshLoop(); + AccessToken accessToken = await this.GetNewTokenAsync(trace); + if (!this.isBackgroundTaskRunning) + { + // This is a background thread so no need to await + Task backgroundThread = Task.Run(this.StartBackgroundTokenRefreshLoop); + } + return accessToken.Token; } @@ -124,137 +128,160 @@ public void Dispose() this.isDisposed = true; } - private async ValueTask RefreshCachedTokenWithRetryHelperAsync( + private async Task GetNewTokenAsync( ITrace trace) { - Exception? lastException = null; - const int totalRetryCount = 2; - for (int retry = 0; retry < totalRetryCount; retry++) + // Use a local variable to avoid the possibility the task gets changed + // between the null check and the await operation. + Task? currentTask = this.currentRefreshOperation; + if (currentTask != null) { - if (this.cancellationToken.IsCancellationRequested) - { - DefaultTrace.TraceInformation( - "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); + // The refresh is already occurring wait on the existing task + return await currentTask; + } + + try + { + await this.isTokenRefreshingLock.WaitAsync(); - break; + // avoid doing the await in the semaphore to unblock the parallel requests + if (this.currentRefreshOperation == null) + { + // ValueTask can not be awaited multiple times + currentTask = this.RefreshCachedTokenWithRetryHelperAsync(trace).AsTask(); + this.currentRefreshOperation = currentTask; } + else + { + currentTask = this.currentRefreshOperation; + } + } + finally + { + this.isTokenRefreshingLock.Release(); + } - using (ITrace getTokenTrace = trace.StartChild( - name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), - component: TraceComponent.Authorization, - level: Tracing.TraceLevel.Info)) + return await currentTask; + } + + private async ValueTask RefreshCachedTokenWithRetryHelperAsync( + ITrace trace) + { + try + { + Exception? lastException = null; + const int totalRetryCount = 2; + for (int retry = 0; retry < totalRetryCount; retry++) { - try - { - return await this.UpdateCachedTokenAsync(); - } - catch (RequestFailedException requestFailedException) + if (this.cancellationToken.IsCancellationRequested) { - lastException = requestFailedException; - getTokenTrace.AddDatum( - $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - requestFailedException); + DefaultTrace.TraceInformation( + "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); - DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); - - // Don't retry on auth failures - if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || - requestFailedException.Status == (int)HttpStatusCode.Forbidden) - { - this.cachedAccessToken = default; - throw; - } + break; } - catch (OperationCanceledException operationCancelled) + + using (ITrace getTokenTrace = trace.StartChild( + name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), + component: TraceComponent.Authorization, + level: Tracing.TraceLevel.Info)) { - lastException = operationCancelled; - getTokenTrace.AddDatum( - $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - operationCancelled); + try + { + this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( + requestContext: this.tokenRequestContext, + cancellationToken: default); - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); + if (!this.cachedAccessToken.HasValue) + { + throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); + } - throw CosmosExceptionFactory.CreateRequestTimeoutException( - message: ClientResources.FailedToGetAadToken, - headers: new Headers() + if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) { - SubStatusCode = SubStatusCodes.FailedToGetAadToken, - }, - innerException: lastException, - trace: getTokenTrace); - } - catch (Exception exception) - { - lastException = exception; - getTokenTrace.AddDatum( - $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", - exception); + throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}"); + } - DefaultTrace.TraceError( - $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); - } - } - } + if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) + { + double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; - if (lastException == null) - { - throw new ArgumentException("Last exception is null."); - } + // Ensure the background refresh interval is a valid range. + refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); + refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); + this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); + } - // The retries have been exhausted. Throw the last exception. - throw lastException; - } + return this.cachedAccessToken.Value; + } + catch (RequestFailedException requestFailedException) + { + lastException = requestFailedException; + getTokenTrace.AddDatum( + $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + requestFailedException); - /// - /// This method takes a lock to only allow one thread to update the token - /// at a time. If the token was updated while it was waiting for the lock it - /// returns the new cached token. - /// - private async Task UpdateCachedTokenAsync() - { - DateTimeOffset? initialExpireTime = this.cachedAccessToken?.ExpiresOn; + DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); - try - { - await this.isTokenRefreshingLock.WaitAsync(); + // Don't retry on auth failures + if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || + requestFailedException.Status == (int)HttpStatusCode.Forbidden) + { + this.cachedAccessToken = default; + throw; + } + } + catch (OperationCanceledException operationCancelled) + { + lastException = operationCancelled; + getTokenTrace.AddDatum( + $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + operationCancelled); + + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); + + throw CosmosExceptionFactory.CreateRequestTimeoutException( + message: ClientResources.FailedToGetAadToken, + headers: new Headers() + { + SubStatusCode = SubStatusCodes.FailedToGetAadToken, + }, + innerException: lastException, + trace: getTokenTrace); + } + catch (Exception exception) + { + lastException = exception; + getTokenTrace.AddDatum( + $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", + exception); - // Token was already refreshed successfully from another thread. - if (this.cachedAccessToken.HasValue && - (!initialExpireTime.HasValue || this.cachedAccessToken.Value.ExpiresOn != initialExpireTime.Value)) - { - return this.cachedAccessToken.Value; + DefaultTrace.TraceError( + $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}"); + } + } } - this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( - requestContext: this.tokenRequestContext, - cancellationToken: default); - - if (!this.cachedAccessToken.HasValue) + if (lastException == null) { - throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); + throw new ArgumentException("Last exception is null."); } - if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) + // The retries have been exhausted. Throw the last exception. + throw lastException; + } + finally + { + try { - throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}"); + await this.isTokenRefreshingLock.WaitAsync(); + this.currentRefreshOperation = null; } - - if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) + finally { - double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; - - // Ensure the background refresh interval is a valid range. - refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); - refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); - this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); + this.isTokenRefreshingLock.Release(); } - - return this.cachedAccessToken.Value; - } - finally - { - this.isTokenRefreshingLock.Release(); } } @@ -300,7 +327,7 @@ private async void StartBackgroundTokenRefreshLoop() DefaultTrace.TraceInformation("BackgroundTokenRefreshLoop() - Invoking refresh"); - await this.UpdateCachedTokenAsync(); + await this.GetNewTokenAsync(Tracing.Trace.GetRootTrace("TokenCredentialCacheBackground refresh")); } catch (Exception ex) { diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs index 7c1b508295..7824d95c29 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs @@ -366,19 +366,30 @@ public async Task TestTokenCredentialFailedToRefreshAsync() // Cache token has expired, and it fails to refresh. await Task.Delay(TimeSpan.FromSeconds(2)); - try - { - await tokenCredentialCache.GetTokenAsync(trace); - Assert.Fail("TokenCredentialCache.GetTokenAsync() is expected to fail but succeeded"); - } - catch (Exception thrownException) + // Simulate multiple concurrent request on the failed token + List tasks = new List(); + for (int i = 0; i < 20; i++) { - // It should just throw the original exception and not be wrapped in a CosmosException - // This avoids any confusion on where the error was thrown from. - Assert.IsTrue(object.ReferenceEquals( - exception, - thrownException)); + Task task = Task.Run(async () => + { + try + { + await tokenCredentialCache.GetTokenAsync(trace); + Assert.Fail("TokenCredentialCache.GetTokenAsync() is expected to fail but succeeded"); + } + catch (Exception thrownException) + { + // It should just throw the original exception and not be wrapped in a CosmosException + // This avoids any confusion on where the error was thrown from. + Assert.IsTrue(object.ReferenceEquals( + exception, + thrownException), $"Incorrect exception thrown: Expected: {exception}; Actual: {thrownException}"); + } + }); + tasks.Add(task); } + + await Task.WhenAll(tasks); this.ValidateSemaphoreIsReleased(tokenCredentialCache); } @@ -418,12 +429,12 @@ public async Task TestTokenCredentialMultiThreadAsync() } while (waitForTasksToStart); // Verify a task took the semaphore lock - int waitCount = int.MinValue; + bool isRefreshing = false; do { - waitCount = this.GetSemaphoreCurrentCount(tokenCredentialCache); + isRefreshing = this.IsTokenRefreshInProgress(tokenCredentialCache); await Task.Delay(TimeSpan.FromMilliseconds(10)); - } while (waitCount > 0); + } while (!isRefreshing); delayTokenRefresh = false; @@ -450,6 +461,14 @@ private TokenCredentialCache CreateTokenCredentialCache( backgroundTokenCredentialRefreshInterval: refreshInterval); } + private bool IsTokenRefreshInProgress(TokenCredentialCache tokenCredentialCache) + { + Type type = typeof(TokenCredentialCache); + FieldInfo sempahoreFieldInfo = type.GetField("currentRefreshOperation", BindingFlags.NonPublic | BindingFlags.Instance); + Task refreshToken = (Task)sempahoreFieldInfo.GetValue(tokenCredentialCache); + return refreshToken != null; + } + private int GetSemaphoreCurrentCount(TokenCredentialCache tokenCredentialCache) { Type type = typeof(TokenCredentialCache); From eeaa39e7fcbe47f22f0073a8993d839f0b5734b1 Mon Sep 17 00:00:00 2001 From: Jake Willey Date: Wed, 8 Sep 2021 05:22:29 -0700 Subject: [PATCH 4/4] Add more testing for success, failure, success scenario with multiple threads --- .../CosmosAuthorizationTests.cs | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs index 7824d95c29..3317d15fb9 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosAuthorizationTests.cs @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Cosmos.Tests { using System; using System.Collections.Generic; + using System.Diagnostics; using System.Linq; using System.Net; using System.Net.Http; @@ -332,43 +333,52 @@ public async Task TestTokenCredentialBackgroundRefreshAsync_OnDispose() public async Task TestTokenCredentialFailedToRefreshAsync() { string token = "Token"; - bool firstTimeGetToken = true; + bool throwExceptionOnGetToken = false; Exception exception = new Exception(); TestTokenCredential testTokenCredential = new TestTokenCredential(() => { - if (firstTimeGetToken) + if (throwExceptionOnGetToken) { - firstTimeGetToken = false; - - return new ValueTask(new AccessToken(token, DateTimeOffset.UtcNow + TimeSpan.FromSeconds(6))); + throw exception; } else { - throw exception; + return new ValueTask(new AccessToken(token, DateTimeOffset.UtcNow + TimeSpan.FromSeconds(8))); } }); - using ITrace trace = Trace.GetRootTrace("test"); + using ITrace trace = Cosmos.Tracing.Trace.GetRootTrace("test"); using (TokenCredentialCache tokenCredentialCache = this.CreateTokenCredentialCache(testTokenCredential)) { Assert.AreEqual(token, await tokenCredentialCache.GetTokenAsync(trace)); + Assert.AreEqual(1, testTokenCredential.NumTimesInvoked); + throwExceptionOnGetToken = true; - // Token is valid for 6 seconds. Client TokenCredentialRefreshBuffer is set to 5 seconds. + // Token is valid for 10 seconds. Client TokenCredentialRefreshBuffer is set to 5 seconds. // After waiting for 2 seconds, the cache token is still valid, but it will be refreshed in the background. await Task.Delay(TimeSpan.FromSeconds(2)); Assert.AreEqual(token, await tokenCredentialCache.GetTokenAsync(trace)); + Assert.AreEqual(1, testTokenCredential.NumTimesInvoked); // Token refreshes fails except for the first time, but the cached token will be served as long as it is valid. - await Task.Delay(TimeSpan.FromSeconds(3)); + // Wait for the background refresh to occur. It should fail but the cached token should still be valid + Stopwatch stopwatch = Stopwatch.StartNew(); + while (testTokenCredential.NumTimesInvoked != 3) + { + Assert.IsTrue(stopwatch.Elapsed.TotalSeconds < 10, "The background task did not start in 10 seconds"); + await Task.Delay(200); + } Assert.AreEqual(token, await tokenCredentialCache.GetTokenAsync(trace)); + Assert.AreEqual(3, testTokenCredential.NumTimesInvoked, $"The cached token was not used. Waited time for background refresh: {stopwatch.Elapsed.TotalSeconds} seconds"); // Cache token has expired, and it fails to refresh. - await Task.Delay(TimeSpan.FromSeconds(2)); + await Task.Delay(TimeSpan.FromSeconds(5)); + throwExceptionOnGetToken = true; // Simulate multiple concurrent request on the failed token List tasks = new List(); - for (int i = 0; i < 20; i++) + for (int i = 0; i < 40; i++) { Task task = Task.Run(async () => { @@ -392,6 +402,22 @@ public async Task TestTokenCredentialFailedToRefreshAsync() await Task.WhenAll(tasks); this.ValidateSemaphoreIsReleased(tokenCredentialCache); + + + // Simulate multiple concurrent request that should succeed after a failure + throwExceptionOnGetToken = false; + int numGetTokenCallsAfterFailures = testTokenCredential.NumTimesInvoked; + tasks = new List(); + for (int i = 0; i < 40; i++) + { + Task task = Task.Run(async () => await tokenCredentialCache.GetTokenAsync(trace)); + tasks.Add(task); + } + + await Task.WhenAll(tasks); + Assert.AreEqual(numGetTokenCallsAfterFailures+1, testTokenCredential.NumTimesInvoked, "There should only be 1 GetToken call to get the new token after the failures"); + + this.ValidateSemaphoreIsReleased(tokenCredentialCache); } }