diff --git a/sdk/core/Azure.Core/src/Shared/BearerTokenChallengeAuthenticationPolicy.cs b/sdk/core/Azure.Core/src/Shared/BearerTokenChallengeAuthenticationPolicy.cs new file mode 100644 index 0000000000000..f3a85b89fbc0b --- /dev/null +++ b/sdk/core/Azure.Core/src/Shared/BearerTokenChallengeAuthenticationPolicy.cs @@ -0,0 +1,480 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Diagnostics; + +#nullable enable + +namespace Azure.Core.Pipeline +{ + /// + /// A policy that sends an provided by a as an Authentication header. + /// Note: This class is currently in preview and is therefore subject to possible future breaking changes. + /// + internal class BearerTokenChallengeAuthenticationPolicy : HttpPipelinePolicy + { + private const string ChallengeHeader = "WWW-Authenticate"; + private readonly AccessTokenCache _accessTokenCache; + private string[] _scopes; + + /// + /// Creates a new instance of using provided token credential and scope to authenticate for. + /// + /// The token credential to use for authentication. + /// The scope to authenticate for. + public BearerTokenChallengeAuthenticationPolicy(TokenCredential credential, string scope) : this(credential, new[] { scope }) { } + + /// + /// Creates a new instance of using provided token credential and scopes to authenticate for. + /// + /// The token credential to use for authentication. + /// Scopes to authenticate for. + public BearerTokenChallengeAuthenticationPolicy(TokenCredential credential, IEnumerable scopes) + : this(credential, scopes, TimeSpan.FromMinutes(5), TimeSpan.FromSeconds(30)) { } + + internal BearerTokenChallengeAuthenticationPolicy(TokenCredential credential, IEnumerable scopes, TimeSpan tokenRefreshOffset, TimeSpan tokenRefreshRetryDelay) + { + Argument.AssertNotNull(credential, nameof(credential)); + Argument.AssertNotNull(scopes, nameof(scopes)); + + _scopes = scopes.ToArray(); + _accessTokenCache = new AccessTokenCache(credential, tokenRefreshOffset, tokenRefreshRetryDelay, scopes.ToArray()); + } + + /// + public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline) + { + return ProcessAsync(message, pipeline, true); + } + + /// + public override void Process(HttpMessage message, ReadOnlyMemory pipeline) + { + ProcessAsync(message, pipeline, false).EnsureCompleted(); + } + + /// + /// Executed in the event a 401 response with a WWW-Authenticate authentication challenge header is received after the initial request. + /// + /// This implementation handles common authentication challenges such as claims challenges. Service client libraries may derive from this and extend to handle service specific authentication challenges. + /// The to be authenticated. + /// If the return value is true, a . + /// A boolean indicated whether the request contained a valid challenge and a was successfully initialized with it. + protected virtual bool TryGetTokenRequestContextFromChallenge(HttpMessage message, out TokenRequestContext context) + { + context = default; + + var claimsChallenge = GetClaimsChallenge(message.Response); + + if (claimsChallenge != null) + { + context = new TokenRequestContext(_scopes, message.Request.ClientRequestId, claimsChallenge); + return true; + } + + return false; + } + + private async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline, bool async) + { + if (message.Request.Uri.Scheme != Uri.UriSchemeHttps) + { + throw new InvalidOperationException("Bearer token authentication is not permitted for non TLS protected (https) endpoints."); + } + + TokenRequestContext context; + + // If the message already has a challenge response due to a sub-class pre-processing the request, get the context from the challenge. + if (message.HasResponse && message.Response.Status == (int)HttpStatusCode.Unauthorized && message.Response.Headers.Contains(ChallengeHeader)) + { + if (!TryGetTokenRequestContextFromChallenge(message, out context)) + { + // We were unsuccessful in handling the challenge, so bail out now. + return; + } + _scopes = context.Scopes; + } + else + { + context = new TokenRequestContext(_scopes, message.Request.ClientRequestId); + } + + await AuthenticateRequestAsync(message, context, async).ConfigureAwait(false); + + if (async) + { + await ProcessNextAsync(message, pipeline).ConfigureAwait(false); + } + else + { + ProcessNext(message, pipeline); + } + + // Check if we have received a challenge or we have not yet issued the first request. + if (message.Response.Status == (int)HttpStatusCode.Unauthorized && message.Response.Headers.Contains(ChallengeHeader)) + { + // Attempt to get the TokenRequestContext based on the challenge. + // If we fail to get the context, the challenge was not present or invalid. + // If we succeed in getting the context, authenticate the request and pass it up the policy chain. + if (TryGetTokenRequestContextFromChallenge(message, out context)) + { + // Ensure the scopes are consistent with what was set by . + _scopes = context.Scopes; + + await AuthenticateRequestAsync(message, context, async).ConfigureAwait(false); + + if (async) + { + await ProcessNextAsync(message, pipeline).ConfigureAwait(false); + } + else + { + ProcessNext(message, pipeline); + } + } + } + } + + private async Task AuthenticateRequestAsync(HttpMessage message, TokenRequestContext context, bool async) + { + string headerValue; + if (async) + { + headerValue = await _accessTokenCache.GetHeaderValueAsync(message, context, async).ConfigureAwait(false); + } + else + { + headerValue = _accessTokenCache.GetHeaderValueAsync(message, context, async).EnsureCompleted(); + } + + message.Request.SetHeader(HttpHeader.Names.Authorization, headerValue); + } + + private static string? GetClaimsChallenge(Response response) + { + if (response.Status != (int)HttpStatusCode.Unauthorized || !response.Headers.TryGetValue(ChallengeHeader, out string? headerValue)) + { + return null; + } + + ReadOnlySpan bearer = "Bearer".AsSpan(); + ReadOnlySpan claims = "claims".AsSpan(); + ReadOnlySpan headerSpan = headerValue.AsSpan(); + + // Iterate through each challenge value. + while (TryGetNextChallenge(ref headerSpan, out var challengeKey)) + { + // Enumerate each key=value parameter until we find the 'claims' key on the 'Bearer' challenge. + while (TryGetNextParameter(ref headerSpan, out var key, out var value)) + { + if (challengeKey.Equals(bearer, StringComparison.OrdinalIgnoreCase) && key.Equals(claims, StringComparison.OrdinalIgnoreCase)) + { + return Base64Url.DecodeString(value.ToString()); + } + } + } + + return null; + } + + /// + /// Iterates through the challenge schemes present in a challenge header. + /// + /// + /// The header value which will be sliced to remove the first parsed . + /// + /// The parsed challenge scheme. + /// + /// true if a challenge scheme was successfully parsed. + /// The value of should be passed to to parse the challenge parameters if true. + /// + internal static bool TryGetNextChallenge(ref ReadOnlySpan headerValue, out ReadOnlySpan challengeKey) + { + challengeKey = default; + + headerValue = headerValue.TrimStart(' '); + int endOfChallengeKey = headerValue.IndexOf(' '); + + if (endOfChallengeKey < 0) + { + return false; + } + + challengeKey = headerValue.Slice(0, endOfChallengeKey); + + // Slice the challenge key from the headerValue + headerValue = headerValue.Slice(endOfChallengeKey + 1); + + return true; + } + + /// + /// Iterates through a challenge header value after being parsed by . + /// + /// The header value after being parsed by . + /// The parsed challenge parameter key. + /// The parsed challenge parameter value. + /// The challenge parameter key / value pair separator. The default is '='. + /// + /// true if the next available challenge parameter was successfully parsed. + /// false if there are no more parameters for the current challenge scheme or an additional challenge scheme was encountered in the . + /// The value of should be passed again to to attempt to parse any additional challenge schemes if false. + /// + internal static bool TryGetNextParameter(ref ReadOnlySpan headerValue, out ReadOnlySpan paramKey, out ReadOnlySpan paramValue, char separator = '=') + { + paramKey = default; + paramValue = default; + var spaceOrComma = " ,".AsSpan(); + + // Trim any separater prefixes. + headerValue = headerValue.TrimStart(spaceOrComma); + + int nextSpace = headerValue.IndexOf(' '); + int nextSeparator = headerValue.IndexOf(separator); + + if (nextSpace < nextSeparator && nextSpace != -1) + { + // we encountered another challenge value. + return false; + } + + if (nextSeparator < 0) + return false; + + // Get the paramKey. + paramKey = headerValue.Slice(0, nextSeparator).Trim(); + + // Slice to remove the 'paramKey=' from the parameters. + headerValue = headerValue.Slice(nextSeparator + 1); + + // The start of paramValue will usually be a quoted string. Find the first quote. + int quoteIndex = headerValue.IndexOf('\"'); + + // Get the paramValue, which is delimited by the trailing quote. + headerValue = headerValue.Slice(quoteIndex + 1); + if (quoteIndex >= 0) + { + // The values are quote wrapped + paramValue = headerValue.Slice(0, headerValue.IndexOf('\"')); + } + else + { + //the values are not quote wrapped (storage is one example of this) + // either find the next space indicating the delimiter to the next value, or go to the end since this is the last value. + int trailingDelimiterIndex = headerValue.IndexOfAny(spaceOrComma); + if (trailingDelimiterIndex >= 0) + { + paramValue = headerValue.Slice(0, trailingDelimiterIndex); + } + else + { + paramValue = headerValue; + } + } + + // Slice to remove the '"paramValue"' from the parameters. + if (headerValue != paramValue) + headerValue = headerValue.Slice(paramValue.Length + 1); + + return true; + } + + private class AccessTokenCache + { + private readonly object _syncObj = new object(); + private readonly TokenCredential _credential; + private readonly TimeSpan _tokenRefreshOffset; + private readonly TimeSpan _tokenRefreshRetryDelay; + + private TokenRequestContext? _currentContext; + private TaskCompletionSource? _infoTcs; + private TaskCompletionSource? _backgroundUpdateTcs; + public AccessTokenCache(TokenCredential credential, TimeSpan tokenRefreshOffset, TimeSpan tokenRefreshRetryDelay, string[] initialScopes) + { + _credential = credential; + _tokenRefreshOffset = tokenRefreshOffset; + _tokenRefreshRetryDelay = tokenRefreshRetryDelay; + _currentContext = new TokenRequestContext(initialScopes); + } + + public async ValueTask GetHeaderValueAsync(HttpMessage message, TokenRequestContext context, bool async) + { + bool getTokenFromCredential; + TaskCompletionSource headerValueTcs; + TaskCompletionSource? backgroundUpdateTcs; + (headerValueTcs, backgroundUpdateTcs, getTokenFromCredential) = GetTaskCompletionSources(context); + HeaderValueInfo info; + + if (getTokenFromCredential) + { + if (backgroundUpdateTcs != null) + { + if (async) + { + info = await headerValueTcs.Task.ConfigureAwait(false); + } + else + { +#pragma warning disable AZC0104 // Use EnsureCompleted() directly on asynchronous method return value. + info = headerValueTcs.Task.EnsureCompleted(); +#pragma warning restore AZC0104 // Use EnsureCompleted() directly on asynchronous method return value. + } + _ = Task.Run(() => GetHeaderValueFromCredentialInBackgroundAsync(backgroundUpdateTcs, info, context, async)); + return info.HeaderValue; + } + + try + { + info = await GetHeaderValueFromCredentialAsync(context, async, message.CancellationToken).ConfigureAwait(false); + headerValueTcs.SetResult(info); + } + catch (OperationCanceledException) + { + headerValueTcs.SetCanceled(); + throw; + } + catch (Exception exception) + { + headerValueTcs.SetException(exception); + throw; + } + } + + var headerValueTask = headerValueTcs.Task; + if (!headerValueTask.IsCompleted) + { + if (async) + { + await headerValueTask.AwaitWithCancellation(message.CancellationToken); + } + else + { + try + { + headerValueTask.Wait(message.CancellationToken); + } + catch (AggregateException) { } // ignore exception here to rethrow it with EnsureCompleted + } + } + if (async) + { + info = await headerValueTcs.Task.ConfigureAwait(false); + } + else + { +#pragma warning disable AZC0104 // Use EnsureCompleted() directly on asynchronous method return value. + info = headerValueTcs.Task.EnsureCompleted(); +#pragma warning restore AZC0104 // Use EnsureCompleted() directly on asynchronous method return value. + } + + return info.HeaderValue; + } + + private (TaskCompletionSource, TaskCompletionSource?, bool) GetTaskCompletionSources(TokenRequestContext context) + { + lock (_syncObj) + { + // Initial state. GetTaskCompletionSources has been called for the first time + if (_infoTcs == null || RequestRequiresNewToken(context)) + { + _currentContext = context; + _infoTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _backgroundUpdateTcs = default; + return (_infoTcs, _backgroundUpdateTcs, true); + } + + // Getting new access token is in progress, wait for it + if (!_infoTcs.Task.IsCompleted) + { + _backgroundUpdateTcs = default; + return (_infoTcs, _backgroundUpdateTcs, false); + } + + DateTimeOffset now = DateTimeOffset.UtcNow; + // Access token has been successfully acquired in background and it is not expired yet, use it instead of current one + if (_backgroundUpdateTcs != null && _backgroundUpdateTcs.Task.Status == TaskStatus.RanToCompletion && _backgroundUpdateTcs.Task.Result.ExpiresOn > now) + { + _infoTcs = _backgroundUpdateTcs; + _backgroundUpdateTcs = default; + } + + // Attempt to get access token has failed or it has already expired. Need to get a new one + if (_infoTcs.Task.Status != TaskStatus.RanToCompletion || now >= _infoTcs.Task.Result.ExpiresOn) + { + _infoTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + return (_infoTcs, default, true); + } + + // Access token is still valid but is about to expire, try to get it in background + if (now >= _infoTcs.Task.Result.RefreshOn && _backgroundUpdateTcs == null) + { + _backgroundUpdateTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + return (_infoTcs, _backgroundUpdateTcs, true); + } + + // Access token is valid, use it + return (_infoTcs, default, false); + } + } + + // must be called under lock (_syncObj) + private bool RequestRequiresNewToken(TokenRequestContext context) => + _currentContext == null || + (context.Scopes != null && !context.Scopes.SequenceEqual(_currentContext.Value.Scopes)) || + (context.Claims != null && !string.Equals(context.Claims, _currentContext.Value.Claims)); + + private async ValueTask GetHeaderValueFromCredentialInBackgroundAsync(TaskCompletionSource backgroundUpdateTcs, HeaderValueInfo info, TokenRequestContext context, bool async) + { + var cts = new CancellationTokenSource(_tokenRefreshRetryDelay); + try + { + HeaderValueInfo newInfo = await GetHeaderValueFromCredentialAsync(context, async, cts.Token).ConfigureAwait(false); + backgroundUpdateTcs.SetResult(newInfo); + } + catch (OperationCanceledException oce) when (cts.IsCancellationRequested) + { + backgroundUpdateTcs.SetResult(new HeaderValueInfo(info.HeaderValue, info.ExpiresOn, DateTimeOffset.UtcNow)); + AzureCoreEventSource.Singleton.BackgroundRefreshFailed(context.ParentRequestId ?? string.Empty, oce.ToString()); + } + catch (Exception e) + { + backgroundUpdateTcs.SetResult(new HeaderValueInfo(info.HeaderValue, info.ExpiresOn, DateTimeOffset.UtcNow + _tokenRefreshRetryDelay)); + AzureCoreEventSource.Singleton.BackgroundRefreshFailed(context.ParentRequestId ?? string.Empty, e.ToString()); + } + finally + { + cts.Dispose(); + } + } + + private async ValueTask GetHeaderValueFromCredentialAsync(TokenRequestContext context, bool async, CancellationToken cancellationToken) + { + AccessToken token = async + ? await _credential.GetTokenAsync(context, cancellationToken).ConfigureAwait(false) + : _credential.GetToken(context, cancellationToken); + + return new HeaderValueInfo("Bearer " + token.Token, token.ExpiresOn, token.ExpiresOn - _tokenRefreshOffset); + } + + private readonly struct HeaderValueInfo + { + public string HeaderValue { get; } + public DateTimeOffset ExpiresOn { get; } + public DateTimeOffset RefreshOn { get; } + + public HeaderValueInfo(string headerValue, DateTimeOffset expiresOn, DateTimeOffset refreshOn) + { + HeaderValue = headerValue; + ExpiresOn = expiresOn; + RefreshOn = refreshOn; + } + } + } + } +} diff --git a/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj b/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj index 19c3058fc5cb8..593a7415302f5 100644 --- a/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj +++ b/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj @@ -24,6 +24,7 @@ + diff --git a/sdk/core/Azure.Core/tests/BearerTokenChallengeAuthenticationPolicyTests.cs b/sdk/core/Azure.Core/tests/BearerTokenChallengeAuthenticationPolicyTests.cs new file mode 100644 index 0000000000000..a9bbf035b6885 --- /dev/null +++ b/sdk/core/Azure.Core/tests/BearerTokenChallengeAuthenticationPolicyTests.cs @@ -0,0 +1,815 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; +using Azure.Core.TestFramework; +using Moq; +using NUnit.Framework; + +namespace Azure.Core.Tests +{ + public class BearerTokenChallengeAuthenticationPolicyTests : SyncAsyncPolicyTestBase + { + public BearerTokenChallengeAuthenticationPolicyTests(bool isAsync) : base(isAsync) { } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_UsesTokenProvidedByCredentials() + { + var credential = new TokenCredentialStub( + (r, c) => r.Scopes.SequenceEqual(new[] { "scope1", "scope2" }) ? new AccessToken("token", DateTimeOffset.MaxValue) : default, IsAsync); + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope1", "scope2" }); + + MockTransport transport = CreateMockTransport(new MockResponse(200)); + await SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + + Assert.True(transport.SingleRequest.Headers.TryGetValue("Authorization", out string authValue)); + Assert.AreEqual("Bearer token", authValue); + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_RequestsTokenEveryRequest() + { + var accessTokens = new Queue(); + accessTokens.Enqueue(new AccessToken("token1", DateTimeOffset.UtcNow)); + accessTokens.Enqueue(new AccessToken("token2", DateTimeOffset.UtcNow)); + + var credential = new TokenCredentialStub( + (r, c) => r.Scopes.SequenceEqual(new[] { "scope1", "scope2" }) ? accessTokens.Dequeue() : default, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope1", "scope2" }); + MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200)); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + await SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value)); + Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value)); + + Assert.AreEqual("Bearer token1", auth1Value); + Assert.AreEqual("Bearer token2", auth2Value); + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_CachesHeaderValue() + { + var credential = new TokenCredentialStub( + (r, c) => r.Scopes.SequenceEqual(new[] { "scope" }) ? new AccessToken("token", DateTimeOffset.MaxValue) : default, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200)); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + await SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value)); + Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value)); + + Assert.AreSame(auth1Value, auth1Value); + Assert.AreEqual("Bearer token", auth2Value); + } + + [Test] + public void BearerTokenChallengeAuthenticationPolicy_ThrowsForNonTlsEndpoint() + { + var credential = new TokenCredentialStub( + (r, c) => r.Scopes.SequenceEqual(new[] { "scope" }) ? new AccessToken("token", DateTimeOffset.MaxValue) : default, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(); + + Assert.ThrowsAsync(async () => await SendGetRequest(transport, policy, uri: new Uri("http://example.com"))); + } + + [Test] + public void BearerTokenChallengeAuthenticationPolicy_ThrowsForEmptyToken() + { + var credential = new TokenCredentialStub((r, c) => new AccessToken(string.Empty, DateTimeOffset.MaxValue), IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(); + + Assert.ThrowsAsync(async () => await SendGetRequest(transport, policy, uri: new Uri("http://example.com"))); + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_OneHundredConcurrentCalls() + { + var credential = new TokenCredentialStub((r, c) => + { + Thread.Sleep(100); + return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddMinutes(30)); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(r => new MockResponse(200)); + var requestTasks = new Task[100]; + + for (int i = 0; i < requestTasks.Length; i++) + { + requestTasks[i] = SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + } + + await Task.WhenAll(requestTasks); + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value)); + + for (int i = 1; i < requestTasks.Length; i++) + { + Assert.True(transport.Requests[i].Headers.TryGetValue("Authorization", out string authValue)); + Assert.AreEqual(auth1Value, authValue); + } + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_GatedConcurrentCalls() + { + var requestMre = new ManualResetEventSlim(false); + var responseMre = new ManualResetEventSlim(false); + var credential = new TokenCredentialStub((r, c) => + { + requestMre.Set(); + responseMre.Wait(c); + return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddMinutes(30)); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200)); + + var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + requestMre.Wait(); + + var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + responseMre.Set(); + + await Task.WhenAll(firstRequestTask, secondRequestTask); + + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value)); + Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value)); + + Assert.AreEqual(auth1Value, auth2Value); + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_SucceededFailedSucceeded() + { + var requestMre = new ManualResetEventSlim(false); + var callCount = 0; + var credential = new TokenCredentialStub((r, c) => + { + Interlocked.Increment(ref callCount); + var offsetTime = DateTimeOffset.UtcNow; + requestMre.Set(); + + return callCount == 2 + ? throw new InvalidOperationException("Call Failed") + : new AccessToken(Guid.NewGuid().ToString(), offsetTime.AddMilliseconds(1000)); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope" }, TimeSpan.FromMilliseconds(100), TimeSpan.FromSeconds(30)); + MockTransport transport = CreateMockTransport(r => new MockResponse(200)); + + var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/1")); + var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/2")); + + requestMre.Wait(); + await Task.Delay(200); + + await Task.WhenAll(firstRequestTask, secondRequestTask); + await Task.Delay(1000); + + Assert.AreEqual(1, callCount); + requestMre.Reset(); + + var failedTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/3/failed")); + requestMre.Wait(); + + Assert.AreEqual(2, callCount); + Assert.ThrowsAsync(async () => await failedTask); + + requestMre.Reset(); + + firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/4")); + secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/5")); + + requestMre.Wait(); + + await Task.WhenAll(firstRequestTask, secondRequestTask); + + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value)); + Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value)); + Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth3Value)); + Assert.True(transport.Requests[3].Headers.TryGetValue("Authorization", out string auth4Value)); + + Assert.AreEqual(3, callCount); + Assert.AreEqual(auth1Value, auth2Value); + Assert.AreNotEqual(auth2Value, auth3Value); + Assert.AreEqual(auth3Value, auth4Value); + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_TokenAlmostExpired() + { + var requestMre = new ManualResetEventSlim(true); + var responseMre = new ManualResetEventSlim(true); + var currentTime = DateTimeOffset.UtcNow; + var expires = new Queue(new[] { currentTime.AddMinutes(2), currentTime.AddMinutes(30) }); + var callCount = 0; + var credential = new TokenCredentialStub((r, c) => + { + requestMre.Set(); + responseMre.Wait(c); + requestMre.Reset(); + callCount++; + + return new AccessToken(Guid.NewGuid().ToString(), expires.Dequeue()); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200), new MockResponse(200), new MockResponse(200)); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/1/Original")); + responseMre.Reset(); + + Task requestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/3/Refresh")); + requestMre.Wait(); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/2/AlmostExpired")); + await requestTask; + responseMre.Set(); + await Task.Delay(1_000); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/4/AfterRefresh")); + + Assert.AreEqual(2, callCount); + + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value)); + Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value)); + Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth3Value)); + Assert.True(transport.Requests[3].Headers.TryGetValue("Authorization", out string auth4Value)); + + Assert.AreEqual(auth1Value, auth2Value); + Assert.AreEqual(auth2Value, auth3Value); + Assert.AreNotEqual(auth3Value, auth4Value); + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_TokenAlmostExpired_NoRefresh() + { + var requestMre = new ManualResetEventSlim(true); + var responseMre = new ManualResetEventSlim(true); + var currentTime = DateTimeOffset.UtcNow; + var callCount = 0; + + var credential = new TokenCredentialStub((r, c) => + { + callCount++; + responseMre.Wait(c); + requestMre.Set(); + + return new AccessToken(Guid.NewGuid().ToString(), currentTime.AddMinutes(2)); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200), new MockResponse(200), new MockResponse(200)); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/1/Original")); + requestMre.Wait(); + responseMre.Reset(); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/2/AlmostExpired")); + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/3/AlmostExpired")); + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/4/AlmostExpired")); + + requestMre.Reset(); + responseMre.Set(); + requestMre.Wait(); + + Assert.AreEqual(2, callCount); + + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value)); + Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value)); + Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth3Value)); + Assert.True(transport.Requests[3].Headers.TryGetValue("Authorization", out string auth4Value)); + + Assert.AreEqual(auth1Value, auth2Value); + Assert.AreEqual(auth2Value, auth3Value); + Assert.AreEqual(auth3Value, auth4Value); + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_TokenExpired() + { + var requestMre = new ManualResetEventSlim(true); + var responseMre = new ManualResetEventSlim(true); + var currentTime = DateTimeOffset.UtcNow; + var expires = new Queue(new[] { currentTime.AddSeconds(2), currentTime.AddMinutes(30) }); + var credential = new TokenCredentialStub((r, c) => + { + requestMre.Set(); + responseMre.Wait(c); + return new AccessToken(Guid.NewGuid().ToString(), expires.Dequeue()); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope" }, TimeSpan.FromSeconds(2), TimeSpan.FromMilliseconds(50)); + MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200), new MockResponse(200)); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/0")); + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string authValue)); + + await Task.Delay(3_000); + + requestMre.Reset(); + responseMre.Reset(); + + var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/1")); + var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/2")); + requestMre.Wait(); + await Task.Delay(1_000); + responseMre.Set(); + + await Task.WhenAll(firstRequestTask, secondRequestTask); + + Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth1Value)); + Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth2Value)); + + Assert.AreNotEqual(authValue, auth1Value); + Assert.AreEqual(auth1Value, auth2Value); + } + + [Test] + public void BearerTokenChallengeAuthenticationPolicy_OneHundredConcurrentCallsFailed() + { + var credential = new TokenCredentialStub((r, c) => + { + Thread.Sleep(100); + throw new InvalidOperationException("Error"); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(r => new MockResponse(200)); + var requestTasks = new Task[100]; + + for (int i = 0; i < requestTasks.Length; i++) + { + requestTasks[i] = SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + } + + Assert.CatchAsync(async () => await Task.WhenAll(requestTasks)); + + foreach (Task task in requestTasks) + { + Assert.IsTrue(task.IsFaulted); + } + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_GatedConcurrentCallsFailed() + { + var requestMre = new ManualResetEventSlim(false); + var responseMre = new ManualResetEventSlim(false); + var credential = new TokenCredentialStub((r, c) => + { + requestMre.Set(); + responseMre.Wait(c); + throw new InvalidOperationException("Error"); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200)); + + var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + + requestMre.Wait(); + await Task.Delay(1_000); + responseMre.Set(); + + Assert.CatchAsync(async () => await Task.WhenAll(firstRequestTask, secondRequestTask)); + + Assert.IsTrue(firstRequestTask.IsFaulted); + Assert.IsTrue(secondRequestTask.IsFaulted); + Assert.AreEqual(firstRequestTask.Exception.InnerException, secondRequestTask.Exception.InnerException); + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_TokenExpiredThenFailed() + { + var requestMre = new ManualResetEventSlim(true); + var responseMre = new ManualResetEventSlim(true); + var fail = false; + var credential = new TokenCredentialStub((r, c) => + { + requestMre.Set(); + responseMre.Wait(c); + if (fail) + { + throw new InvalidOperationException("Error"); + } + + fail = true; + return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddSeconds(2)); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope" }, TimeSpan.FromSeconds(2), TimeSpan.FromMilliseconds(50)); + MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200), new MockResponse(200)); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/0")); + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string _)); + + await Task.Delay(3_000); + + requestMre.Reset(); + responseMre.Reset(); + + var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com")); + + requestMre.Wait(); + await Task.Delay(1_000); + responseMre.Set(); + + Assert.CatchAsync(async () => await Task.WhenAll(firstRequestTask, secondRequestTask)); + + Assert.IsTrue(firstRequestTask.IsFaulted); + Assert.IsTrue(secondRequestTask.IsFaulted); + Assert.AreEqual(firstRequestTask.Exception.InnerException, secondRequestTask.Exception.InnerException); + } + + [Test] + [Ignore("https://github.com/Azure/azure-sdk-for-net/issues/14612")] + public async Task BearerTokenChallengeAuthenticationPolicy_TokenAlmostExpiredThenFailed() + { + var requestMre = new ManualResetEventSlim(true); + var responseMre = new ManualResetEventSlim(true); + var credentialMre = new ManualResetEventSlim(false); + + var getTokenRequestTimes = new List(); + var transportCallCount = 0; + var credential = new TokenCredentialStub((r, c) => + { + if (transportCallCount > 0) + { + credentialMre.Set(); + getTokenRequestTimes.Add(DateTimeOffset.UtcNow); + throw new InvalidOperationException("Error"); + } + + return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddMinutes(1.5)); + }, IsAsync); + + var tokenRefreshRetryDelay = TimeSpan.FromSeconds(2); + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope" }, TimeSpan.FromMinutes(2), tokenRefreshRetryDelay); + MockTransport transport = CreateMockTransport(r => + { + requestMre.Set(); + responseMre.Wait(); + if (Interlocked.Increment(ref transportCallCount) == 4) + { + credentialMre.Wait(); + } + return new MockResponse(200); + }); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/1")); + Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value)); + + requestMre.Reset(); + responseMre.Reset(); + + Task requestTask1 = SendGetRequest(transport, policy, uri: new Uri("https://example.com/2/TokenFromCache/RefreshInBackground")); + Task requestTask2 = SendGetRequest(transport, policy, uri: new Uri("https://example.com/3/TokenFromCache/")); + + requestMre.Wait(); + responseMre.Set(); + + await Task.WhenAll(requestTask1, requestTask2); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/4/TokenFromCache")); + + await Task.Delay((int)tokenRefreshRetryDelay.TotalMilliseconds + 1_000); + credentialMre.Reset(); + + await SendGetRequest(transport, policy, uri: new Uri("https://example.com/5/TokenFromCache/GetTokenFailed")); + credentialMre.Wait(); + + Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value)); + Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth3Value)); + Assert.True(transport.Requests[3].Headers.TryGetValue("Authorization", out string auth4Value)); + Assert.True(transport.Requests[4].Headers.TryGetValue("Authorization", out string auth5Value)); + + Assert.AreEqual(auth1Value, auth2Value); + Assert.AreEqual(auth2Value, auth3Value); + Assert.AreEqual(auth3Value, auth4Value); + Assert.AreEqual(auth4Value, auth5Value); + + Assert.AreEqual(2, getTokenRequestTimes.Count); + Assert.True(getTokenRequestTimes[1] - getTokenRequestTimes[0] > tokenRefreshRetryDelay); + } + + [Test] + public void BearerTokenChallengeAuthenticationPolicy_GatedConcurrentCallsCancelled() + { + var requestMre = new ManualResetEventSlim(false); + var responseMre = new ManualResetEventSlim(false); + var cts = new CancellationTokenSource(); + var credential = new TokenCredentialStub((r, c) => + { + requestMre.Set(); + responseMre.Wait(c); + throw new InvalidOperationException("Error"); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200)); + + var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default); + requestMre.Wait(); + + var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: cts.Token); + cts.Cancel(); + + Assert.CatchAsync(async () => await secondRequestTask); + responseMre.Set(); + + Assert.CatchAsync(async () => await firstRequestTask); + } + + private const string CaeInsufficientClaimsChallenge = "Bearer realm=\"\", authorization_uri=\"https://login.microsoftonline.com/common/oauth2/authorize\", client_id=\"00000003-0000-0000-c000-000000000000\", error=\"insufficient_claims\", claims=\"eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0=\""; + private const string CaeInsufficientClaimsChallengeValue = "eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0="; + private static readonly Challenge ParsedCaeInsufficientClaimsChallenge = new Challenge + { + Scheme = "Bearer", + Parameters = + { + ("realm", ""), + ("authorization_uri", "https://login.microsoftonline.com/common/oauth2/authorize"), + ("client_id", "00000003-0000-0000-c000-000000000000"), + ("error", "insufficient_claims"), + ("claims", "eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0="), + } + }; + + private const string CaeSessionsRevokedClaimsChallenge = "Bearer authorization_uri=\"https://login.windows-ppe.net/\", error=\"invalid_token\", error_description=\"User session has been revoked\", claims=\"eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=\""; + private const string CaeSessionsRevokedClaimsChallengeValue = "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="; + private static readonly Challenge ParsedCaeSessionsRevokedClaimsChallenge = new Challenge + { + Scheme = "Bearer", + Parameters = + { + ("authorization_uri", "https://login.windows-ppe.net/"), + ("error", "invalid_token"), + ("error_description", "User session has been revoked"), + ("claims", "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="), + } + }; + + private const string KeyVaultChallenge = "Bearer authorization=\"https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47\", resource=\"https://vault.azure.net\""; + private static readonly Challenge ParsedKeyVaultChallenge = new Challenge + { + Scheme = "Bearer", + Parameters = + { + ("authorization", "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47"), + ("resource", "https://vault.azure.net"), + } + }; + + private const string ArmChallenge = "Bearer authorization_uri=\"https://login.windows.net/\", error=\"invalid_token\", error_description=\"The authentication failed because of missing 'Authorization' header.\""; + private static readonly Challenge ParsedArmChallenge = new Challenge() + { + Scheme = "Bearer", + Parameters = + { + ("authorization_uri", "https://login.windows.net/"), + ("error", "invalid_token"), + ("error_description", "The authentication failed because of missing 'Authorization' header."), + } + }; + + private const string StorageChallenge = "Bearer authorization_uri=https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/authorize resource_id=https://storage.azure.com"; + private static readonly Challenge ParsedStorageChallenge = new Challenge() + { + Scheme = "Bearer", + Parameters = + { + ("authorization_uri", "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/authorize"), + ("resource_id", "https://storage.azure.com"), + } + }; + private static readonly Challenge ParsedMultipleChallenges = new Challenge + { + Scheme = "Bearer", + Parameters = + { + ("authorization_uri", "https://login.windows-ppe.net/"), + ("error", "invalid_token"), + ("error_description", "User session has been revoked"), + ("claims", "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="), + } + }; + private static readonly Dictionary ChallengeStrings = new Dictionary() + { + { "CaeInsufficientClaims", CaeInsufficientClaimsChallenge }, + { "CaeSessionsRevoked", CaeSessionsRevokedClaimsChallenge }, + { "KeyVault", KeyVaultChallenge }, + { "Arm", ArmChallenge }, + { "Storage", StorageChallenge }, + }; + + private static readonly Dictionary ParsedChallenges = new Dictionary() + { + { "CaeInsufficientClaims", ParsedCaeInsufficientClaimsChallenge }, + { "CaeSessionsRevoked", ParsedCaeSessionsRevokedClaimsChallenge }, + { "KeyVault", ParsedKeyVaultChallenge }, + { "Arm", ParsedArmChallenge }, + { "Storage", ParsedStorageChallenge } + }; + + private static readonly List MultipleParsedChallenges = new List() + { + { ParsedCaeInsufficientClaimsChallenge }, + { ParsedCaeSessionsRevokedClaimsChallenge }, + { ParsedKeyVaultChallenge }, + { ParsedArmChallenge }, + }; + + private class Challenge + { + public string Scheme { get; set; } + + public List<(string, string)> Parameters { get; } = new List<(string, string)>(); + } + + [Test] + public void BearerTokenChallengeAuthenticationPolicy_ValidateChallengeParsing([Values("CaeInsufficientClaims", "CaeSessionsRevoked", "KeyVault", "Arm", "Storage")] string challengeKey) + { + var challenge = ChallengeStrings[challengeKey].AsSpan(); + + List parsedChallenges = new List(); + + while (BearerTokenChallengeAuthenticationPolicy.TryGetNextChallenge(ref challenge, out var scheme)) + { + Challenge parsedChallenge = new Challenge(); + + parsedChallenge.Scheme = scheme.ToString(); + + while (BearerTokenChallengeAuthenticationPolicy.TryGetNextParameter(ref challenge, out var key, out var value)) + { + parsedChallenge.Parameters.Add((key.ToString(), value.ToString())); + } + + parsedChallenges.Add(parsedChallenge); + } + + Assert.AreEqual(1, parsedChallenges.Count); + + ValidateParsedChallenge(ParsedChallenges[challengeKey], parsedChallenges[0]); + } + + [Test] + public void BearerTokenChallengeAuthenticationPolicy_ValidateChallengeParsingWithMultipleChallenges() + { + var challenge = string.Join(", ", new[] { CaeInsufficientClaimsChallenge, CaeSessionsRevokedClaimsChallenge, KeyVaultChallenge, ArmChallenge }).AsSpan(); + + List parsedChallenges = new List(); + + while (BearerTokenChallengeAuthenticationPolicy.TryGetNextChallenge(ref challenge, out var scheme)) + { + Challenge parsedChallenge = new Challenge(); + + parsedChallenge.Scheme = scheme.ToString(); + + while (BearerTokenChallengeAuthenticationPolicy.TryGetNextParameter(ref challenge, out var key, out var value)) + { + parsedChallenge.Parameters.Add((key.ToString(), value.ToString())); + } + + parsedChallenges.Add(parsedChallenge); + } + + Assert.AreEqual(MultipleParsedChallenges.Count, parsedChallenges.Count); + + for (int i = 0; i < parsedChallenges.Count; i++) + { + ValidateParsedChallenge(MultipleParsedChallenges[i], parsedChallenges[i]); + } + } + + [Test] + public async Task BearerTokenChallengeAuthenticationPolicy_ValidateClaimsChallengeTokenRequest() + { + string currentClaimChallenge = null; + + int tokensRequested = 0; + + var credential = new TokenCredentialStub((r, c) => + { + tokensRequested++; + + Assert.AreEqual(currentClaimChallenge, r.Claims); + + return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow + TimeSpan.FromDays(1)); + }, IsAsync); + + var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope"); + + var insufficientClaimsChallengeResponse = new MockResponse(401); + + insufficientClaimsChallengeResponse.AddHeader(new HttpHeader("WWW-Authenticate", CaeInsufficientClaimsChallenge)); + + var sessionRevokedChallengeResponse = new MockResponse(401); + + sessionRevokedChallengeResponse.AddHeader(new HttpHeader("WWW-Authenticate", CaeSessionsRevokedClaimsChallenge)); + + var armChallengeResponse = new MockResponse(401); + + armChallengeResponse.AddHeader(new HttpHeader("WWW-Authenticate", ArmChallenge)); + + var keyvaultChallengeResponse = new MockResponse(401); + + keyvaultChallengeResponse.AddHeader(new HttpHeader("WWW-Authenticate", KeyVaultChallenge)); + + MockTransport transport = CreateMockTransport(new MockResponse(200), + insufficientClaimsChallengeResponse, + new MockResponse(200), + sessionRevokedChallengeResponse, + new MockResponse(200), + armChallengeResponse, + keyvaultChallengeResponse); + + var response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default); + + Assert.AreEqual(tokensRequested, 1); + + Assert.AreEqual(response.Status, 200); + + currentClaimChallenge = Base64Url.DecodeString(CaeInsufficientClaimsChallengeValue); + + response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default); + + Assert.AreEqual(tokensRequested, 2); + + Assert.AreEqual(response.Status, 200); + + currentClaimChallenge = Base64Url.DecodeString(CaeSessionsRevokedClaimsChallengeValue); + + response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default); + + Assert.AreEqual(tokensRequested, 3); + + Assert.AreEqual(response.Status, 200); + + currentClaimChallenge = null; + + response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default); + + Assert.AreEqual(tokensRequested, 3); + + Assert.AreEqual(response.Status, 401); + + response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default); + + Assert.AreEqual(tokensRequested, 3); + + Assert.AreEqual(response.Status, 401); + } + + private void ValidateParsedChallenge(Challenge expected, Challenge actual) + { + Assert.AreEqual(expected.Scheme, actual.Scheme); + + CollectionAssert.AreEquivalent(expected.Parameters, actual.Parameters); + } + + private class TokenCredentialStub : TokenCredential + { + public TokenCredentialStub(Func handler, bool isAsync) + { + if (isAsync) + { +#pragma warning disable 1998 + _getTokenAsyncHandler = async (r, c) => handler(r, c); +#pragma warning restore 1998 + } + else + { + _getTokenHandler = handler; + } + } + + private readonly Func> _getTokenAsyncHandler; + private readonly Func _getTokenHandler; + + public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + => _getTokenAsyncHandler(requestContext, cancellationToken); + + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) + => _getTokenHandler(requestContext, cancellationToken); + } + } +}