diff --git a/sdk/core/Azure.Core/src/Shared/ClientDiagnostics.cs b/sdk/core/Azure.Core/src/Shared/ClientDiagnostics.cs index 45871fff9c779..fff0f60ae9674 100644 --- a/sdk/core/Azure.Core/src/Shared/ClientDiagnostics.cs +++ b/sdk/core/Azure.Core/src/Shared/ClientDiagnostics.cs @@ -81,20 +81,9 @@ public RequestFailedException CreateRequestFailedExceptionWithContent( return exception; } - public ValueTask CreateRequestFailedMessageAsync(Response response, string? message = null, string? errorCode= null, IDictionary? additionalInfo = null) - { - return CreateRequestFailedMessageAsync(response, message, errorCode, additionalInfo, true); - } - - public string CreateRequestFailedMessage(Response response, string? message = null, string? errorCode = null, IDictionary? additionalInfo = null) - { - return CreateRequestFailedMessageAsync(response, message, errorCode, additionalInfo, false).EnsureCompleted(); - } - - private async ValueTask CreateRequestFailedMessageAsync(Response response, string? message, string? errorCode, IDictionary? additionalInfo, bool async) + public async ValueTask CreateRequestFailedMessageAsync(Response response, string? message, string? errorCode, IDictionary? additionalInfo, bool async) { var content = await ReadContentAsync(response, async).ConfigureAwait(false); - return CreateRequestFailedMessageWithContent(response, message, content, errorCode, additionalInfo); } diff --git a/sdk/core/Azure.Core/src/Shared/TaskExtensions.cs b/sdk/core/Azure.Core/src/Shared/TaskExtensions.cs index a22738192f8bd..74aba2bb6a540 100644 --- a/sdk/core/Azure.Core/src/Shared/TaskExtensions.cs +++ b/sdk/core/Azure.Core/src/Shared/TaskExtensions.cs @@ -15,6 +15,9 @@ namespace Azure.Core.Pipeline { internal static class TaskExtensions { + public static WithCancellationTaskAwaitable AwaitWithCancellation(this Task task, CancellationToken cancellationToken) + => new WithCancellationTaskAwaitable(task, cancellationToken); + public static WithCancellationTaskAwaitable AwaitWithCancellation(this Task task, CancellationToken cancellationToken) => new WithCancellationTaskAwaitable(task, cancellationToken); @@ -141,6 +144,20 @@ private static void VerifyTaskCompleted(bool isCompleted) #pragma warning restore AZC0107 // Do not call public asynchronous method in synchronous scope. } + public readonly struct WithCancellationTaskAwaitable + { + private readonly CancellationToken _cancellationToken; + private readonly ConfiguredTaskAwaitable _awaitable; + + public WithCancellationTaskAwaitable(Task task, CancellationToken cancellationToken) + { + _awaitable = task.ConfigureAwait(false); + _cancellationToken = cancellationToken; + } + + public WithCancellationTaskAwaiter GetAwaiter() => new WithCancellationTaskAwaiter(_awaitable.GetAwaiter(), _cancellationToken); + } + public readonly struct WithCancellationTaskAwaitable { private readonly CancellationToken _cancellationToken; @@ -169,6 +186,39 @@ public WithCancellationValueTaskAwaitable(ValueTask task, CancellationToken c public WithCancellationValueTaskAwaiter GetAwaiter() => new WithCancellationValueTaskAwaiter(_awaitable.GetAwaiter(), _cancellationToken); } + public readonly struct WithCancellationTaskAwaiter : ICriticalNotifyCompletion + { + private readonly CancellationToken _cancellationToken; + private readonly ConfiguredTaskAwaitable.ConfiguredTaskAwaiter _taskAwaiter; + + public WithCancellationTaskAwaiter(ConfiguredTaskAwaitable.ConfiguredTaskAwaiter awaiter, CancellationToken cancellationToken) + { + _taskAwaiter = awaiter; + _cancellationToken = cancellationToken; + } + + public bool IsCompleted => _taskAwaiter.IsCompleted || _cancellationToken.IsCancellationRequested; + + public void OnCompleted(Action continuation) => _taskAwaiter.OnCompleted(WrapContinuation(continuation)); + + public void UnsafeOnCompleted(Action continuation) => _taskAwaiter.UnsafeOnCompleted(WrapContinuation(continuation)); + + public void GetResult() + { + Debug.Assert(IsCompleted); + if (!_taskAwaiter.IsCompleted) + { + _cancellationToken.ThrowIfCancellationRequested(); + } + _taskAwaiter.GetResult(); + } + + private Action WrapContinuation(in Action originalContinuation) + => _cancellationToken.CanBeCanceled + ? new WithCancellationContinuationWrapper(originalContinuation, _cancellationToken).Continuation + : originalContinuation; + } + public readonly struct WithCancellationTaskAwaiter : ICriticalNotifyCompletion { private readonly CancellationToken _cancellationToken; diff --git a/sdk/identity/Azure.Identity/src/AppServiceV2017ManagedIdentitySource.cs b/sdk/identity/Azure.Identity/src/AppServiceV2017ManagedIdentitySource.cs new file mode 100644 index 0000000000000..9cb56d52f31dd --- /dev/null +++ b/sdk/identity/Azure.Identity/src/AppServiceV2017ManagedIdentitySource.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Globalization; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Azure.Identity +{ + internal class AppServiceV2017ManagedIdentitySource : IManagedIdentitySource + { + // MSI Constants. Docs for MSI are available here https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity + private const string AppServiceMsiApiVersion = "2017-09-01"; + private const string MsiEndpointInvalidUriError = "The environment variable MSI_ENDPOINT contains an invalid Uri."; + + private readonly HttpPipeline _pipeline; + private readonly Uri _endpoint; + private readonly string _secret; + private readonly string _clientId; + + public static IManagedIdentitySource TryCreate(HttpPipeline pipeline, string clientId) + { + string msiEndpoint = EnvironmentVariables.MsiEndpoint; + string msiSecret = EnvironmentVariables.MsiSecret; + + // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the MsiType is AppService + if (string.IsNullOrEmpty(msiEndpoint) || string.IsNullOrEmpty(msiSecret)) + { + return default; + } + + Uri endpointUri; + try + { + endpointUri = new Uri(msiEndpoint); + } + catch (FormatException ex) + { + throw new AuthenticationFailedException(MsiEndpointInvalidUriError, ex); + } + + return new AppServiceV2017ManagedIdentitySource(pipeline, endpointUri, msiSecret, clientId); + } + + private AppServiceV2017ManagedIdentitySource(HttpPipeline pipeline, Uri endpoint, string secret, string clientId) + { + _pipeline = pipeline; + _endpoint = endpoint; + _secret = secret; + _clientId = clientId; + } + + public Request CreateRequest(string[] scopes) + { + // covert the scopes to a resource string + string resource = ScopeUtilities.ScopesToResource(scopes); + + Request request = _pipeline.CreateRequest(); + + request.Method = RequestMethod.Get; + request.Headers.Add("secret", _secret); + request.Uri.Reset(_endpoint); + request.Uri.AppendQuery("api-version", AppServiceMsiApiVersion); + request.Uri.AppendQuery("resource", resource); + + if (!string.IsNullOrEmpty(_clientId)) + { + request.Uri.AppendQuery("clientid", _clientId); + } + + return request; + } + + public AccessToken GetAccessTokenFromJson(in JsonElement jsonAccessToken, in JsonElement jsonExpiresOn) + { + // AppService version 2017-09-01 sends expires_on as a string formatted datetimeoffset + if (DateTimeOffset.TryParse(jsonExpiresOn.GetString(), CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset expiresOn)) + { + return new AccessToken(jsonAccessToken.GetString(), expiresOn); + } + + throw new AuthenticationFailedException(ManagedIdentityClient.AuthenticationResponseInvalidFormatError); + } + + public ValueTask HandleFailedRequestAsync(Response response, ClientDiagnostics diagnostics, bool async) => new ValueTask(); + } +} diff --git a/sdk/identity/Azure.Identity/src/AppServiceV2019ManagedIdentitySource.cs b/sdk/identity/Azure.Identity/src/AppServiceV2019ManagedIdentitySource.cs new file mode 100644 index 0000000000000..9885afd12288f --- /dev/null +++ b/sdk/identity/Azure.Identity/src/AppServiceV2019ManagedIdentitySource.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Azure.Identity +{ + internal class AppServiceV2019ManagedIdentitySource : IManagedIdentitySource + { + private const string AppServiceMsiApiVersion = "2019-08-01"; + private const string IdentityEndpointInvalidUriError = "The environment variable IDENTITY_ENDPOINT contains an invalid Uri."; + + private readonly HttpPipeline _pipeline; + private readonly Uri _endpoint; + private readonly string _secret; + private readonly string _clientId; + + public static IManagedIdentitySource TryCreate(HttpPipeline pipeline, string clientId) + { + string identityEndpoint = EnvironmentVariables.IdentityEndpoint; + string identityHeader = EnvironmentVariables.IdentityHeader; + + if (string.IsNullOrEmpty(identityEndpoint) || string.IsNullOrEmpty(identityHeader)) + { + return default; + } + + Uri endpointUri; + try + { + endpointUri = new Uri(identityEndpoint); + } + catch (FormatException ex) + { + throw new AuthenticationFailedException(IdentityEndpointInvalidUriError, ex); + } + + return new AppServiceV2019ManagedIdentitySource(pipeline, endpointUri, identityHeader, clientId); + } + + private AppServiceV2019ManagedIdentitySource(HttpPipeline pipeline, Uri endpoint, string secret, string clientId) + { + _pipeline = pipeline; + _endpoint = endpoint; + _secret = secret; + _clientId = clientId; + } + + public Request CreateRequest(string[] scopes) + { + // covert the scopes to a resource string + string resource = ScopeUtilities.ScopesToResource(scopes); + + Request request = _pipeline.CreateRequest(); + + request.Method = RequestMethod.Get; + request.Headers.Add("X-IDENTITY-HEADER", _secret); + request.Uri.Reset(_endpoint); + request.Uri.AppendQuery("api-version", AppServiceMsiApiVersion); + request.Uri.AppendQuery("resource", resource); + + if (!string.IsNullOrEmpty(_clientId)) + { + request.Uri.AppendQuery("client_id", _clientId); + } + + return request; + } + + public AccessToken GetAccessTokenFromJson(in JsonElement jsonAccessToken, in JsonElement jsonExpiresOn) + { + // the seconds from epoch may be returned as a Json number or a Json string which is a number + // depending on the environment. If neither of these are the case we throw an AuthException. + if (jsonExpiresOn.ValueKind == JsonValueKind.Number && jsonExpiresOn.TryGetInt64(out long expiresOnSec) || + jsonExpiresOn.ValueKind == JsonValueKind.String && long.TryParse(jsonExpiresOn.GetString(), out expiresOnSec)) + { + return new AccessToken(jsonAccessToken.GetString(), DateTimeOffset.FromUnixTimeSeconds(expiresOnSec)); + } + + throw new AuthenticationFailedException(ManagedIdentityClient.AuthenticationResponseInvalidFormatError); + } + + public ValueTask HandleFailedRequestAsync(Response response, ClientDiagnostics diagnostics, bool async) => new ValueTask(); + } +} diff --git a/sdk/identity/Azure.Identity/src/CloudShellManagedIdentitySource.cs b/sdk/identity/Azure.Identity/src/CloudShellManagedIdentitySource.cs new file mode 100644 index 0000000000000..1a2670a0c20d4 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/CloudShellManagedIdentitySource.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Azure.Identity +{ + internal class CloudShellManagedIdentitySource : IManagedIdentitySource + { + private readonly HttpPipeline _pipeline; + private readonly Uri _endpoint; + private readonly string _clientId; + private const string MsiEndpointInvalidUriError = "The environment variable MSI_ENDPOINT contains an invalid Uri."; + + public static IManagedIdentitySource TryCreate(HttpPipeline pipeline, string clientId) + { + string msiEndpoint = EnvironmentVariables.MsiEndpoint; + + // if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell + if (string.IsNullOrEmpty(msiEndpoint)) + { + return default; + } + + Uri endpointUri; + try + { + endpointUri = new Uri(msiEndpoint); + } + catch (FormatException ex) + { + throw new AuthenticationFailedException(MsiEndpointInvalidUriError, ex); + } + + return new CloudShellManagedIdentitySource(pipeline, endpointUri, clientId); + } + + private CloudShellManagedIdentitySource(HttpPipeline pipeline, Uri endpoint, string clientId) + { + _pipeline = pipeline; + _endpoint = endpoint; + _clientId = clientId; + } + + public Request CreateRequest(string[] scopes) + { + // covert the scopes to a resource string + string resource = ScopeUtilities.ScopesToResource(scopes); + + Request request = _pipeline.CreateRequest(); + + request.Method = RequestMethod.Post; + + request.Headers.Add(HttpHeader.Common.FormUrlEncodedContentType); + + request.Uri.Reset(_endpoint); + + request.Headers.Add("Metadata", "true"); + + var bodyStr = $"resource={Uri.EscapeDataString(resource)}"; + + if (!string.IsNullOrEmpty(_clientId)) + { + bodyStr += $"&client_id={Uri.EscapeDataString(_clientId)}"; + } + + ReadOnlyMemory content = Encoding.UTF8.GetBytes(bodyStr).AsMemory(); + request.Content = RequestContent.Create(content); + return request; + } + + public AccessToken GetAccessTokenFromJson(in JsonElement jsonAccessToken, in JsonElement jsonExpiresOn) + { + // the seconds from epoch may be returned as a Json number or a Json string which is a number + // depending on the environment. If neither of these are the case we throw an AuthException. + if (jsonExpiresOn.ValueKind == JsonValueKind.Number && jsonExpiresOn.TryGetInt64(out long expiresOnSec) || + jsonExpiresOn.ValueKind == JsonValueKind.String && long.TryParse(jsonExpiresOn.GetString(), out expiresOnSec)) + { + return new AccessToken(jsonAccessToken.GetString(), DateTimeOffset.FromUnixTimeSeconds(expiresOnSec)); + } + + throw new AuthenticationFailedException(ManagedIdentityClient.AuthenticationResponseInvalidFormatError); + } + + public ValueTask HandleFailedRequestAsync(Response response, ClientDiagnostics diagnostics, bool async) => new ValueTask(); + } +} diff --git a/sdk/identity/Azure.Identity/src/EnvironmentVariables.cs b/sdk/identity/Azure.Identity/src/EnvironmentVariables.cs index a2057c087ab4e..5fe545eb380b6 100644 --- a/sdk/identity/Azure.Identity/src/EnvironmentVariables.cs +++ b/sdk/identity/Azure.Identity/src/EnvironmentVariables.cs @@ -14,6 +14,8 @@ internal class EnvironmentVariables public static string ClientSecret => Environment.GetEnvironmentVariable("AZURE_CLIENT_SECRET"); public static string ClientCertificatePath => Environment.GetEnvironmentVariable("AZURE_CLIENT_CERTIFICATE_PATH"); + public static string IdentityEndpoint => Environment.GetEnvironmentVariable("IDENTITY_ENDPOINT"); + public static string IdentityHeader => Environment.GetEnvironmentVariable("IDENTITY_HEADER"); public static string MsiEndpoint => Environment.GetEnvironmentVariable("MSI_ENDPOINT"); public static string MsiSecret => Environment.GetEnvironmentVariable("MSI_SECRET"); diff --git a/sdk/identity/Azure.Identity/src/IManagedIdentitySource.cs b/sdk/identity/Azure.Identity/src/IManagedIdentitySource.cs new file mode 100644 index 0000000000000..9d3555b0f92a7 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/IManagedIdentitySource.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Azure.Identity +{ + internal interface IManagedIdentitySource + { + Request CreateRequest(string[] scopes); + AccessToken GetAccessTokenFromJson(in JsonElement jsonAccessToken, in JsonElement jsonExpiresOn); + ValueTask HandleFailedRequestAsync(Response response, ClientDiagnostics diagnostics, bool async); + } +} diff --git a/sdk/identity/Azure.Identity/src/ImdsManagedIdentitySource.cs b/sdk/identity/Azure.Identity/src/ImdsManagedIdentitySource.cs new file mode 100644 index 0000000000000..42a4db16e55a8 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/ImdsManagedIdentitySource.cs @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Net.Sockets; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Azure.Identity +{ + internal class ImdsManagedIdentitySource : IManagedIdentitySource + { + // IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + private static readonly Uri s_imdsEndpoint = new Uri("http://169.254.169.254/metadata/identity/oauth2/token"); + private static readonly IPAddress s_imdsHostIp = IPAddress.Parse("169.254.169.254"); + private const int s_imdsPort = 80; + private const int ImdsAvailableTimeoutMs = 1000; + private const string ImdsApiVersion = "2018-02-01"; + + internal const string IdentityUnavailableError = "ManagedIdentityCredential authentication unavailable. The requested identity has not been assigned to this resource."; + + private readonly HttpPipeline _pipeline; + private readonly string _clientId; + + private string _identityUnavailableErrorMessage; + + public static async ValueTask TryCreateAsync(HttpPipeline pipeline, string clientId, bool async, CancellationToken cancellationToken) + { + AzureIdentityEventSource.Singleton.ProbeImdsEndpoint(s_imdsEndpoint); + + bool available; + // try to create a TCP connection to the IMDS IP address. If the connection can be established + // we assume that IMDS is available. If connecting times out or fails to connect assume that + // IMDS is not available in this environment. + try + { + using var client = new TcpClient(); + Task connectTask = client.ConnectAsync(s_imdsHostIp, s_imdsPort); + + if (async) + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + cts.CancelAfter(ImdsAvailableTimeoutMs); + await connectTask.AwaitWithCancellation(cts.Token); + available = client.Connected; + } + else + { + available = connectTask.Wait(ImdsAvailableTimeoutMs, cancellationToken) && client.Connected; + } + } + catch + { + available = false; + } + + if (available) + { + AzureIdentityEventSource.Singleton.ImdsEndpointFound(s_imdsEndpoint); + } + else + { + AzureIdentityEventSource.Singleton.ImdsEndpointUnavailable(s_imdsEndpoint); + } + + return available ? new ImdsManagedIdentitySource(pipeline, clientId) : default; + } + + internal ImdsManagedIdentitySource(HttpPipeline pipeline, string clientId) + { + _pipeline = pipeline; + _clientId = clientId; + } + + public Request CreateRequest(string[] scopes) + { + if (_identityUnavailableErrorMessage != default) + { + throw new CredentialUnavailableException(_identityUnavailableErrorMessage); + } + + // covert the scopes to a resource string + string resource = ScopeUtilities.ScopesToResource(scopes); + + Request request = _pipeline.CreateRequest(); + request.Method = RequestMethod.Get; + request.Headers.Add("Metadata", "true"); + request.Uri.Reset(s_imdsEndpoint); + request.Uri.AppendQuery("api-version", ImdsApiVersion); + + request.Uri.AppendQuery("resource", resource); + + if (!string.IsNullOrEmpty(_clientId)) + { + request.Uri.AppendQuery("client_id", _clientId); + } + + return request; + } + + public AccessToken GetAccessTokenFromJson(in JsonElement jsonAccessToken, in JsonElement jsonExpiresOn) + { + // the seconds from epoch may be returned as a Json number or a Json string which is a number + // depending on the environment. If neither of these are the case we throw an AuthException. + if (jsonExpiresOn.ValueKind == JsonValueKind.Number && jsonExpiresOn.TryGetInt64(out long expiresOnSec) || + jsonExpiresOn.ValueKind == JsonValueKind.String && long.TryParse(jsonExpiresOn.GetString(), out expiresOnSec)) + { + return new AccessToken(jsonAccessToken.GetString(), DateTimeOffset.FromUnixTimeSeconds(expiresOnSec)); + } + + throw new AuthenticationFailedException(ManagedIdentityClient.AuthenticationResponseInvalidFormatError); + } + + public async ValueTask HandleFailedRequestAsync(Response response, ClientDiagnostics diagnostics, bool async) + { + if (response.Status == 400) + { + string message = _identityUnavailableErrorMessage ?? await diagnostics + .CreateRequestFailedMessageAsync(response, IdentityUnavailableError, null, null, async) + .ConfigureAwait(false); + + Interlocked.CompareExchange(ref _identityUnavailableErrorMessage, message, null); + throw new CredentialUnavailableException(message); + } + } + } +} diff --git a/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs b/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs index 553f0c995c2b5..13b71e8450dd1 100644 --- a/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs +++ b/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs @@ -1,12 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System; -using System.Globalization; -using System.IO; -using System.Net; -using System.Net.Sockets; -using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -16,25 +10,10 @@ namespace Azure.Identity { internal class ManagedIdentityClient { - private const string AuthenticationResponseInvalidFormatError = "Invalid response, the authentication response was not in the expected format."; - private const string MsiEndpointInvalidUriError = "The environment variable MSI_ENDPOINT contains an invalid Uri."; + internal const string AuthenticationResponseInvalidFormatError = "Invalid response, the authentication response was not in the expected format."; internal const string MsiUnavailableError = "ManagedIdentityCredential authentication unavailable. No Managed Identity endpoint found."; - internal const string IdentityUnavailableError = "ManagedIdentityCredential authentication unavailable. The requested identity has not been assigned to this resource."; - - // IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http - private static readonly Uri s_imdsEndpoint = new Uri("http://169.254.169.254/metadata/identity/oauth2/token"); - private static readonly IPAddress s_imdsHostIp = IPAddress.Parse("169.254.169.254"); - private const int s_imdsPort = 80; - private const string ImdsApiVersion = "2018-02-01"; - private const int ImdsAvailableTimeoutMs = 1000; - - // MSI Constants. Docs for MSI are available here https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity - private const string AppServiceMsiApiVersion = "2017-09-01"; - - private static readonly SemaphoreSlim _initLock = new SemaphoreSlim(1, 1); - private MsiType _msiType; - private Uri _endpoint; + private readonly AsyncLockWithValue _identitySourceAsyncLock = new AsyncLockWithValue(); private readonly CredentialPipeline _pipeline; protected ManagedIdentityClient() @@ -44,391 +23,71 @@ protected ManagedIdentityClient() public ManagedIdentityClient(CredentialPipeline pipeline, string clientId = null) { _pipeline = pipeline; - ClientId = clientId; } protected string ClientId { get; } - public virtual AccessToken Authenticate(string[] scopes, CancellationToken cancellationToken) + public virtual async ValueTask AuthenticateAsync(bool async, string[] scopes, CancellationToken cancellationToken) { - MsiType msiType = GetMsiType(cancellationToken); + IManagedIdentitySource identitySource = await GetManagedIdentitySourceAsync(async, cancellationToken).ConfigureAwait(false); // if msi is unavailable or we were unable to determine the type return CredentialUnavailable exception that no endpoint was found - if (msiType == MsiType.Unavailable || msiType == MsiType.Unknown) + if (identitySource == default) { throw new CredentialUnavailableException(MsiUnavailableError); } - using Request request = CreateAuthRequest(msiType, scopes); - - Response response = _pipeline.HttpPipeline.SendRequest(request, cancellationToken); + using Request request = identitySource.CreateRequest(scopes); + Response response = async + ? await _pipeline.HttpPipeline.SendRequestAsync(request, cancellationToken).ConfigureAwait(false) + : _pipeline.HttpPipeline.SendRequest(request, cancellationToken); if (response.Status == 200) { - AccessToken result = Deserialize(response.ContentStream); + using JsonDocument json = async + ? await JsonDocument.ParseAsync(response.ContentStream, default, cancellationToken).ConfigureAwait(false) + : JsonDocument.Parse(response.ContentStream); - return result; + (JsonElement accessToken, JsonElement expiresOnProp) = GetAccessTokenProperties(json.RootElement); + return identitySource.GetAccessTokenFromJson(accessToken, expiresOnProp); } - if (response.Status == 400 && msiType == MsiType.Imds) - { - _msiType = MsiType.Unavailable; - - string message = _pipeline.Diagnostics.CreateRequestFailedMessage(response, message: IdentityUnavailableError); + await identitySource.HandleFailedRequestAsync(response, _pipeline.Diagnostics, async).ConfigureAwait(false); - throw new CredentialUnavailableException(message); - } - - throw _pipeline.Diagnostics.CreateRequestFailedException(response); + throw async + ? await _pipeline.Diagnostics.CreateRequestFailedExceptionAsync(response).ConfigureAwait(false) + : _pipeline.Diagnostics.CreateRequestFailedException(response); } - public virtual async Task AuthenticateAsync(string[] scopes, CancellationToken cancellationToken) + private protected virtual async ValueTask GetManagedIdentitySourceAsync(bool async, CancellationToken cancellationToken) { - MsiType msiType = await GetMsiTypeAsync(cancellationToken).ConfigureAwait(false); - - // if msi is unavailable or we were unable to determine the type return CredentialUnavailable exception that no endpoint was found - if (msiType == MsiType.Unavailable || msiType == MsiType.Unknown) - { - throw new CredentialUnavailableException(MsiUnavailableError); - } - - using Request request = CreateAuthRequest(msiType, scopes); - - Response response = await _pipeline.HttpPipeline.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); - - if (response.Status == 200) - { - AccessToken result = await DeserializeAsync(response.ContentStream, cancellationToken).ConfigureAwait(false); - - return result; - } - - if (response.Status == 400 && msiType == MsiType.Imds) + using var asyncLock = await _identitySourceAsyncLock.GetLockOrValueAsync(async, cancellationToken).ConfigureAwait(false); + if (asyncLock.HasValue) { - _msiType = MsiType.Unavailable; - - string message = await _pipeline.Diagnostics.CreateRequestFailedMessageAsync(response, message: IdentityUnavailableError, errorCode: null).ConfigureAwait(false); - - throw new CredentialUnavailableException(message); + return asyncLock.Value; } - throw await _pipeline.Diagnostics.CreateRequestFailedExceptionAsync(response).ConfigureAwait(false); - } - - protected virtual MsiType GetMsiType(CancellationToken cancellationToken) - { - // if we haven't already determined the msi type - if (_msiType == MsiType.Unknown) - { - // acquire the init lock - _initLock.Wait(cancellationToken); - - try - { - // check again if the we already determined the msiType now that we hold the lock - if (_msiType == MsiType.Unknown) - { - string endpointEnvVar = EnvironmentVariables.MsiEndpoint; - string secretEnvVar = EnvironmentVariables.MsiSecret; - - // if the env var MSI_ENDPOINT is set - if (!string.IsNullOrEmpty(endpointEnvVar)) - { - try - { - _endpoint = new Uri(endpointEnvVar); - } - catch (FormatException ex) - { - throw new AuthenticationFailedException(MsiEndpointInvalidUriError, ex); - } - - // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the MsiType is AppService - if (!string.IsNullOrEmpty(secretEnvVar)) - { - _msiType = MsiType.AppService; - } - // if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell - else - { - _msiType = MsiType.CloudShell; - } - } - // if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the MsiType is Imds - else if (ImdsAvailable(cancellationToken)) - { - _endpoint = s_imdsEndpoint; - _msiType = MsiType.Imds; - } - // if MSI_ENDPOINT is NOT set and IMDS endpoint is not available ManagedIdentity is not available - else - { - _msiType = MsiType.Unavailable; - } - } - } - // release the init lock - finally - { - _initLock.Release(); - } - } + IManagedIdentitySource identitySource = AppServiceV2019ManagedIdentitySource.TryCreate(_pipeline.HttpPipeline, ClientId) ?? + AppServiceV2017ManagedIdentitySource.TryCreate(_pipeline.HttpPipeline, ClientId) ?? + CloudShellManagedIdentitySource.TryCreate(_pipeline.HttpPipeline, ClientId) ?? + await ImdsManagedIdentitySource.TryCreateAsync(_pipeline.HttpPipeline, ClientId, async, cancellationToken).ConfigureAwait(false); - return _msiType; + asyncLock.SetValue(identitySource); + return identitySource; } - protected virtual async Task GetMsiTypeAsync(CancellationToken cancellationToken) + private static (JsonElement accessToken, JsonElement expiresOnProp) GetAccessTokenProperties(in JsonElement root) { - // if we haven't already determined the msi type - if (_msiType == MsiType.Unknown) - { - // acquire the init lock - await _initLock.WaitAsync(cancellationToken).ConfigureAwait(false); - - try - { - // check again if the we already determined the msiType now that we hold the lock - if (_msiType == MsiType.Unknown) - { - string endpointEnvVar = EnvironmentVariables.MsiEndpoint; - string secretEnvVar = EnvironmentVariables.MsiSecret; - - // if the env var MSI_ENDPOINT is set - if (!string.IsNullOrEmpty(endpointEnvVar)) - { - try - { - _endpoint = new Uri(endpointEnvVar); - } - catch (FormatException ex) - { - throw new AuthenticationFailedException(MsiEndpointInvalidUriError, ex); - } - - // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the MsiType is AppService - if (!string.IsNullOrEmpty(secretEnvVar)) - { - _msiType = MsiType.AppService; - } - // if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell - else - { - _msiType = MsiType.CloudShell; - } - } - // if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the MsiType is Imds - else if (await ImdsAvailableAsync(cancellationToken).ConfigureAwait(false)) - { - _endpoint = s_imdsEndpoint; - _msiType = MsiType.Imds; - } - // if MSI_ENDPOINT is NOT set and IMDS endpoint is not available ManagedIdentity is not available - else - { - _msiType = MsiType.Unavailable; - } - } - } - // release the init lock - finally - { - _initLock.Release(); - } - } - - return _msiType; - } - - protected virtual bool ImdsAvailable(CancellationToken cancellationToken) - { - AzureIdentityEventSource.Singleton.ProbeImdsEndpoint(s_imdsEndpoint); - - bool available; - // try to create a TCP connection to the IMDS IP address. If the connection can be established - // we assume that IMDS is available. If connecting times out or fails to connect assume that - // IMDS is not available in this environment. - try - { - using (var client = new TcpClient()) - { - var result = client.BeginConnect(s_imdsHostIp, s_imdsPort, null, null); - - var success = result.AsyncWaitHandle.WaitOne(ImdsAvailableTimeoutMs); - - available = success && client.Connected; - } - } - catch - { - available = false; - } - - if (available) - { - AzureIdentityEventSource.Singleton.ImdsEndpointFound(s_imdsEndpoint); - } - else - { - AzureIdentityEventSource.Singleton.ImdsEndpointUnavailable(s_imdsEndpoint); - } - - return available; - } - - protected virtual async Task ImdsAvailableAsync(CancellationToken cancellationToken) - { - AzureIdentityEventSource.Singleton.ProbeImdsEndpoint(s_imdsEndpoint); - - bool available; - // try to create a TCP connection to the IMDS IP address. If the connection can be established - // we assume that IMDS is available. If connecting times out or fails to connect assume that - // IMDS is not available in this environment. - try - { - using (var client = new TcpClient()) - { - var result = client.BeginConnect(s_imdsHostIp, s_imdsPort, null, null); - - var success = await Task.Run(() => result.AsyncWaitHandle.WaitOne(ImdsAvailableTimeoutMs), cancellationToken).ConfigureAwait(false); - - available = success && client.Connected; - } - } - catch - { - available = false; - } - - if (available) - { - AzureIdentityEventSource.Singleton.ImdsEndpointFound(s_imdsEndpoint); - } - else - { - AzureIdentityEventSource.Singleton.ImdsEndpointUnavailable(s_imdsEndpoint); - } - - return available; - } - - private Request CreateAuthRequest(MsiType msiType, string[] scopes) - { - return msiType switch - { - MsiType.Imds => CreateImdsAuthRequest(scopes), - MsiType.AppService => CreateAppServiceAuthRequest(scopes), - MsiType.CloudShell => CreateCloudShellAuthRequest(scopes), - _ => default, - }; - } - - private Request CreateImdsAuthRequest(string[] scopes) - { - // covert the scopes to a resource string - string resource = ScopeUtilities.ScopesToResource(scopes); - - Request request = _pipeline.HttpPipeline.CreateRequest(); - - request.Method = RequestMethod.Get; - - request.Headers.Add("Metadata", "true"); - - request.Uri.Reset(_endpoint); - - request.Uri.AppendQuery("api-version", ImdsApiVersion); - - request.Uri.AppendQuery("resource", resource); - - if (!string.IsNullOrEmpty(ClientId)) - { - request.Uri.AppendQuery("client_id", ClientId); - } - - return request; - } - - private Request CreateAppServiceAuthRequest(string[] scopes) - { - // covert the scopes to a resource string - string resource = ScopeUtilities.ScopesToResource(scopes); - - Request request = _pipeline.HttpPipeline.CreateRequest(); - - request.Method = RequestMethod.Get; - - request.Headers.Add("secret", EnvironmentVariables.MsiSecret); - - request.Uri.Reset(_endpoint); - - request.Uri.AppendQuery("api-version", AppServiceMsiApiVersion); - - request.Uri.AppendQuery("resource", resource); - - if (!string.IsNullOrEmpty(ClientId)) - { - request.Uri.AppendQuery("clientid", ClientId); - } - - return request; - } - - private Request CreateCloudShellAuthRequest(string[] scopes) - { - // covert the scopes to a resource string - string resource = ScopeUtilities.ScopesToResource(scopes); - - Request request = _pipeline.HttpPipeline.CreateRequest(); - - request.Method = RequestMethod.Post; - - request.Headers.Add(HttpHeader.Common.FormUrlEncodedContentType); - - request.Uri.Reset(_endpoint); - - request.Headers.Add("Metadata", "true"); - - var bodyStr = $"resource={Uri.EscapeDataString(resource)}"; - - if (!string.IsNullOrEmpty(ClientId)) - { - bodyStr += $"&client_id={Uri.EscapeDataString(ClientId)}"; - } - - ReadOnlyMemory content = Encoding.UTF8.GetBytes(bodyStr).AsMemory(); - - request.Content = RequestContent.Create(content); - - return request; - } - - private async Task DeserializeAsync(Stream content, CancellationToken cancellationToken) - { - using (JsonDocument json = await JsonDocument.ParseAsync(content, default, cancellationToken).ConfigureAwait(false)) - { - return Deserialize(json.RootElement); - } - } - - private AccessToken Deserialize(Stream content) - { - using (JsonDocument json = JsonDocument.Parse(content)) - { - return Deserialize(json.RootElement); - } - } - - private AccessToken Deserialize(JsonElement json) - { - string accessToken = null; + JsonElement? accessToken = null; JsonElement? expiresOnProp = null; - foreach (JsonProperty prop in json.EnumerateObject()) + foreach (JsonProperty prop in root.EnumerateObject()) { switch (prop.Name) { case "access_token": - accessToken = prop.Value.GetString(); + accessToken = prop.Value; break; case "expires_on": @@ -437,42 +96,9 @@ private AccessToken Deserialize(JsonElement json) } } - if (accessToken is null || !expiresOnProp.HasValue) - { - throw new AuthenticationFailedException(AuthenticationResponseInvalidFormatError); - } - - DateTimeOffset expiresOn; - // if s_msiType is AppService expires_on will be a string formatted datetimeoffset - if (_msiType == MsiType.AppService) - { - if (!DateTimeOffset.TryParse(expiresOnProp.Value.GetString(), CultureInfo.InvariantCulture, DateTimeStyles.None, out expiresOn)) - { - throw new AuthenticationFailedException(AuthenticationResponseInvalidFormatError); - } - } - // otherwise expires_on will be a unix timestamp seconds from epoch - else - { - // the seconds from epoch may be returned as a Json number or a Json string which is a number - // depending on the environment. If neither of these are the case we throw an AuthException. - if (!(expiresOnProp.Value.ValueKind == JsonValueKind.Number && expiresOnProp.Value.TryGetInt64(out long expiresOnSec)) && - !(expiresOnProp.Value.ValueKind == JsonValueKind.String && long.TryParse(expiresOnProp.Value.GetString(), out expiresOnSec))) - { - throw new AuthenticationFailedException(AuthenticationResponseInvalidFormatError); - } - - expiresOn = DateTimeOffset.FromUnixTimeSeconds(expiresOnSec); - } - - return new AccessToken(accessToken, expiresOn); - } - - private struct Error - { - public string Code { get; set; } - - public string Message { get; set; } + return accessToken.HasValue && expiresOnProp.HasValue + ? (accessToken.Value, expiresOnProp.Value) + : throw new AuthenticationFailedException(AuthenticationResponseInvalidFormatError); } } } diff --git a/sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs b/sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs index f65f06217d4ec..0121884dcf344 100644 --- a/sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs +++ b/sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs @@ -87,8 +87,7 @@ private async ValueTask GetTokenImplAsync(bool async, TokenRequestC try { - AccessToken result = async ? await _client.AuthenticateAsync(requestContext.Scopes, cancellationToken).ConfigureAwait(false) : _client.Authenticate(requestContext.Scopes, cancellationToken); - + AccessToken result = await _client.AuthenticateAsync(async, requestContext.Scopes, cancellationToken).ConfigureAwait(false); return scope.Succeeded(result); } catch (Exception e) diff --git a/sdk/identity/Azure.Identity/src/MsiType.cs b/sdk/identity/Azure.Identity/src/MsiType.cs deleted file mode 100644 index d8cca612f49a6..0000000000000 --- a/sdk/identity/Azure.Identity/src/MsiType.cs +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - - -namespace Azure.Identity -{ - internal enum MsiType - { - Unknown = 0, - Imds = 1, - AppService = 2, - CloudShell = 3, - Unavailable = 4 - } -} diff --git a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialImdsLiveTests.cs b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialImdsLiveTests.cs index d68e81a78ca82..49731f8b225fd 100644 --- a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialImdsLiveTests.cs +++ b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialImdsLiveTests.cs @@ -77,7 +77,9 @@ private ManagedIdentityCredential CreateManagedIdentityCredential(string clientI var pipeline = CredentialPipeline.GetInstance(options); // if we're in playback mode we need to mock the ImdsAvailable call since we won't be able to open a connection - var client = (Mode == RecordedTestMode.Playback) ? new MockManagedIdentityClient(pipeline, clientId) { ImdsAvailableFunc = _ => true } : new ManagedIdentityClient(pipeline, clientId); + var client = (Mode == RecordedTestMode.Playback) + ? new MockManagedIdentityClient(pipeline, clientId) { AuthRequestBuilderFactory = () => new ImdsManagedIdentitySource(pipeline.HttpPipeline, clientId) } + : new ManagedIdentityClient(pipeline, clientId); var cred = new ManagedIdentityCredential(pipeline, client); diff --git a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs index 48a2830d99087..31765c45df7ad 100644 --- a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs @@ -39,7 +39,9 @@ public async Task VerifyImdsRequestMockAsync() var pipeline = CredentialPipeline.GetInstance(options); - ManagedIdentityCredential credential = InstrumentClient(new ManagedIdentityCredential(pipeline, new MockManagedIdentityClient(pipeline, "mock-client-id") { ImdsAvailableFunc = _ => true })); + var client = new MockManagedIdentityClient(pipeline, "mock-client-id") { AuthRequestBuilderFactory = () => new ImdsManagedIdentitySource(pipeline.HttpPipeline, "mock-client-id") }; + + ManagedIdentityCredential credential = InstrumentClient(new ManagedIdentityCredential(pipeline, client)); AccessToken actualToken = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)); @@ -78,7 +80,9 @@ public async Task VerifyImdsRequestWithClientIdMockAsync() var pipeline = CredentialPipeline.GetInstance(options); - ManagedIdentityCredential credential = InstrumentClient(new ManagedIdentityCredential(pipeline, new MockManagedIdentityClient(pipeline, "mock-client-id") { ImdsAvailableFunc = _ => true })); + var client = new MockManagedIdentityClient(pipeline, "mock-client-id") { AuthRequestBuilderFactory = () => new ImdsManagedIdentitySource(pipeline.HttpPipeline, "mock-client-id") }; + + ManagedIdentityCredential credential = InstrumentClient(new ManagedIdentityCredential(pipeline, client)); AccessToken actualToken = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)); @@ -102,30 +106,7 @@ public async Task VerifyImdsRequestWithClientIdMockAsync() [NonParallelizable] [Test] - public void VerifyImdsAvailableUserCanceledMockAsync() - { - using (new TestEnvVar("MSI_ENDPOINT", null)) - using (new TestEnvVar("MSI_SECRET", null)) - { - var mockTransport = new MockTransport(); - - var options = new TokenCredentialOptions() { Transport = mockTransport }; - - CancellationTokenSource cancellationSource = new CancellationTokenSource(); - - cancellationSource.Cancel(); - - var pipeline = CredentialPipeline.GetInstance(options); - - ManagedIdentityCredential credential = InstrumentClient(new ManagedIdentityCredential(pipeline, new MockManagedIdentityClient(pipeline, "mock-client-id") { ImdsAvailableFunc = ct => { ct.ThrowIfCancellationRequested(); return true; } })); - - Assert.CatchAsync(async () => await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default), cancellationSource.Token)); - } - } - - [NonParallelizable] - [Test] - public async Task VerifyAppServiceMsiRequestMockAsync() + public async Task VerifyAppService2017RequestMockAsync() { using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) using (new TestEnvVar("MSI_SECRET", "mock-msi-secret")) @@ -164,7 +145,7 @@ public async Task VerifyAppServiceMsiRequestMockAsync() [NonParallelizable] [Test] - public async Task VerifyAppServiceMsiRequestWithClientIdMockAsync() + public async Task VerifyAppService2017RequestWithClientIdMockAsync() { using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) using (new TestEnvVar("MSI_SECRET", "mock-msi-secret")) @@ -203,6 +184,69 @@ public async Task VerifyAppServiceMsiRequestWithClientIdMockAsync() } } + [NonParallelizable] + [Test] + public async Task VerifyAppService2019RequestMockAsync() + { + using (new TestEnvVar("IDENTITY_ENDPOINT", "https://identity.endpoint/")) + using (new TestEnvVar("IDENTITY_HEADER", "mock-identity-header")) + { + var expectedToken = "mock-access-token"; + var response = new MockResponse(200); + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"3600\" }}"); + + var mockTransport = new MockTransport(response); + var options = new TokenCredentialOptions { Transport = mockTransport }; + + ManagedIdentityCredential credential = InstrumentClient(new ManagedIdentityCredential(options: options)); + AccessToken token = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)); + + Assert.AreEqual(expectedToken, token.Token); + + MockRequest request = mockTransport.Requests[0]; + Assert.IsTrue(request.Uri.ToString().StartsWith(EnvironmentVariables.IdentityEndpoint)); + + string query = request.Uri.Query; + Assert.IsTrue(query.Contains("api-version=2019-08-01")); + Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); + Assert.IsTrue(request.Headers.TryGetValue("X-IDENTITY-HEADER", out string identityHeader)); + + Assert.AreEqual(EnvironmentVariables.IdentityHeader, identityHeader); + } + } + + [NonParallelizable] + [Test] + public async Task VerifyAppService2019RequestWithClientIdMockAsync() + { + using (new TestEnvVar("IDENTITY_ENDPOINT", "https://identity.endpoint/")) + using (new TestEnvVar("IDENTITY_HEADER", "mock-identity-header")) + { + var expectedToken = "mock-access-token"; + var response = new MockResponse(200); + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"3600\" }}"); + + var mockTransport = new MockTransport(response); + var options = new TokenCredentialOptions { Transport = mockTransport }; + + ManagedIdentityCredential credential = InstrumentClient(new ManagedIdentityCredential("mock-client-id", options)); + AccessToken actualToken = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.SingleRequest; + Assert.IsTrue(request.Uri.ToString().StartsWith(EnvironmentVariables.IdentityEndpoint)); + + string query = request.Uri.Query; + Assert.IsTrue(query.Contains("api-version=2019-08-01")); + Assert.IsTrue(query.Contains("client_id=mock-client-id")); + Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); + Assert.IsTrue(request.Headers.TryGetValue("X-IDENTITY-HEADER", out string identityHeader)); + + Assert.AreEqual(EnvironmentVariables.IdentityHeader, identityHeader); + } + } + [NonParallelizable] [Test] public async Task VerifyCloudShellMsiRequestMockAsync() @@ -298,7 +342,7 @@ public async Task VerifyCloudShellMsiRequestWithClientIdMockAsync() [Test] public async Task VerifyMsiUnavailableCredentialException() { - var mockClient = new MockManagedIdentityClient { MsiTypeFactory = () => MsiType.Unavailable }; + var mockClient = new MockManagedIdentityClient { AuthRequestBuilderFactory = () => default }; var credential = InstrumentClient(new ManagedIdentityCredential(CredentialPipeline.GetInstance(null), mockClient)); @@ -312,7 +356,7 @@ public async Task VerifyMsiUnavailableCredentialException() [Test] public async Task VerifyClientGetMsiTypeThrows() { - var mockClient = new MockManagedIdentityClient { MsiTypeFactory = () => throw new MockClientException("message") }; + var mockClient = new MockManagedIdentityClient { AuthRequestBuilderFactory = () => throw new MockClientException("message") }; var credential = InstrumentClient(new ManagedIdentityCredential(CredentialPipeline.GetInstance(null), mockClient)); @@ -326,7 +370,7 @@ public async Task VerifyClientGetMsiTypeThrows() [Test] public async Task VerifyClientAuthenticateThrows() { - var mockClient = new MockManagedIdentityClient { MsiTypeFactory = () => MsiType.Imds, TokenFactory = () => throw new MockClientException("message") }; + var mockClient = new MockManagedIdentityClient { AuthRequestBuilderFactory = () => new ImdsManagedIdentitySource(default, default), TokenFactory = () => throw new MockClientException("message") }; var credential = InstrumentClient(new ManagedIdentityCredential(CredentialPipeline.GetInstance(null), mockClient)); diff --git a/sdk/identity/Azure.Identity/tests/Mock/MockManagedIdentityClient.cs b/sdk/identity/Azure.Identity/tests/Mock/MockManagedIdentityClient.cs index 1b8f175207116..0410ba8057930 100644 --- a/sdk/identity/Azure.Identity/tests/Mock/MockManagedIdentityClient.cs +++ b/sdk/identity/Azure.Identity/tests/Mock/MockManagedIdentityClient.cs @@ -27,70 +27,14 @@ public MockManagedIdentityClient(CredentialPipeline pipeline, string clientId) { } - public Func MsiTypeFactory { get; set; } + public Func AuthRequestBuilderFactory { get; set; } public Func TokenFactory { get; set; } - public Func ImdsAvailableFunc { get; set; } + public override ValueTask AuthenticateAsync(bool async, string[] scopes, CancellationToken cancellationToken) + => TokenFactory != null ? new ValueTask(TokenFactory()) : base.AuthenticateAsync(async, scopes, cancellationToken); - public override AccessToken Authenticate(string[] scopes, CancellationToken cancellationToken) - { - if (TokenFactory != null) - { - return TokenFactory(); - } - - return base.Authenticate(scopes, cancellationToken); - } - - public override Task AuthenticateAsync(string[] scopes, CancellationToken cancellationToken) - { - if (TokenFactory != null) - { - return Task.FromResult(TokenFactory()); - } - - return base.AuthenticateAsync(scopes, cancellationToken); - } - - protected override MsiType GetMsiType(CancellationToken cancellationToken) - { - if (MsiTypeFactory != null) - { - return MsiTypeFactory(); - } - - return base.GetMsiType(cancellationToken); - } - - protected override Task GetMsiTypeAsync(CancellationToken cancellationToken) - { - if (MsiTypeFactory != null) - { - return Task.FromResult(MsiTypeFactory()); - } - - return base.GetMsiTypeAsync(cancellationToken); - } - - protected override bool ImdsAvailable(CancellationToken cancellationToken) - { - if (ImdsAvailableFunc != null) - { - return ImdsAvailableFunc(cancellationToken); - } - - return base.ImdsAvailable(cancellationToken); - } - - protected override Task ImdsAvailableAsync(CancellationToken cancellationToken) - { - if (ImdsAvailableFunc != null) - { - return Task.FromResult(ImdsAvailableFunc(cancellationToken)); - } - - return base.ImdsAvailableAsync(cancellationToken); - } + private protected override ValueTask GetManagedIdentitySourceAsync(bool async, CancellationToken cancellationToken) + => AuthRequestBuilderFactory != null ? new ValueTask(AuthRequestBuilderFactory()) : base.GetManagedIdentitySourceAsync(async, cancellationToken); } }