Skip to content

Commit

Permalink
Fix #13687: Move to latest App Services MSI API version (#14770)
Browse files Browse the repository at this point in the history
* Fix #13687: Move to latest App Services MSI API version

* Delete unused type

* Move request builders creation into factory methods

* Move maximum of code into IManagedIdentitySource implementations

* Fix formatting

* Update sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs

Typo fix

Co-authored-by: Charles Lowell <chlowe@microsoft.com>

Co-authored-by: Charles Lowell <chlowe@microsoft.com>
  • Loading branch information
AlexanderSher and chlowell authored Sep 4, 2020
1 parent 772c659 commit 8d38ce2
Show file tree
Hide file tree
Showing 14 changed files with 591 additions and 530 deletions.
13 changes: 1 addition & 12 deletions sdk/core/Azure.Core/src/Shared/ClientDiagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,9 @@ public RequestFailedException CreateRequestFailedExceptionWithContent(
return exception;
}

public ValueTask<string> CreateRequestFailedMessageAsync(Response response, string? message = null, string? errorCode= null, IDictionary<string, string>? additionalInfo = null)
{
return CreateRequestFailedMessageAsync(response, message, errorCode, additionalInfo, true);
}

public string CreateRequestFailedMessage(Response response, string? message = null, string? errorCode = null, IDictionary<string, string>? additionalInfo = null)
{
return CreateRequestFailedMessageAsync(response, message, errorCode, additionalInfo, false).EnsureCompleted();
}

private async ValueTask<string> CreateRequestFailedMessageAsync(Response response, string? message, string? errorCode, IDictionary<string, string>? additionalInfo, bool async)
public async ValueTask<string> CreateRequestFailedMessageAsync(Response response, string? message, string? errorCode, IDictionary<string, string>? additionalInfo, bool async)
{
var content = await ReadContentAsync(response, async).ConfigureAwait(false);

return CreateRequestFailedMessageWithContent(response, message, content, errorCode, additionalInfo);
}

