diff --git a/sdk/core/Azure.Core/src/Shared/BearerTokenChallengeAuthenticationPolicy.cs b/sdk/core/Azure.Core/src/Shared/BearerTokenChallengeAuthenticationPolicy.cs
new file mode 100644
index 0000000000000..f3a85b89fbc0b
--- /dev/null
+++ b/sdk/core/Azure.Core/src/Shared/BearerTokenChallengeAuthenticationPolicy.cs
@@ -0,0 +1,480 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.Core.Diagnostics;
+
+#nullable enable
+
+namespace Azure.Core.Pipeline
+{
+ ///
+ /// A policy that sends an provided by a as an Authentication header.
+ /// Note: This class is currently in preview and is therefore subject to possible future breaking changes.
+ ///
+ internal class BearerTokenChallengeAuthenticationPolicy : HttpPipelinePolicy
+ {
+ private const string ChallengeHeader = "WWW-Authenticate";
+ private readonly AccessTokenCache _accessTokenCache;
+ private string[] _scopes;
+
+ ///
+ /// Creates a new instance of using provided token credential and scope to authenticate for.
+ ///
+ /// The token credential to use for authentication.
+ /// The scope to authenticate for.
+ public BearerTokenChallengeAuthenticationPolicy(TokenCredential credential, string scope) : this(credential, new[] { scope }) { }
+
+ ///
+ /// Creates a new instance of using provided token credential and scopes to authenticate for.
+ ///
+ /// The token credential to use for authentication.
+ /// Scopes to authenticate for.
+ public BearerTokenChallengeAuthenticationPolicy(TokenCredential credential, IEnumerable scopes)
+ : this(credential, scopes, TimeSpan.FromMinutes(5), TimeSpan.FromSeconds(30)) { }
+
+ internal BearerTokenChallengeAuthenticationPolicy(TokenCredential credential, IEnumerable scopes, TimeSpan tokenRefreshOffset, TimeSpan tokenRefreshRetryDelay)
+ {
+ Argument.AssertNotNull(credential, nameof(credential));
+ Argument.AssertNotNull(scopes, nameof(scopes));
+
+ _scopes = scopes.ToArray();
+ _accessTokenCache = new AccessTokenCache(credential, tokenRefreshOffset, tokenRefreshRetryDelay, scopes.ToArray());
+ }
+
+ ///
+ public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline)
+ {
+ return ProcessAsync(message, pipeline, true);
+ }
+
+ ///
+ public override void Process(HttpMessage message, ReadOnlyMemory pipeline)
+ {
+ ProcessAsync(message, pipeline, false).EnsureCompleted();
+ }
+
+ ///
+ /// Executed in the event a 401 response with a WWW-Authenticate authentication challenge header is received after the initial request.
+ ///
+ /// This implementation handles common authentication challenges such as claims challenges. Service client libraries may derive from this and extend to handle service specific authentication challenges.
+ /// The to be authenticated.
+ /// If the return value is true, a .
+ /// A boolean indicated whether the request contained a valid challenge and a was successfully initialized with it.
+ protected virtual bool TryGetTokenRequestContextFromChallenge(HttpMessage message, out TokenRequestContext context)
+ {
+ context = default;
+
+ var claimsChallenge = GetClaimsChallenge(message.Response);
+
+ if (claimsChallenge != null)
+ {
+ context = new TokenRequestContext(_scopes, message.Request.ClientRequestId, claimsChallenge);
+ return true;
+ }
+
+ return false;
+ }
+
+ private async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline, bool async)
+ {
+ if (message.Request.Uri.Scheme != Uri.UriSchemeHttps)
+ {
+ throw new InvalidOperationException("Bearer token authentication is not permitted for non TLS protected (https) endpoints.");
+ }
+
+ TokenRequestContext context;
+
+ // If the message already has a challenge response due to a sub-class pre-processing the request, get the context from the challenge.
+ if (message.HasResponse && message.Response.Status == (int)HttpStatusCode.Unauthorized && message.Response.Headers.Contains(ChallengeHeader))
+ {
+ if (!TryGetTokenRequestContextFromChallenge(message, out context))
+ {
+ // We were unsuccessful in handling the challenge, so bail out now.
+ return;
+ }
+ _scopes = context.Scopes;
+ }
+ else
+ {
+ context = new TokenRequestContext(_scopes, message.Request.ClientRequestId);
+ }
+
+ await AuthenticateRequestAsync(message, context, async).ConfigureAwait(false);
+
+ if (async)
+ {
+ await ProcessNextAsync(message, pipeline).ConfigureAwait(false);
+ }
+ else
+ {
+ ProcessNext(message, pipeline);
+ }
+
+ // Check if we have received a challenge or we have not yet issued the first request.
+ if (message.Response.Status == (int)HttpStatusCode.Unauthorized && message.Response.Headers.Contains(ChallengeHeader))
+ {
+ // Attempt to get the TokenRequestContext based on the challenge.
+ // If we fail to get the context, the challenge was not present or invalid.
+ // If we succeed in getting the context, authenticate the request and pass it up the policy chain.
+ if (TryGetTokenRequestContextFromChallenge(message, out context))
+ {
+ // Ensure the scopes are consistent with what was set by .
+ _scopes = context.Scopes;
+
+ await AuthenticateRequestAsync(message, context, async).ConfigureAwait(false);
+
+ if (async)
+ {
+ await ProcessNextAsync(message, pipeline).ConfigureAwait(false);
+ }
+ else
+ {
+ ProcessNext(message, pipeline);
+ }
+ }
+ }
+ }
+
+ private async Task AuthenticateRequestAsync(HttpMessage message, TokenRequestContext context, bool async)
+ {
+ string headerValue;
+ if (async)
+ {
+ headerValue = await _accessTokenCache.GetHeaderValueAsync(message, context, async).ConfigureAwait(false);
+ }
+ else
+ {
+ headerValue = _accessTokenCache.GetHeaderValueAsync(message, context, async).EnsureCompleted();
+ }
+
+ message.Request.SetHeader(HttpHeader.Names.Authorization, headerValue);
+ }
+
+ private static string? GetClaimsChallenge(Response response)
+ {
+ if (response.Status != (int)HttpStatusCode.Unauthorized || !response.Headers.TryGetValue(ChallengeHeader, out string? headerValue))
+ {
+ return null;
+ }
+
+ ReadOnlySpan bearer = "Bearer".AsSpan();
+ ReadOnlySpan claims = "claims".AsSpan();
+ ReadOnlySpan headerSpan = headerValue.AsSpan();
+
+ // Iterate through each challenge value.
+ while (TryGetNextChallenge(ref headerSpan, out var challengeKey))
+ {
+ // Enumerate each key=value parameter until we find the 'claims' key on the 'Bearer' challenge.
+ while (TryGetNextParameter(ref headerSpan, out var key, out var value))
+ {
+ if (challengeKey.Equals(bearer, StringComparison.OrdinalIgnoreCase) && key.Equals(claims, StringComparison.OrdinalIgnoreCase))
+ {
+ return Base64Url.DecodeString(value.ToString());
+ }
+ }
+ }
+
+ return null;
+ }
+
+ ///
+ /// Iterates through the challenge schemes present in a challenge header.
+ ///
+ ///
+ /// The header value which will be sliced to remove the first parsed .
+ ///
+ /// The parsed challenge scheme.
+ ///
+ /// true if a challenge scheme was successfully parsed.
+ /// The value of should be passed to to parse the challenge parameters if true.
+ ///
+ internal static bool TryGetNextChallenge(ref ReadOnlySpan headerValue, out ReadOnlySpan challengeKey)
+ {
+ challengeKey = default;
+
+ headerValue = headerValue.TrimStart(' ');
+ int endOfChallengeKey = headerValue.IndexOf(' ');
+
+ if (endOfChallengeKey < 0)
+ {
+ return false;
+ }
+
+ challengeKey = headerValue.Slice(0, endOfChallengeKey);
+
+ // Slice the challenge key from the headerValue
+ headerValue = headerValue.Slice(endOfChallengeKey + 1);
+
+ return true;
+ }
+
+ ///
+ /// Iterates through a challenge header value after being parsed by .
+ ///
+ /// The header value after being parsed by .
+ /// The parsed challenge parameter key.
+ /// The parsed challenge parameter value.
+ /// The challenge parameter key / value pair separator. The default is '='.
+ ///
+ /// true if the next available challenge parameter was successfully parsed.
+ /// false if there are no more parameters for the current challenge scheme or an additional challenge scheme was encountered in the .
+ /// The value of should be passed again to to attempt to parse any additional challenge schemes if false.
+ ///
+ internal static bool TryGetNextParameter(ref ReadOnlySpan headerValue, out ReadOnlySpan paramKey, out ReadOnlySpan paramValue, char separator = '=')
+ {
+ paramKey = default;
+ paramValue = default;
+ var spaceOrComma = " ,".AsSpan();
+
+ // Trim any separater prefixes.
+ headerValue = headerValue.TrimStart(spaceOrComma);
+
+ int nextSpace = headerValue.IndexOf(' ');
+ int nextSeparator = headerValue.IndexOf(separator);
+
+ if (nextSpace < nextSeparator && nextSpace != -1)
+ {
+ // we encountered another challenge value.
+ return false;
+ }
+
+ if (nextSeparator < 0)
+ return false;
+
+ // Get the paramKey.
+ paramKey = headerValue.Slice(0, nextSeparator).Trim();
+
+ // Slice to remove the 'paramKey=' from the parameters.
+ headerValue = headerValue.Slice(nextSeparator + 1);
+
+ // The start of paramValue will usually be a quoted string. Find the first quote.
+ int quoteIndex = headerValue.IndexOf('\"');
+
+ // Get the paramValue, which is delimited by the trailing quote.
+ headerValue = headerValue.Slice(quoteIndex + 1);
+ if (quoteIndex >= 0)
+ {
+ // The values are quote wrapped
+ paramValue = headerValue.Slice(0, headerValue.IndexOf('\"'));
+ }
+ else
+ {
+ //the values are not quote wrapped (storage is one example of this)
+ // either find the next space indicating the delimiter to the next value, or go to the end since this is the last value.
+ int trailingDelimiterIndex = headerValue.IndexOfAny(spaceOrComma);
+ if (trailingDelimiterIndex >= 0)
+ {
+ paramValue = headerValue.Slice(0, trailingDelimiterIndex);
+ }
+ else
+ {
+ paramValue = headerValue;
+ }
+ }
+
+ // Slice to remove the '"paramValue"' from the parameters.
+ if (headerValue != paramValue)
+ headerValue = headerValue.Slice(paramValue.Length + 1);
+
+ return true;
+ }
+
+ private class AccessTokenCache
+ {
+ private readonly object _syncObj = new object();
+ private readonly TokenCredential _credential;
+ private readonly TimeSpan _tokenRefreshOffset;
+ private readonly TimeSpan _tokenRefreshRetryDelay;
+
+ private TokenRequestContext? _currentContext;
+ private TaskCompletionSource? _infoTcs;
+ private TaskCompletionSource? _backgroundUpdateTcs;
+ public AccessTokenCache(TokenCredential credential, TimeSpan tokenRefreshOffset, TimeSpan tokenRefreshRetryDelay, string[] initialScopes)
+ {
+ _credential = credential;
+ _tokenRefreshOffset = tokenRefreshOffset;
+ _tokenRefreshRetryDelay = tokenRefreshRetryDelay;
+ _currentContext = new TokenRequestContext(initialScopes);
+ }
+
+ public async ValueTask GetHeaderValueAsync(HttpMessage message, TokenRequestContext context, bool async)
+ {
+ bool getTokenFromCredential;
+ TaskCompletionSource headerValueTcs;
+ TaskCompletionSource? backgroundUpdateTcs;
+ (headerValueTcs, backgroundUpdateTcs, getTokenFromCredential) = GetTaskCompletionSources(context);
+ HeaderValueInfo info;
+
+ if (getTokenFromCredential)
+ {
+ if (backgroundUpdateTcs != null)
+ {
+ if (async)
+ {
+ info = await headerValueTcs.Task.ConfigureAwait(false);
+ }
+ else
+ {
+#pragma warning disable AZC0104 // Use EnsureCompleted() directly on asynchronous method return value.
+ info = headerValueTcs.Task.EnsureCompleted();
+#pragma warning restore AZC0104 // Use EnsureCompleted() directly on asynchronous method return value.
+ }
+ _ = Task.Run(() => GetHeaderValueFromCredentialInBackgroundAsync(backgroundUpdateTcs, info, context, async));
+ return info.HeaderValue;
+ }
+
+ try
+ {
+ info = await GetHeaderValueFromCredentialAsync(context, async, message.CancellationToken).ConfigureAwait(false);
+ headerValueTcs.SetResult(info);
+ }
+ catch (OperationCanceledException)
+ {
+ headerValueTcs.SetCanceled();
+ throw;
+ }
+ catch (Exception exception)
+ {
+ headerValueTcs.SetException(exception);
+ throw;
+ }
+ }
+
+ var headerValueTask = headerValueTcs.Task;
+ if (!headerValueTask.IsCompleted)
+ {
+ if (async)
+ {
+ await headerValueTask.AwaitWithCancellation(message.CancellationToken);
+ }
+ else
+ {
+ try
+ {
+ headerValueTask.Wait(message.CancellationToken);
+ }
+ catch (AggregateException) { } // ignore exception here to rethrow it with EnsureCompleted
+ }
+ }
+ if (async)
+ {
+ info = await headerValueTcs.Task.ConfigureAwait(false);
+ }
+ else
+ {
+#pragma warning disable AZC0104 // Use EnsureCompleted() directly on asynchronous method return value.
+ info = headerValueTcs.Task.EnsureCompleted();
+#pragma warning restore AZC0104 // Use EnsureCompleted() directly on asynchronous method return value.
+ }
+
+ return info.HeaderValue;
+ }
+
+ private (TaskCompletionSource, TaskCompletionSource?, bool) GetTaskCompletionSources(TokenRequestContext context)
+ {
+ lock (_syncObj)
+ {
+ // Initial state. GetTaskCompletionSources has been called for the first time
+ if (_infoTcs == null || RequestRequiresNewToken(context))
+ {
+ _currentContext = context;
+ _infoTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ _backgroundUpdateTcs = default;
+ return (_infoTcs, _backgroundUpdateTcs, true);
+ }
+
+ // Getting new access token is in progress, wait for it
+ if (!_infoTcs.Task.IsCompleted)
+ {
+ _backgroundUpdateTcs = default;
+ return (_infoTcs, _backgroundUpdateTcs, false);
+ }
+
+ DateTimeOffset now = DateTimeOffset.UtcNow;
+ // Access token has been successfully acquired in background and it is not expired yet, use it instead of current one
+ if (_backgroundUpdateTcs != null && _backgroundUpdateTcs.Task.Status == TaskStatus.RanToCompletion && _backgroundUpdateTcs.Task.Result.ExpiresOn > now)
+ {
+ _infoTcs = _backgroundUpdateTcs;
+ _backgroundUpdateTcs = default;
+ }
+
+ // Attempt to get access token has failed or it has already expired. Need to get a new one
+ if (_infoTcs.Task.Status != TaskStatus.RanToCompletion || now >= _infoTcs.Task.Result.ExpiresOn)
+ {
+ _infoTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ return (_infoTcs, default, true);
+ }
+
+ // Access token is still valid but is about to expire, try to get it in background
+ if (now >= _infoTcs.Task.Result.RefreshOn && _backgroundUpdateTcs == null)
+ {
+ _backgroundUpdateTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ return (_infoTcs, _backgroundUpdateTcs, true);
+ }
+
+ // Access token is valid, use it
+ return (_infoTcs, default, false);
+ }
+ }
+
+ // must be called under lock (_syncObj)
+ private bool RequestRequiresNewToken(TokenRequestContext context) =>
+ _currentContext == null ||
+ (context.Scopes != null && !context.Scopes.SequenceEqual(_currentContext.Value.Scopes)) ||
+ (context.Claims != null && !string.Equals(context.Claims, _currentContext.Value.Claims));
+
+ private async ValueTask GetHeaderValueFromCredentialInBackgroundAsync(TaskCompletionSource backgroundUpdateTcs, HeaderValueInfo info, TokenRequestContext context, bool async)
+ {
+ var cts = new CancellationTokenSource(_tokenRefreshRetryDelay);
+ try
+ {
+ HeaderValueInfo newInfo = await GetHeaderValueFromCredentialAsync(context, async, cts.Token).ConfigureAwait(false);
+ backgroundUpdateTcs.SetResult(newInfo);
+ }
+ catch (OperationCanceledException oce) when (cts.IsCancellationRequested)
+ {
+ backgroundUpdateTcs.SetResult(new HeaderValueInfo(info.HeaderValue, info.ExpiresOn, DateTimeOffset.UtcNow));
+ AzureCoreEventSource.Singleton.BackgroundRefreshFailed(context.ParentRequestId ?? string.Empty, oce.ToString());
+ }
+ catch (Exception e)
+ {
+ backgroundUpdateTcs.SetResult(new HeaderValueInfo(info.HeaderValue, info.ExpiresOn, DateTimeOffset.UtcNow + _tokenRefreshRetryDelay));
+ AzureCoreEventSource.Singleton.BackgroundRefreshFailed(context.ParentRequestId ?? string.Empty, e.ToString());
+ }
+ finally
+ {
+ cts.Dispose();
+ }
+ }
+
+ private async ValueTask GetHeaderValueFromCredentialAsync(TokenRequestContext context, bool async, CancellationToken cancellationToken)
+ {
+ AccessToken token = async
+ ? await _credential.GetTokenAsync(context, cancellationToken).ConfigureAwait(false)
+ : _credential.GetToken(context, cancellationToken);
+
+ return new HeaderValueInfo("Bearer " + token.Token, token.ExpiresOn, token.ExpiresOn - _tokenRefreshOffset);
+ }
+
+ private readonly struct HeaderValueInfo
+ {
+ public string HeaderValue { get; }
+ public DateTimeOffset ExpiresOn { get; }
+ public DateTimeOffset RefreshOn { get; }
+
+ public HeaderValueInfo(string headerValue, DateTimeOffset expiresOn, DateTimeOffset refreshOn)
+ {
+ HeaderValue = headerValue;
+ ExpiresOn = expiresOn;
+ RefreshOn = refreshOn;
+ }
+ }
+ }
+ }
+}
diff --git a/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj b/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj
index 19c3058fc5cb8..593a7415302f5 100644
--- a/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj
+++ b/sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj
@@ -24,6 +24,7 @@
+
diff --git a/sdk/core/Azure.Core/tests/BearerTokenChallengeAuthenticationPolicyTests.cs b/sdk/core/Azure.Core/tests/BearerTokenChallengeAuthenticationPolicyTests.cs
new file mode 100644
index 0000000000000..a9bbf035b6885
--- /dev/null
+++ b/sdk/core/Azure.Core/tests/BearerTokenChallengeAuthenticationPolicyTests.cs
@@ -0,0 +1,815 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.Core.Pipeline;
+using Azure.Core.TestFramework;
+using Moq;
+using NUnit.Framework;
+
+namespace Azure.Core.Tests
+{
+ public class BearerTokenChallengeAuthenticationPolicyTests : SyncAsyncPolicyTestBase
+ {
+ public BearerTokenChallengeAuthenticationPolicyTests(bool isAsync) : base(isAsync) { }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_UsesTokenProvidedByCredentials()
+ {
+ var credential = new TokenCredentialStub(
+ (r, c) => r.Scopes.SequenceEqual(new[] { "scope1", "scope2" }) ? new AccessToken("token", DateTimeOffset.MaxValue) : default, IsAsync);
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope1", "scope2" });
+
+ MockTransport transport = CreateMockTransport(new MockResponse(200));
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+
+ Assert.True(transport.SingleRequest.Headers.TryGetValue("Authorization", out string authValue));
+ Assert.AreEqual("Bearer token", authValue);
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_RequestsTokenEveryRequest()
+ {
+ var accessTokens = new Queue();
+ accessTokens.Enqueue(new AccessToken("token1", DateTimeOffset.UtcNow));
+ accessTokens.Enqueue(new AccessToken("token2", DateTimeOffset.UtcNow));
+
+ var credential = new TokenCredentialStub(
+ (r, c) => r.Scopes.SequenceEqual(new[] { "scope1", "scope2" }) ? accessTokens.Dequeue() : default, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope1", "scope2" });
+ MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200));
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value));
+ Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value));
+
+ Assert.AreEqual("Bearer token1", auth1Value);
+ Assert.AreEqual("Bearer token2", auth2Value);
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_CachesHeaderValue()
+ {
+ var credential = new TokenCredentialStub(
+ (r, c) => r.Scopes.SequenceEqual(new[] { "scope" }) ? new AccessToken("token", DateTimeOffset.MaxValue) : default, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200));
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value));
+ Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value));
+
+ Assert.AreSame(auth1Value, auth1Value);
+ Assert.AreEqual("Bearer token", auth2Value);
+ }
+
+ [Test]
+ public void BearerTokenChallengeAuthenticationPolicy_ThrowsForNonTlsEndpoint()
+ {
+ var credential = new TokenCredentialStub(
+ (r, c) => r.Scopes.SequenceEqual(new[] { "scope" }) ? new AccessToken("token", DateTimeOffset.MaxValue) : default, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport();
+
+ Assert.ThrowsAsync(async () => await SendGetRequest(transport, policy, uri: new Uri("http://example.com")));
+ }
+
+ [Test]
+ public void BearerTokenChallengeAuthenticationPolicy_ThrowsForEmptyToken()
+ {
+ var credential = new TokenCredentialStub((r, c) => new AccessToken(string.Empty, DateTimeOffset.MaxValue), IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport();
+
+ Assert.ThrowsAsync(async () => await SendGetRequest(transport, policy, uri: new Uri("http://example.com")));
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_OneHundredConcurrentCalls()
+ {
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ Thread.Sleep(100);
+ return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddMinutes(30));
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport(r => new MockResponse(200));
+ var requestTasks = new Task[100];
+
+ for (int i = 0; i < requestTasks.Length; i++)
+ {
+ requestTasks[i] = SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+ }
+
+ await Task.WhenAll(requestTasks);
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value));
+
+ for (int i = 1; i < requestTasks.Length; i++)
+ {
+ Assert.True(transport.Requests[i].Headers.TryGetValue("Authorization", out string authValue));
+ Assert.AreEqual(auth1Value, authValue);
+ }
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_GatedConcurrentCalls()
+ {
+ var requestMre = new ManualResetEventSlim(false);
+ var responseMre = new ManualResetEventSlim(false);
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ requestMre.Set();
+ responseMre.Wait(c);
+ return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddMinutes(30));
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200));
+
+ var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+ requestMre.Wait();
+
+ var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+ responseMre.Set();
+
+ await Task.WhenAll(firstRequestTask, secondRequestTask);
+
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value));
+ Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value));
+
+ Assert.AreEqual(auth1Value, auth2Value);
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_SucceededFailedSucceeded()
+ {
+ var requestMre = new ManualResetEventSlim(false);
+ var callCount = 0;
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ Interlocked.Increment(ref callCount);
+ var offsetTime = DateTimeOffset.UtcNow;
+ requestMre.Set();
+
+ return callCount == 2
+ ? throw new InvalidOperationException("Call Failed")
+ : new AccessToken(Guid.NewGuid().ToString(), offsetTime.AddMilliseconds(1000));
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope" }, TimeSpan.FromMilliseconds(100), TimeSpan.FromSeconds(30));
+ MockTransport transport = CreateMockTransport(r => new MockResponse(200));
+
+ var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/1"));
+ var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/2"));
+
+ requestMre.Wait();
+ await Task.Delay(200);
+
+ await Task.WhenAll(firstRequestTask, secondRequestTask);
+ await Task.Delay(1000);
+
+ Assert.AreEqual(1, callCount);
+ requestMre.Reset();
+
+ var failedTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/3/failed"));
+ requestMre.Wait();
+
+ Assert.AreEqual(2, callCount);
+ Assert.ThrowsAsync(async () => await failedTask);
+
+ requestMre.Reset();
+
+ firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/4"));
+ secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/5"));
+
+ requestMre.Wait();
+
+ await Task.WhenAll(firstRequestTask, secondRequestTask);
+
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value));
+ Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value));
+ Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth3Value));
+ Assert.True(transport.Requests[3].Headers.TryGetValue("Authorization", out string auth4Value));
+
+ Assert.AreEqual(3, callCount);
+ Assert.AreEqual(auth1Value, auth2Value);
+ Assert.AreNotEqual(auth2Value, auth3Value);
+ Assert.AreEqual(auth3Value, auth4Value);
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_TokenAlmostExpired()
+ {
+ var requestMre = new ManualResetEventSlim(true);
+ var responseMre = new ManualResetEventSlim(true);
+ var currentTime = DateTimeOffset.UtcNow;
+ var expires = new Queue(new[] { currentTime.AddMinutes(2), currentTime.AddMinutes(30) });
+ var callCount = 0;
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ requestMre.Set();
+ responseMre.Wait(c);
+ requestMre.Reset();
+ callCount++;
+
+ return new AccessToken(Guid.NewGuid().ToString(), expires.Dequeue());
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200), new MockResponse(200), new MockResponse(200));
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/1/Original"));
+ responseMre.Reset();
+
+ Task requestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/3/Refresh"));
+ requestMre.Wait();
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/2/AlmostExpired"));
+ await requestTask;
+ responseMre.Set();
+ await Task.Delay(1_000);
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/4/AfterRefresh"));
+
+ Assert.AreEqual(2, callCount);
+
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value));
+ Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value));
+ Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth3Value));
+ Assert.True(transport.Requests[3].Headers.TryGetValue("Authorization", out string auth4Value));
+
+ Assert.AreEqual(auth1Value, auth2Value);
+ Assert.AreEqual(auth2Value, auth3Value);
+ Assert.AreNotEqual(auth3Value, auth4Value);
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_TokenAlmostExpired_NoRefresh()
+ {
+ var requestMre = new ManualResetEventSlim(true);
+ var responseMre = new ManualResetEventSlim(true);
+ var currentTime = DateTimeOffset.UtcNow;
+ var callCount = 0;
+
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ callCount++;
+ responseMre.Wait(c);
+ requestMre.Set();
+
+ return new AccessToken(Guid.NewGuid().ToString(), currentTime.AddMinutes(2));
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200), new MockResponse(200), new MockResponse(200));
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/1/Original"));
+ requestMre.Wait();
+ responseMre.Reset();
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/2/AlmostExpired"));
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/3/AlmostExpired"));
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/4/AlmostExpired"));
+
+ requestMre.Reset();
+ responseMre.Set();
+ requestMre.Wait();
+
+ Assert.AreEqual(2, callCount);
+
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value));
+ Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value));
+ Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth3Value));
+ Assert.True(transport.Requests[3].Headers.TryGetValue("Authorization", out string auth4Value));
+
+ Assert.AreEqual(auth1Value, auth2Value);
+ Assert.AreEqual(auth2Value, auth3Value);
+ Assert.AreEqual(auth3Value, auth4Value);
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_TokenExpired()
+ {
+ var requestMre = new ManualResetEventSlim(true);
+ var responseMre = new ManualResetEventSlim(true);
+ var currentTime = DateTimeOffset.UtcNow;
+ var expires = new Queue(new[] { currentTime.AddSeconds(2), currentTime.AddMinutes(30) });
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ requestMre.Set();
+ responseMre.Wait(c);
+ return new AccessToken(Guid.NewGuid().ToString(), expires.Dequeue());
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope" }, TimeSpan.FromSeconds(2), TimeSpan.FromMilliseconds(50));
+ MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200), new MockResponse(200));
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/0"));
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string authValue));
+
+ await Task.Delay(3_000);
+
+ requestMre.Reset();
+ responseMre.Reset();
+
+ var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/1"));
+ var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com/2"));
+ requestMre.Wait();
+ await Task.Delay(1_000);
+ responseMre.Set();
+
+ await Task.WhenAll(firstRequestTask, secondRequestTask);
+
+ Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth1Value));
+ Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth2Value));
+
+ Assert.AreNotEqual(authValue, auth1Value);
+ Assert.AreEqual(auth1Value, auth2Value);
+ }
+
+ [Test]
+ public void BearerTokenChallengeAuthenticationPolicy_OneHundredConcurrentCallsFailed()
+ {
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ Thread.Sleep(100);
+ throw new InvalidOperationException("Error");
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport(r => new MockResponse(200));
+ var requestTasks = new Task[100];
+
+ for (int i = 0; i < requestTasks.Length; i++)
+ {
+ requestTasks[i] = SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+ }
+
+ Assert.CatchAsync(async () => await Task.WhenAll(requestTasks));
+
+ foreach (Task task in requestTasks)
+ {
+ Assert.IsTrue(task.IsFaulted);
+ }
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_GatedConcurrentCallsFailed()
+ {
+ var requestMre = new ManualResetEventSlim(false);
+ var responseMre = new ManualResetEventSlim(false);
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ requestMre.Set();
+ responseMre.Wait(c);
+ throw new InvalidOperationException("Error");
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200));
+
+ var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+ var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+
+ requestMre.Wait();
+ await Task.Delay(1_000);
+ responseMre.Set();
+
+ Assert.CatchAsync(async () => await Task.WhenAll(firstRequestTask, secondRequestTask));
+
+ Assert.IsTrue(firstRequestTask.IsFaulted);
+ Assert.IsTrue(secondRequestTask.IsFaulted);
+ Assert.AreEqual(firstRequestTask.Exception.InnerException, secondRequestTask.Exception.InnerException);
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_TokenExpiredThenFailed()
+ {
+ var requestMre = new ManualResetEventSlim(true);
+ var responseMre = new ManualResetEventSlim(true);
+ var fail = false;
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ requestMre.Set();
+ responseMre.Wait(c);
+ if (fail)
+ {
+ throw new InvalidOperationException("Error");
+ }
+
+ fail = true;
+ return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddSeconds(2));
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope" }, TimeSpan.FromSeconds(2), TimeSpan.FromMilliseconds(50));
+ MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200), new MockResponse(200));
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/0"));
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string _));
+
+ await Task.Delay(3_000);
+
+ requestMre.Reset();
+ responseMre.Reset();
+
+ var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+ var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"));
+
+ requestMre.Wait();
+ await Task.Delay(1_000);
+ responseMre.Set();
+
+ Assert.CatchAsync(async () => await Task.WhenAll(firstRequestTask, secondRequestTask));
+
+ Assert.IsTrue(firstRequestTask.IsFaulted);
+ Assert.IsTrue(secondRequestTask.IsFaulted);
+ Assert.AreEqual(firstRequestTask.Exception.InnerException, secondRequestTask.Exception.InnerException);
+ }
+
+ [Test]
+ [Ignore("https://github.com/Azure/azure-sdk-for-net/issues/14612")]
+ public async Task BearerTokenChallengeAuthenticationPolicy_TokenAlmostExpiredThenFailed()
+ {
+ var requestMre = new ManualResetEventSlim(true);
+ var responseMre = new ManualResetEventSlim(true);
+ var credentialMre = new ManualResetEventSlim(false);
+
+ var getTokenRequestTimes = new List();
+ var transportCallCount = 0;
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ if (transportCallCount > 0)
+ {
+ credentialMre.Set();
+ getTokenRequestTimes.Add(DateTimeOffset.UtcNow);
+ throw new InvalidOperationException("Error");
+ }
+
+ return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddMinutes(1.5));
+ }, IsAsync);
+
+ var tokenRefreshRetryDelay = TimeSpan.FromSeconds(2);
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, new[] { "scope" }, TimeSpan.FromMinutes(2), tokenRefreshRetryDelay);
+ MockTransport transport = CreateMockTransport(r =>
+ {
+ requestMre.Set();
+ responseMre.Wait();
+ if (Interlocked.Increment(ref transportCallCount) == 4)
+ {
+ credentialMre.Wait();
+ }
+ return new MockResponse(200);
+ });
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/1"));
+ Assert.True(transport.Requests[0].Headers.TryGetValue("Authorization", out string auth1Value));
+
+ requestMre.Reset();
+ responseMre.Reset();
+
+ Task requestTask1 = SendGetRequest(transport, policy, uri: new Uri("https://example.com/2/TokenFromCache/RefreshInBackground"));
+ Task requestTask2 = SendGetRequest(transport, policy, uri: new Uri("https://example.com/3/TokenFromCache/"));
+
+ requestMre.Wait();
+ responseMre.Set();
+
+ await Task.WhenAll(requestTask1, requestTask2);
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/4/TokenFromCache"));
+
+ await Task.Delay((int)tokenRefreshRetryDelay.TotalMilliseconds + 1_000);
+ credentialMre.Reset();
+
+ await SendGetRequest(transport, policy, uri: new Uri("https://example.com/5/TokenFromCache/GetTokenFailed"));
+ credentialMre.Wait();
+
+ Assert.True(transport.Requests[1].Headers.TryGetValue("Authorization", out string auth2Value));
+ Assert.True(transport.Requests[2].Headers.TryGetValue("Authorization", out string auth3Value));
+ Assert.True(transport.Requests[3].Headers.TryGetValue("Authorization", out string auth4Value));
+ Assert.True(transport.Requests[4].Headers.TryGetValue("Authorization", out string auth5Value));
+
+ Assert.AreEqual(auth1Value, auth2Value);
+ Assert.AreEqual(auth2Value, auth3Value);
+ Assert.AreEqual(auth3Value, auth4Value);
+ Assert.AreEqual(auth4Value, auth5Value);
+
+ Assert.AreEqual(2, getTokenRequestTimes.Count);
+ Assert.True(getTokenRequestTimes[1] - getTokenRequestTimes[0] > tokenRefreshRetryDelay);
+ }
+
+ [Test]
+ public void BearerTokenChallengeAuthenticationPolicy_GatedConcurrentCallsCancelled()
+ {
+ var requestMre = new ManualResetEventSlim(false);
+ var responseMre = new ManualResetEventSlim(false);
+ var cts = new CancellationTokenSource();
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ requestMre.Set();
+ responseMre.Wait(c);
+ throw new InvalidOperationException("Error");
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+ MockTransport transport = CreateMockTransport(new MockResponse(200), new MockResponse(200));
+
+ var firstRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default);
+ requestMre.Wait();
+
+ var secondRequestTask = SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: cts.Token);
+ cts.Cancel();
+
+ Assert.CatchAsync(async () => await secondRequestTask);
+ responseMre.Set();
+
+ Assert.CatchAsync(async () => await firstRequestTask);
+ }
+
+ private const string CaeInsufficientClaimsChallenge = "Bearer realm=\"\", authorization_uri=\"https://login.microsoftonline.com/common/oauth2/authorize\", client_id=\"00000003-0000-0000-c000-000000000000\", error=\"insufficient_claims\", claims=\"eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0=\"";
+ private const string CaeInsufficientClaimsChallengeValue = "eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0=";
+ private static readonly Challenge ParsedCaeInsufficientClaimsChallenge = new Challenge
+ {
+ Scheme = "Bearer",
+ Parameters =
+ {
+ ("realm", ""),
+ ("authorization_uri", "https://login.microsoftonline.com/common/oauth2/authorize"),
+ ("client_id", "00000003-0000-0000-c000-000000000000"),
+ ("error", "insufficient_claims"),
+ ("claims", "eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0="),
+ }
+ };
+
+ private const string CaeSessionsRevokedClaimsChallenge = "Bearer authorization_uri=\"https://login.windows-ppe.net/\", error=\"invalid_token\", error_description=\"User session has been revoked\", claims=\"eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=\"";
+ private const string CaeSessionsRevokedClaimsChallengeValue = "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=";
+ private static readonly Challenge ParsedCaeSessionsRevokedClaimsChallenge = new Challenge
+ {
+ Scheme = "Bearer",
+ Parameters =
+ {
+ ("authorization_uri", "https://login.windows-ppe.net/"),
+ ("error", "invalid_token"),
+ ("error_description", "User session has been revoked"),
+ ("claims", "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="),
+ }
+ };
+
+ private const string KeyVaultChallenge = "Bearer authorization=\"https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47\", resource=\"https://vault.azure.net\"";
+ private static readonly Challenge ParsedKeyVaultChallenge = new Challenge
+ {
+ Scheme = "Bearer",
+ Parameters =
+ {
+ ("authorization", "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47"),
+ ("resource", "https://vault.azure.net"),
+ }
+ };
+
+ private const string ArmChallenge = "Bearer authorization_uri=\"https://login.windows.net/\", error=\"invalid_token\", error_description=\"The authentication failed because of missing 'Authorization' header.\"";
+ private static readonly Challenge ParsedArmChallenge = new Challenge()
+ {
+ Scheme = "Bearer",
+ Parameters =
+ {
+ ("authorization_uri", "https://login.windows.net/"),
+ ("error", "invalid_token"),
+ ("error_description", "The authentication failed because of missing 'Authorization' header."),
+ }
+ };
+
+ private const string StorageChallenge = "Bearer authorization_uri=https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/authorize resource_id=https://storage.azure.com";
+ private static readonly Challenge ParsedStorageChallenge = new Challenge()
+ {
+ Scheme = "Bearer",
+ Parameters =
+ {
+ ("authorization_uri", "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/authorize"),
+ ("resource_id", "https://storage.azure.com"),
+ }
+ };
+ private static readonly Challenge ParsedMultipleChallenges = new Challenge
+ {
+ Scheme = "Bearer",
+ Parameters =
+ {
+ ("authorization_uri", "https://login.windows-ppe.net/"),
+ ("error", "invalid_token"),
+ ("error_description", "User session has been revoked"),
+ ("claims", "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="),
+ }
+ };
+ private static readonly Dictionary ChallengeStrings = new Dictionary()
+ {
+ { "CaeInsufficientClaims", CaeInsufficientClaimsChallenge },
+ { "CaeSessionsRevoked", CaeSessionsRevokedClaimsChallenge },
+ { "KeyVault", KeyVaultChallenge },
+ { "Arm", ArmChallenge },
+ { "Storage", StorageChallenge },
+ };
+
+ private static readonly Dictionary ParsedChallenges = new Dictionary()
+ {
+ { "CaeInsufficientClaims", ParsedCaeInsufficientClaimsChallenge },
+ { "CaeSessionsRevoked", ParsedCaeSessionsRevokedClaimsChallenge },
+ { "KeyVault", ParsedKeyVaultChallenge },
+ { "Arm", ParsedArmChallenge },
+ { "Storage", ParsedStorageChallenge }
+ };
+
+ private static readonly List MultipleParsedChallenges = new List()
+ {
+ { ParsedCaeInsufficientClaimsChallenge },
+ { ParsedCaeSessionsRevokedClaimsChallenge },
+ { ParsedKeyVaultChallenge },
+ { ParsedArmChallenge },
+ };
+
+ private class Challenge
+ {
+ public string Scheme { get; set; }
+
+ public List<(string, string)> Parameters { get; } = new List<(string, string)>();
+ }
+
+ [Test]
+ public void BearerTokenChallengeAuthenticationPolicy_ValidateChallengeParsing([Values("CaeInsufficientClaims", "CaeSessionsRevoked", "KeyVault", "Arm", "Storage")] string challengeKey)
+ {
+ var challenge = ChallengeStrings[challengeKey].AsSpan();
+
+ List parsedChallenges = new List();
+
+ while (BearerTokenChallengeAuthenticationPolicy.TryGetNextChallenge(ref challenge, out var scheme))
+ {
+ Challenge parsedChallenge = new Challenge();
+
+ parsedChallenge.Scheme = scheme.ToString();
+
+ while (BearerTokenChallengeAuthenticationPolicy.TryGetNextParameter(ref challenge, out var key, out var value))
+ {
+ parsedChallenge.Parameters.Add((key.ToString(), value.ToString()));
+ }
+
+ parsedChallenges.Add(parsedChallenge);
+ }
+
+ Assert.AreEqual(1, parsedChallenges.Count);
+
+ ValidateParsedChallenge(ParsedChallenges[challengeKey], parsedChallenges[0]);
+ }
+
+ [Test]
+ public void BearerTokenChallengeAuthenticationPolicy_ValidateChallengeParsingWithMultipleChallenges()
+ {
+ var challenge = string.Join(", ", new[] { CaeInsufficientClaimsChallenge, CaeSessionsRevokedClaimsChallenge, KeyVaultChallenge, ArmChallenge }).AsSpan();
+
+ List parsedChallenges = new List();
+
+ while (BearerTokenChallengeAuthenticationPolicy.TryGetNextChallenge(ref challenge, out var scheme))
+ {
+ Challenge parsedChallenge = new Challenge();
+
+ parsedChallenge.Scheme = scheme.ToString();
+
+ while (BearerTokenChallengeAuthenticationPolicy.TryGetNextParameter(ref challenge, out var key, out var value))
+ {
+ parsedChallenge.Parameters.Add((key.ToString(), value.ToString()));
+ }
+
+ parsedChallenges.Add(parsedChallenge);
+ }
+
+ Assert.AreEqual(MultipleParsedChallenges.Count, parsedChallenges.Count);
+
+ for (int i = 0; i < parsedChallenges.Count; i++)
+ {
+ ValidateParsedChallenge(MultipleParsedChallenges[i], parsedChallenges[i]);
+ }
+ }
+
+ [Test]
+ public async Task BearerTokenChallengeAuthenticationPolicy_ValidateClaimsChallengeTokenRequest()
+ {
+ string currentClaimChallenge = null;
+
+ int tokensRequested = 0;
+
+ var credential = new TokenCredentialStub((r, c) =>
+ {
+ tokensRequested++;
+
+ Assert.AreEqual(currentClaimChallenge, r.Claims);
+
+ return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow + TimeSpan.FromDays(1));
+ }, IsAsync);
+
+ var policy = new BearerTokenChallengeAuthenticationPolicy(credential, "scope");
+
+ var insufficientClaimsChallengeResponse = new MockResponse(401);
+
+ insufficientClaimsChallengeResponse.AddHeader(new HttpHeader("WWW-Authenticate", CaeInsufficientClaimsChallenge));
+
+ var sessionRevokedChallengeResponse = new MockResponse(401);
+
+ sessionRevokedChallengeResponse.AddHeader(new HttpHeader("WWW-Authenticate", CaeSessionsRevokedClaimsChallenge));
+
+ var armChallengeResponse = new MockResponse(401);
+
+ armChallengeResponse.AddHeader(new HttpHeader("WWW-Authenticate", ArmChallenge));
+
+ var keyvaultChallengeResponse = new MockResponse(401);
+
+ keyvaultChallengeResponse.AddHeader(new HttpHeader("WWW-Authenticate", KeyVaultChallenge));
+
+ MockTransport transport = CreateMockTransport(new MockResponse(200),
+ insufficientClaimsChallengeResponse,
+ new MockResponse(200),
+ sessionRevokedChallengeResponse,
+ new MockResponse(200),
+ armChallengeResponse,
+ keyvaultChallengeResponse);
+
+ var response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default);
+
+ Assert.AreEqual(tokensRequested, 1);
+
+ Assert.AreEqual(response.Status, 200);
+
+ currentClaimChallenge = Base64Url.DecodeString(CaeInsufficientClaimsChallengeValue);
+
+ response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default);
+
+ Assert.AreEqual(tokensRequested, 2);
+
+ Assert.AreEqual(response.Status, 200);
+
+ currentClaimChallenge = Base64Url.DecodeString(CaeSessionsRevokedClaimsChallengeValue);
+
+ response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default);
+
+ Assert.AreEqual(tokensRequested, 3);
+
+ Assert.AreEqual(response.Status, 200);
+
+ currentClaimChallenge = null;
+
+ response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default);
+
+ Assert.AreEqual(tokensRequested, 3);
+
+ Assert.AreEqual(response.Status, 401);
+
+ response = await SendGetRequest(transport, policy, uri: new Uri("https://example.com"), cancellationToken: default);
+
+ Assert.AreEqual(tokensRequested, 3);
+
+ Assert.AreEqual(response.Status, 401);
+ }
+
+ private void ValidateParsedChallenge(Challenge expected, Challenge actual)
+ {
+ Assert.AreEqual(expected.Scheme, actual.Scheme);
+
+ CollectionAssert.AreEquivalent(expected.Parameters, actual.Parameters);
+ }
+
+ private class TokenCredentialStub : TokenCredential
+ {
+ public TokenCredentialStub(Func handler, bool isAsync)
+ {
+ if (isAsync)
+ {
+#pragma warning disable 1998
+ _getTokenAsyncHandler = async (r, c) => handler(r, c);
+#pragma warning restore 1998
+ }
+ else
+ {
+ _getTokenHandler = handler;
+ }
+ }
+
+ private readonly Func> _getTokenAsyncHandler;
+ private readonly Func _getTokenHandler;
+
+ public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
+ => _getTokenAsyncHandler(requestContext, cancellationToken);
+
+ public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
+ => _getTokenHandler(requestContext, cancellationToken);
+ }
+ }
+}