Skip to content

Commit

Permalink
Azure Active Directory(AAD): Fixes stuck requests when background ref…
Browse files Browse the repository at this point in the history
…resh fails to refresh token (#2697)

1. This bug can cause requests to get stuck because the semaphore was not getting released in the scenario where multiple requests are waiting for a new token. This only occurs in scenario where the background refresh has failed to get a new token.
2. This optimizes the scenario where multiple concurrent requests are waiting on the token. It will now return the original task that is getting the token. This prevents all the requests waiting in serial to get the failure. All the requests will return the same exception.

The existing test for this scenario use .Wait which blocked the threads which also blocked all the other task to simulate concurrent requests. The tasks now use a Task.Run to prevent them from getting blocked again and wait logic was converted to an async/await.
  • Loading branch information
j82w authored Sep 9, 2021
1 parent 8f12da0 commit 660ec97
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 136 deletions.
247 changes: 137 additions & 110 deletions Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
namespace Microsoft.Azure.Cosmos
{
using System;
using System.Diagnostics;
using System.Globalization;
using System.Net;
using System.Threading;
Expand All @@ -15,7 +14,6 @@ namespace Microsoft.Azure.Cosmos
using Microsoft.Azure.Cosmos.Core.Trace;
using Microsoft.Azure.Cosmos.Resource.CosmosExceptions;
using Microsoft.Azure.Cosmos.Tracing;
using Microsoft.Azure.Cosmos.Tracing.TraceData;
using Microsoft.Azure.Documents;

/// <summary>
Expand All @@ -32,7 +30,7 @@ internal sealed class TokenCredentialCache : IDisposable
public static readonly double DefaultBackgroundTokenCredentialRefreshIntervalPercentage = .50;

// The maximum time a task delayed is allowed is Int32.MaxValue in Milliseconds which is roughly 24 days
public static readonly TimeSpan MaxBackgroundRefreshInterval = TimeSpan.FromMilliseconds(Int32.MaxValue);
public static readonly TimeSpan MaxBackgroundRefreshInterval = TimeSpan.FromMilliseconds(int.MaxValue);

// The token refresh retries half the time. Given default of 1hr it will retry at 30m, 15, 7.5, 3.75, 1.875
// If the background refresh fails with less than a minute then just allow the request to hit the exception.
Expand All @@ -49,7 +47,8 @@ internal sealed class TokenCredentialCache : IDisposable
private readonly object backgroundRefreshLock = new object();

private TimeSpan? systemBackgroundTokenCredentialRefreshInterval;
private AccessToken? cachedAccessToken;
private Task<AccessToken>? currentRefreshOperation = null;
private AccessToken? cachedAccessToken = null;
private bool isBackgroundTaskRunning = false;
private bool isDisposed = false;

Expand Down Expand Up @@ -107,8 +106,13 @@ internal async ValueTask<string> GetTokenAsync(ITrace trace)
return this.cachedAccessToken.Value.Token;
}

AccessToken accessToken = await this.RefreshCachedTokenWithRetryHelperAsync(trace);
this.StartBackgroundTokenRefreshLoop();
AccessToken accessToken = await this.GetNewTokenAsync(trace);
if (!this.isBackgroundTaskRunning)
{
// This is a background thread so no need to await
Task backgroundThread = Task.Run(this.StartBackgroundTokenRefreshLoop);
}

return accessToken.Token;
}

Expand All @@ -124,137 +128,160 @@ public void Dispose()
this.isDisposed = true;
}

private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
private async Task<AccessToken> GetNewTokenAsync(
ITrace trace)
{
Exception? lastException = null;
const int totalRetryCount = 2;
for (int retry = 0; retry < totalRetryCount; retry++)
// Use a local variable to avoid the possibility the task gets changed
// between the null check and the await operation.
Task<AccessToken>? currentTask = this.currentRefreshOperation;
if (currentTask != null)
{
if (this.cancellationToken.IsCancellationRequested)
{
DefaultTrace.TraceInformation(
"Stop RefreshTokenWithIndefiniteRetries because cancellation is requested");
// The refresh is already occurring wait on the existing task
return await currentTask;
}

break;
}
try
{
await this.isTokenRefreshingLock.WaitAsync();

using (ITrace getTokenTrace = trace.StartChild(
name: nameof(this.RefreshCachedTokenWithRetryHelperAsync),
component: TraceComponent.Authorization,
level: Tracing.TraceLevel.Info))
// avoid doing the await in the semaphore to unblock the parallel requests
if (this.currentRefreshOperation == null)
{
try
{
return await this.UpdateCachedTokenAsync();
}
catch (RequestFailedException requestFailedException)
{
lastException = requestFailedException;
getTokenTrace.AddDatum(
$"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
requestFailedException);
// ValueTask can not be awaited multiple times
currentTask = this.RefreshCachedTokenWithRetryHelperAsync(trace).AsTask();
this.currentRefreshOperation = currentTask;
}
else
{
currentTask = this.currentRefreshOperation;
}
}
finally
{
this.isTokenRefreshingLock.Release();
}

DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}");
return await currentTask;
}