Expand Down
50 changes: 50 additions & 0 deletions sdk/core/Azure.Core/src/Shared/TaskExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ namespace Azure.Core.Pipeline
{
internal static class TaskExtensions
{
public static WithCancellationTaskAwaitable AwaitWithCancellation(this Task task, CancellationToken cancellationToken)
=> new WithCancellationTaskAwaitable(task, cancellationToken);

public static WithCancellationTaskAwaitable<T> AwaitWithCancellation<T>(this Task<T> task, CancellationToken cancellationToken)
=> new WithCancellationTaskAwaitable<T>(task, cancellationToken);

Expand Down Expand Up @@ -141,6 +144,20 @@ private static void VerifyTaskCompleted(bool isCompleted)
#pragma warning restore AZC0107 // Do not call public asynchronous method in synchronous scope.
}

public readonly struct WithCancellationTaskAwaitable
{
private readonly CancellationToken _cancellationToken;
private readonly ConfiguredTaskAwaitable _awaitable;

public WithCancellationTaskAwaitable(Task task, CancellationToken cancellationToken)
{
_awaitable = task.ConfigureAwait(false);
_cancellationToken = cancellationToken;
}

public WithCancellationTaskAwaiter GetAwaiter() => new WithCancellationTaskAwaiter(_awaitable.GetAwaiter(), _cancellationToken);
}

public readonly struct WithCancellationTaskAwaitable<T>
{
private readonly CancellationToken _cancellationToken;
Expand Down Expand Up @@ -169,6 +186,39 @@ public WithCancellationValueTaskAwaitable(ValueTask<T> task, CancellationToken c
public WithCancellationValueTaskAwaiter<T> GetAwaiter() => new WithCancellationValueTaskAwaiter<T>(_awaitable.GetAwaiter(), _cancellationToken);
}

public readonly struct WithCancellationTaskAwaiter : ICriticalNotifyCompletion
{
private readonly CancellationToken _cancellationToken;
private readonly ConfiguredTaskAwaitable.ConfiguredTaskAwaiter _taskAwaiter;

public WithCancellationTaskAwaiter(ConfiguredTaskAwaitable.ConfiguredTaskAwaiter awaiter, CancellationToken cancellationToken)
{
_taskAwaiter = awaiter;
_cancellationToken = cancellationToken;
}

public bool IsCompleted => _taskAwaiter.IsCompleted || _cancellationToken.IsCancellationRequested;

public void OnCompleted(Action continuation) => _taskAwaiter.OnCompleted(WrapContinuation(continuation));

public void UnsafeOnCompleted(Action continuation) => _taskAwaiter.UnsafeOnCompleted(WrapContinuation(continuation));

public void GetResult()
{
Debug.Assert(IsCompleted);
if (!_taskAwaiter.IsCompleted)
{
_cancellationToken.ThrowIfCancellationRequested();
}
_taskAwaiter.GetResult();
}

private Action WrapContinuation(in Action originalContinuation)
=> _cancellationToken.CanBeCanceled
? new WithCancellationContinuationWrapper(originalContinuation, _cancellationToken).Continuation
: originalContinuation;
}

public readonly struct WithCancellationTaskAwaiter<T> : ICriticalNotifyCompletion
{
private readonly CancellationToken _cancellationToken;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Globalization;
using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;

namespace Azure.Identity
{
internal class AppServiceV2017ManagedIdentitySource : IManagedIdentitySource
{
// MSI Constants. Docs for MSI are available here https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity
private const string AppServiceMsiApiVersion = "2017-09-01";
private const string MsiEndpointInvalidUriError = "The environment variable MSI_ENDPOINT contains an invalid Uri.";

private readonly HttpPipeline _pipeline;
private readonly Uri _endpoint;
private readonly string _secret;
private readonly string _clientId;

public static IManagedIdentitySource TryCreate(HttpPipeline pipeline, string clientId)
{
string msiEndpoint = EnvironmentVariables.MsiEndpoint;
string msiSecret = EnvironmentVariables.MsiSecret;

// if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the MsiType is AppService
if (string.IsNullOrEmpty(msiEndpoint) || string.IsNullOrEmpty(msiSecret))
{
return default;
}

Uri endpointUri;
try
{
endpointUri = new Uri(msiEndpoint);
}
catch (FormatException ex)
{
throw new AuthenticationFailedException(MsiEndpointInvalidUriError, ex);
}

return new AppServiceV2017ManagedIdentitySource(pipeline, endpointUri, msiSecret, clientId);
}

private AppServiceV2017ManagedIdentitySource(HttpPipeline pipeline, Uri endpoint, string secret, string clientId)
{
_pipeline = pipeline;
_endpoint = endpoint;
_secret = secret;
_clientId = clientId;
}

public Request CreateRequest(string[] scopes)
{
// covert the scopes to a resource string
string resource = ScopeUtilities.ScopesToResource(scopes);

Request request = _pipeline.CreateRequest();

request.Method = RequestMethod.Get;
request.Headers.Add("secret", _secret);
request.Uri.Reset(_endpoint);
request.Uri.AppendQuery("api-version", AppServiceMsiApiVersion);
request.Uri.AppendQuery("resource", resource);

if (!string.IsNullOrEmpty(_clientId))
{
request.Uri.AppendQuery("clientid", _clientId);
}

return request;
}

public AccessToken GetAccessTokenFromJson(in JsonElement jsonAccessToken, in JsonElement jsonExpiresOn)
{
// AppService version 2017-09-01 sends expires_on as a string formatted datetimeoffset
if (DateTimeOffset.TryParse(jsonExpiresOn.GetString(), CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset expiresOn))
{
return new AccessToken(jsonAccessToken.GetString(), expiresOn);
}

throw new AuthenticationFailedException(ManagedIdentityClient.AuthenticationResponseInvalidFormatError);
}

public ValueTask HandleFailedRequestAsync(Response response, ClientDiagnostics diagnostics, bool async) => new ValueTask();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;

namespace Azure.Identity
{
internal class AppServiceV2019ManagedIdentitySource : IManagedIdentitySource
{
private const string AppServiceMsiApiVersion = "2019-08-01";
private const string IdentityEndpointInvalidUriError = "The environment variable IDENTITY_ENDPOINT contains an invalid Uri.";

private readonly HttpPipeline _pipeline;
private readonly Uri _endpoint;
private readonly string _secret;
private readonly string _clientId;

public static IManagedIdentitySource TryCreate(HttpPipeline pipeline, string clientId)
{
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
string identityHeader = EnvironmentVariables.IdentityHeader;

if (string.IsNullOrEmpty(identityEndpoint) || string.IsNullOrEmpty(identityHeader))
{
return default;
}

Uri endpointUri;
try
{
endpointUri = new Uri(identityEndpoint);
}
catch (FormatException ex)
{
throw new AuthenticationFailedException(IdentityEndpointInvalidUriError, ex);
}

return new AppServiceV2019ManagedIdentitySource(pipeline, endpointUri, identityHeader, clientId);
}

private AppServiceV2019ManagedIdentitySource(HttpPipeline pipeline, Uri endpoint, string secret, string clientId)
{
_pipeline = pipeline;
_endpoint = endpoint;
_secret = secret;
_clientId = clientId;
}

public Request CreateRequest(string[] scopes)
{
// covert the scopes to a resource string
string resource = ScopeUtilities.ScopesToResource(scopes);

Request request = _pipeline.CreateRequest();

request.Method = RequestMethod.Get;
request.Headers.Add("X-IDENTITY-HEADER", _secret);
request.Uri.Reset(_endpoint);
request.Uri.AppendQuery("api-version", AppServiceMsiApiVersion);
request.Uri.AppendQuery("resource", resource);

if (!string.IsNullOrEmpty(_clientId))
{
request.Uri.AppendQuery("client_id", _clientId);
}

return request;
}

public AccessToken GetAccessTokenFromJson(in JsonElement jsonAccessToken, in JsonElement jsonExpiresOn)
{
// the seconds from epoch may be returned as a Json number or a Json string which is a number
// depending on the environment. If neither of these are the case we throw an AuthException.
if (jsonExpiresOn.ValueKind == JsonValueKind.Number && jsonExpiresOn.TryGetInt64(out long expiresOnSec) ||
jsonExpiresOn.ValueKind == JsonValueKind.String && long.TryParse(jsonExpiresOn.GetString(), out expiresOnSec))
{
return new AccessToken(jsonAccessToken.GetString(), DateTimeOffset.FromUnixTimeSeconds(expiresOnSec));
}

throw new AuthenticationFailedException(ManagedIdentityClient.AuthenticationResponseInvalidFormatError);
}

public ValueTask HandleFailedRequestAsync(Response response, ClientDiagnostics diagnostics, bool async) => new ValueTask();
}
}
92 changes: 92 additions & 0 deletions sdk/identity/Azure.Identity/src/CloudShellManagedIdentitySource.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;

namespace Azure.Identity
{
internal class CloudShellManagedIdentitySource : IManagedIdentitySource
{
private readonly HttpPipeline _pipeline;
private readonly Uri _endpoint;
private readonly string _clientId;
private const string MsiEndpointInvalidUriError = "The environment variable MSI_ENDPOINT contains an invalid Uri.";

public static IManagedIdentitySource TryCreate(HttpPipeline pipeline, string clientId)
{
string msiEndpoint = EnvironmentVariables.MsiEndpoint;

// if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell
if (string.IsNullOrEmpty(msiEndpoint))
{
return default;
}

Uri endpointUri;
try
{
endpointUri = new Uri(msiEndpoint);
}
catch (FormatException ex)
{
throw new AuthenticationFailedException(MsiEndpointInvalidUriError, ex);
}

return new CloudShellManagedIdentitySource(pipeline, endpointUri, clientId);
}

private CloudShellManagedIdentitySource(HttpPipeline pipeline, Uri endpoint, string clientId)
{
_pipeline = pipeline;
_endpoint = endpoint;
_clientId = clientId;
}

public Request CreateRequest(string[] scopes)
{
// covert the scopes to a resource string
string resource = ScopeUtilities.ScopesToResource(scopes);

Request request = _pipeline.CreateRequest();

request.Method = RequestMethod.Post;

request.Headers.Add(HttpHeader.Common.FormUrlEncodedContentType);

request.Uri.Reset(_endpoint);

request.Headers.Add("Metadata", "true");

var bodyStr = $"resource={Uri.EscapeDataString(resource)}";

if (!string.IsNullOrEmpty(_clientId))
{
bodyStr += $"&client_id={Uri.EscapeDataString(_clientId)}";
}

ReadOnlyMemory<byte> content = Encoding.UTF8.GetBytes(bodyStr).AsMemory();
request.Content = RequestContent.Create(content);
return request;
}

public AccessToken GetAccessTokenFromJson(in JsonElement jsonAccessToken, in JsonElement jsonExpiresOn)
{
// the seconds from epoch may be returned as a Json number or a Json string which is a number
// depending on the environment. If neither of these are the case we throw an AuthException.
if (jsonExpiresOn.ValueKind == JsonValueKind.Number && jsonExpiresOn.TryGetInt64(out long expiresOnSec) ||
jsonExpiresOn.ValueKind == JsonValueKind.String && long.TryParse(jsonExpiresOn.GetString(), out expiresOnSec))
{
return new AccessToken(jsonAccessToken.GetString(), DateTimeOffset.FromUnixTimeSeconds(expiresOnSec));
}

throw new AuthenticationFailedException(ManagedIdentityClient.AuthenticationResponseInvalidFormatError);
}

public ValueTask HandleFailedRequestAsync(Response response, ClientDiagnostics diagnostics, bool async) => new ValueTask();
}
}
2 changes: 2 additions & 0 deletions sdk/identity/Azure.Identity/src/EnvironmentVariables.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ internal class EnvironmentVariables
public static string ClientSecret => Environment.GetEnvironmentVariable("AZURE_CLIENT_SECRET");
public static string ClientCertificatePath => Environment.GetEnvironmentVariable("AZURE_CLIENT_CERTIFICATE_PATH");

public static string IdentityEndpoint => Environment.GetEnvironmentVariable("IDENTITY_ENDPOINT");
public static string IdentityHeader => Environment.GetEnvironmentVariable("IDENTITY_HEADER");
public static string MsiEndpoint => Environment.GetEnvironmentVariable("MSI_ENDPOINT");
public static string MsiSecret => Environment.GetEnvironmentVariable("MSI_SECRET");

Expand Down
17 changes: 17 additions & 0 deletions sdk/identity/Azure.Identity/src/IManagedIdentitySource.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;

namespace Azure.Identity
{
internal interface IManagedIdentitySource
{
Request CreateRequest(string[] scopes);
AccessToken GetAccessTokenFromJson(in JsonElement jsonAccessToken, in JsonElement jsonExpiresOn);
ValueTask HandleFailedRequestAsync(Response response, ClientDiagnostics diagnostics, bool async);
}
}
Loading

0 comments on commit 8d38ce2

Please sign in to comment.