Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Availability: Fixes get account info retry logic to not go to secondary regions on 403(Forbidden) #2511

Merged
merged 9 commits into from
May 29, 2021
36 changes: 27 additions & 9 deletions Microsoft.Azure.Cosmos/src/Routing/GlobalEndpointManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ private class GetAccountPropertiesHelper
private readonly Uri DefaultEndpoint;
private readonly IEnumerator<string>? Locations;
private readonly Func<Uri, Task<AccountProperties>> GetDatabaseAccountFn;
private readonly List<Exception> TransientExceptions = new List<Exception>();
private AccountProperties? AccountProperties = null;
private Exception? NonRetriableException = null;
private Exception? LastTransientException = null;

public GetAccountPropertiesHelper(
Uri defaultEndpoint,
Expand Down Expand Up @@ -190,12 +190,17 @@ public async Task<AccountProperties> GetAccountPropertiesAsync()
tasksToWaitOn.Remove(completedTask);
}

if (this.LastTransientException == null)
if (!this.TransientExceptions.Any())
j82w marked this conversation as resolved.
Show resolved Hide resolved
{
throw new ArgumentException("Account properties and NonRetriableException are null and there is no LastTransientException.");
j82w marked this conversation as resolved.
Show resolved Hide resolved
}

throw this.LastTransientException;
if (this.TransientExceptions.Count == 1)
{
throw this.TransientExceptions.First();
j82w marked this conversation as resolved.
Show resolved Hide resolved
}

throw new AggregateException(this.TransientExceptions);
}

private async Task<AccountProperties> GetOnlyGlobalEndpointAsync()
Expand All @@ -217,12 +222,17 @@ private async Task<AccountProperties> GetOnlyGlobalEndpointAsync()
throw this.NonRetriableException;
}

if (this.LastTransientException != null)
if (!this.TransientExceptions.Any())
j82w marked this conversation as resolved.
Show resolved Hide resolved
{
throw new ArgumentException("Account properties and NonRetriableException are null and there is no LastTransientException.");
j82w marked this conversation as resolved.
Show resolved Hide resolved
}

if (this.TransientExceptions.Count == 1)
{
throw this.LastTransientException;
throw this.TransientExceptions.First();
j82w marked this conversation as resolved.
Show resolved Hide resolved
}

throw new ArgumentException("The account properties and exceptions are null");
throw new AggregateException(this.TransientExceptions);
}

/// <summary>
Expand Down Expand Up @@ -279,7 +289,11 @@ private async Task GetAndUpdateAccountPropertiesAsync(Uri endpoint)
{
if (this.CancellationTokenSource.IsCancellationRequested)
{
this.LastTransientException = new OperationCanceledException("GlobalEndpointManager: Get account information canceled");
lock (this.TransientExceptions)
{
this.TransientExceptions.Add(new OperationCanceledException("GlobalEndpointManager: Get account information canceled"));
}

return;
}

Expand All @@ -302,14 +316,18 @@ private async Task GetAndUpdateAccountPropertiesAsync(Uri endpoint)
}
else
{
this.LastTransientException = e;
lock (this.TransientExceptions)
{
this.TransientExceptions.Add(e);
}
}
}
}

private static bool IsNonRetriableException(Exception exception)
{
if (exception is DocumentClientException dce && dce.StatusCode == HttpStatusCode.Unauthorized)
if (exception is DocumentClientException dce &&
(dce.StatusCode == HttpStatusCode.Unauthorized || dce.StatusCode == HttpStatusCode.Forbidden))
{
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,69 @@ await GlobalEndpointManager.GetDatabaseAccountFromAnyLocationsAsync(
Assert.IsTrue(count <= 3, "Global endpoint is 1, 2 tasks going to regions parallel");
Assert.AreEqual(2, count, "Only request should be made");
}

count = 0;
try
{
await GlobalEndpointManager.GetDatabaseAccountFromAnyLocationsAsync(
defaultEndpoint: defaultEndpoint,
locations: new List<string>(){
"westus",
"southeastasia",
"northcentralus"
},
getDatabaseAccountFn: (uri) =>
{
count++;
if (uri == defaultEndpoint)
{
throw new Microsoft.Azure.Documents.ForbiddenException("Mock ForbiddenException exception");
}

throw new Exception("This should never be hit since it should stop after the global endpoint hit the nonretriable exception");
},
cancellationToken: default);

Assert.Fail("Should throw the ForbiddenException");
}
catch (Microsoft.Azure.Documents.ForbiddenException)
{
Assert.AreEqual(1, count, "Only request should be made");
}

// All endpoints failed. Validate aggregate exception
count = 0;
HashSet<Exception> exceptions = new HashSet<Exception>();
try
{
await GlobalEndpointManager.GetDatabaseAccountFromAnyLocationsAsync(
defaultEndpoint: defaultEndpoint,
locations: new List<string>(){
"westus",
"southeastasia",
"northcentralus"
},
getDatabaseAccountFn: (uri) =>
{
count++;
Exception exception = new HttpRequestException("Mock HttpRequestException exception:" + count);
exceptions.Add(exception);
throw exception;
},
cancellationToken: default);

Assert.Fail("Should throw the AggregateException");
}
catch (AggregateException aggregateException)
{
Assert.AreEqual(4, count, "All endpoints should have been tried. 1 global, 3 regional endpoints");
Assert.AreEqual(4, exceptions.Count, "Some exceptions were not logged");
Assert.AreEqual(4, aggregateException.InnerExceptions.Count, "aggregateException should have 4 inner exceptions");
foreach(Exception exception in aggregateException.InnerExceptions)
{
Assert.IsTrue(exceptions.Contains(exception));
}
}
}

/// <summary>
Expand Down