// Don't retry on auth failures
if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden)
{
this.cachedAccessToken = default;
throw;
}
}
catch (OperationCanceledException operationCancelled)
private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
ITrace trace)
{
try
{
Exception? lastException = null;
const int totalRetryCount = 2;
for (int retry = 0; retry < totalRetryCount; retry++)
{
if (this.cancellationToken.IsCancellationRequested)
{
lastException = operationCancelled;
getTokenTrace.AddDatum(
$"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
operationCancelled);
DefaultTrace.TraceInformation(
"Stop RefreshTokenWithIndefiniteRetries because cancellation is requested");

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}");

throw CosmosExceptionFactory.CreateRequestTimeoutException(
message: ClientResources.FailedToGetAadToken,
headers: new Headers()
{
SubStatusCode = SubStatusCodes.FailedToGetAadToken,
},
innerException: lastException,
trace: getTokenTrace);
break;
}
catch (Exception exception)
{
lastException = exception;
getTokenTrace.AddDatum(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}");
}
}
}
using (ITrace getTokenTrace = trace.StartChild(
name: nameof(this.RefreshCachedTokenWithRetryHelperAsync),
component: TraceComponent.Authorization,
level: Tracing.TraceLevel.Info))
{
try
{
this.cachedAccessToken = await this.tokenCredential.GetTokenAsync(
requestContext: this.tokenRequestContext,
cancellationToken: default);

if (lastException == null)
{
throw new ArgumentException("Last exception is null.");
}
if (!this.cachedAccessToken.HasValue)
{
throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token.");
}

// The retries have been exhausted. Throw the last exception.
throw lastException;
}
if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow)
{
throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}");
}

/// <summary>
/// This method takes a lock to only allow one thread to update the token
/// at a time. If the token was updated while it was waiting for the lock it
/// returns the new cached token.
/// </summary>
private async Task<AccessToken> UpdateCachedTokenAsync()
{
DateTimeOffset? initialExpireTime = this.cachedAccessToken?.ExpiresOn;
if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue)
{
double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage;

await this.isTokenRefreshingLock.WaitAsync();
// Ensure the background refresh interval is a valid range.
refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds);
refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds);
this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds);
}

// Token was already refreshed successfully from another thread.
if (this.cachedAccessToken.HasValue &&
(!initialExpireTime.HasValue || this.cachedAccessToken.Value.ExpiresOn != initialExpireTime.Value))
{
return this.cachedAccessToken.Value;
}
return this.cachedAccessToken.Value;
}
catch (RequestFailedException requestFailedException)
{
lastException = requestFailedException;
getTokenTrace.AddDatum(
$"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
requestFailedException);

try
{
this.cachedAccessToken = await this.tokenCredential.GetTokenAsync(
requestContext: this.tokenRequestContext,
cancellationToken: default);
DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}");

if (!this.cachedAccessToken.HasValue)
{
throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token.");
}
// Don't retry on auth failures
if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden)
{
this.cachedAccessToken = default;
throw;
}
}
catch (OperationCanceledException operationCancelled)
{
lastException = operationCancelled;
getTokenTrace.AddDatum(
$"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
operationCancelled);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}");

throw CosmosExceptionFactory.CreateRequestTimeoutException(
message: ClientResources.FailedToGetAadToken,
headers: new Headers()
{
SubStatusCode = SubStatusCodes.FailedToGetAadToken,
},
innerException: lastException,
trace: getTokenTrace);
}
catch (Exception exception)
{
lastException = exception;
getTokenTrace.AddDatum(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception);

if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow)
{
throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}");
DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException}");
}
}
}

if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue)
if (lastException == null)
{
double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage;

// Ensure the background refresh interval is a valid range.
refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds);
refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds);
this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds);
throw new ArgumentException("Last exception is null.");
}

return this.cachedAccessToken.Value;
// The retries have been exhausted. Throw the last exception.
throw lastException;
}
finally
{
this.isTokenRefreshingLock.Release();
try
{
await this.isTokenRefreshingLock.WaitAsync();
this.currentRefreshOperation = null;
}
finally
{
this.isTokenRefreshingLock.Release();
}
}
}

Expand Down Expand Up @@ -300,7 +327,7 @@ private async void StartBackgroundTokenRefreshLoop()

DefaultTrace.TraceInformation("BackgroundTokenRefreshLoop() - Invoking refresh");

await this.UpdateCachedTokenAsync();
await this.GetNewTokenAsync(Tracing.Trace.GetRootTrace("TokenCredentialCacheBackground refresh"));
}
catch (Exception ex)
{
Expand Down
Loading

0 comments on commit 660ec97

Please sign in to comment.