From 2dff80e8eacd014c412b4a8416e1f0fdde6742e4 Mon Sep 17 00:00:00 2001 From: Terence Fan Date: Wed, 11 Sep 2024 12:29:12 +0800 Subject: [PATCH] Reformat aad related codes (#2041) --- .../ServiceEndpointProvider.cs | 2 +- .../{Endpoints => Auth}/AccessKey.cs | 0 .../Auth/LogHelper.cs | 2 + .../MicrosoftEntra/AccessKeySynchronizer.cs | 98 +++++++++ .../MicrosoftEntra/IAccessKeySynchronizer.cs | 13 ++ .../MicrosoftEntra/MicrosoftEntraAccessKey.cs | 201 +++++++++++++++++ .../MicrosoftEntraTokenProvider.cs | 19 ++ .../Auth/MicrosoftEntraTokenProvider.cs | 20 -- .../Endpoints/AccessKeyForMicrosoftEntra.cs | 202 ------------------ .../Endpoints/AccessKeySynchronizer.cs | 99 --------- .../Endpoints/IAccessKeySynchronizer.cs | 14 -- .../Endpoints/ServiceEndpoint.cs | 2 +- .../ServiceConnectionBase.cs | 10 +- .../Utilities/ConnectionStringParser.cs | 18 +- .../Utilities/RestApiAccessTokenGenerator.cs | 2 +- .../ServiceEndpointProvider.cs | 2 +- .../Auth/AccessKeyForMicrosoftEntraTests.cs | 20 +- .../Auth/AuthUtilityTests.cs | 2 +- .../Auth/ConnectionStringParserTests.cs | 9 +- .../Auth/MicrosoftEntraApplicationTests.cs | 4 +- .../Endpoints/AccessKeySynchronizerFacts.cs | 1 - .../ServiceEndpointFacts.cs | 11 +- .../ServiceMessageTests.cs | 12 +- 23 files changed, 379 insertions(+), 384 deletions(-) rename src/Microsoft.Azure.SignalR.Common/{Endpoints => Auth}/AccessKey.cs (100%) create mode 100644 src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs create mode 100644 src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/IAccessKeySynchronizer.cs create mode 100644 src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs create mode 100644 src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraTokenProvider.cs delete mode 100644 src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntraTokenProvider.cs delete mode 100644 src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeyForMicrosoftEntra.cs delete mode 100644 src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs delete mode 100644 src/Microsoft.Azure.SignalR.Common/Endpoints/IAccessKeySynchronizer.cs diff --git a/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs b/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs index 1728b143d..dcdc5dc0b 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs @@ -95,7 +95,7 @@ public string GetServerEndpoint(string hubName) public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId) { - if (_accessKey is AccessKeyForMicrosoftEntra key) + if (_accessKey is MicrosoftEntraAccessKey key) { return new MicrosoftEntraTokenProvider(key); } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Auth/AccessKey.cs similarity index 100% rename from src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKey.cs rename to src/Microsoft.Azure.SignalR.Common/Auth/AccessKey.cs diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/LogHelper.cs b/src/Microsoft.Azure.SignalR.Common/Auth/LogHelper.cs index 00ae780a7..3b892c47c 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/LogHelper.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/LogHelper.cs @@ -5,6 +5,8 @@ using System; using System.Globalization; +namespace Microsoft.Azure.SignalR; + internal class LogHelper { public static ArgumentNullException LogArgumentNullException(string name) diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs new file mode 100644 index 000000000..8ea9a550f --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Azure.SignalR; + +internal sealed class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable +{ + private readonly ConcurrentDictionary _endpoints = new ConcurrentDictionary(ReferenceEqualityComparer.Instance); + + private readonly ILoggerFactory _factory; + + private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1)); + + internal IEnumerable AccessKeyForMicrosoftEntraList => _endpoints.Select(e => e.Key.AccessKey).OfType(); + + public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true) + { + } + + /// + /// Test only. + /// + internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start) + { + if (start) + { + _ = UpdateAccessKeyAsync(); + } + _factory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory)); + } + + public void AddServiceEndpoint(ServiceEndpoint endpoint) + { + if (endpoint.AccessKey is MicrosoftEntraAccessKey key) + { + _ = key.UpdateAccessKeyAsync(); + } + _endpoints.TryAdd(endpoint, null); + } + + public void Dispose() => _timer.Stop(); + + public void UpdateServiceEndpoints(IEnumerable endpoints) + { + _endpoints.Clear(); + foreach (var endpoint in endpoints) + { + AddServiceEndpoint(endpoint); + } + } + + internal bool ContainsServiceEndpoint(ServiceEndpoint e) => _endpoints.ContainsKey(e); + + internal int ServiceEndpointsCount() => _endpoints.Count; + + private async Task UpdateAccessKeyAsync() + { + using (_timer) + { + _timer.Start(); + + while (await _timer) + { + foreach (var key in AccessKeyForMicrosoftEntraList) + { + _ = key.UpdateAccessKeyAsync(); + } + } + } + } + + private sealed class ReferenceEqualityComparer : IEqualityComparer + { + internal static readonly ReferenceEqualityComparer Instance = new ReferenceEqualityComparer(); + + private ReferenceEqualityComparer() + { + } + + public bool Equals(ServiceEndpoint x, ServiceEndpoint y) + { + return ReferenceEquals(x, y); + } + + public int GetHashCode(ServiceEndpoint obj) + { + return RuntimeHelpers.GetHashCode(obj); + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/IAccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/IAccessKeySynchronizer.cs new file mode 100644 index 000000000..0bf4fc255 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/IAccessKeySynchronizer.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; + +namespace Microsoft.Azure.SignalR; + +internal interface IAccessKeySynchronizer +{ + public void AddServiceEndpoint(ServiceEndpoint endpoint); + + public void UpdateServiceEndpoints(IEnumerable endpoints); +} diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs new file mode 100644 index 000000000..76df07598 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Security.Claims; +using System.Threading; +using System.Threading.Tasks; + +using Azure.Core; + +using Microsoft.Azure.SignalR.Common; + +using Newtonsoft.Json.Linq; + +namespace Microsoft.Azure.SignalR; + +internal class MicrosoftEntraAccessKey : AccessKey +{ + internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100); + + private const int GetAccessKeyIntervalInMinute = 55; + + private const int GetAccessKeyMaxRetryTimes = 3; + + private const int GetMicrosoftEntraTokenMaxRetryTimes = 3; + + private const string DefaultScope = "https://signalr.azure.com/.default"; + + private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { DefaultScope }); + + private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(GetAccessKeyIntervalInMinute); + + private static readonly TimeSpan GetAccessKeyIntervalWhenUnauthorized = TimeSpan.FromMinutes(5); + + private static readonly TimeSpan GetAccessKeyRetryInterval = TimeSpan.FromSeconds(3); + + private readonly TaskCompletionSource _initializedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + private volatile bool _isAuthorized = false; + + private Exception _lastException; + + private DateTime _lastUpdatedTime = DateTime.MinValue; + + public bool IsAuthorized + { + get => _isAuthorized; + private set + { + if (value) + { + _lastException = null; + } + _lastUpdatedTime = DateTime.UtcNow; + _isAuthorized = value; + _initializedTcs.TrySetResult(null); + } + } + + public TokenCredential TokenCredential { get; } + + internal string GetAccessKeyUrl { get; } + + internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(GetAccessKeyIntervalInMinute * 2); + + private Task InitializedTask => _initializedTcs.Task; + + public MicrosoftEntraAccessKey(Uri endpoint, TokenCredential credential, Uri serverEndpoint = null) : base(endpoint) + { + var authorizeUri = (serverEndpoint ?? endpoint).Append("/api/v1/auth/accessKey"); + GetAccessKeyUrl = authorizeUri.AbsoluteUri; + TokenCredential = credential; + } + + public virtual async Task GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default) + { + Exception latest = null; + for (var i = 0; i < GetMicrosoftEntraTokenMaxRetryTimes; i++) + { + try + { + var token = await TokenCredential.GetTokenAsync(DefaultRequestContext, ctoken); + return token.Token; + } + catch (Exception e) + { + latest = e; + } + } + throw latest; + } + + public override async Task GenerateAccessTokenAsync( + string audience, + IEnumerable claims, + TimeSpan lifetime, + AccessTokenAlgorithm algorithm, + CancellationToken ctoken = default) + { + var task = await Task.WhenAny(InitializedTask, ctoken.AsTask()); + + if (task == InitializedTask || InitializedTask.IsCompleted) + { + await task; + return IsAuthorized + ? await base.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm) + : throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential.GetType().Name, _lastException); + } + else + { + throw new TaskCanceledException("Timeout reached when authorizing AzureAD identity."); + } + } + + internal void UpdateAccessKey(string kid, string accessKey) + { + Key = new Tuple(kid, accessKey); + IsAuthorized = true; + } + + internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default) + { + var delta = DateTime.UtcNow - _lastUpdatedTime; + if (IsAuthorized && delta < GetAccessKeyInterval) + { + return; + } + else if (!IsAuthorized && delta < GetAccessKeyIntervalWhenUnauthorized) + { + return; + } + + for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++) + { + var source = new CancellationTokenSource(GetAccessKeyTimeout); + var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, ctoken); + try + { + var token = await GetMicrosoftEntraTokenAsync(linkedSource.Token); + await GetAccessKeyInternalAsync(token, linkedSource.Token); + return; + } + catch (OperationCanceledException e) + { + _lastException = e; + break; + } + catch (Exception e) + { + _lastException = e; + try + { + await Task.Delay(GetAccessKeyRetryInterval, ctoken); + } + catch (OperationCanceledException) + { + break; + } + } + } + + IsAuthorized = false; + } + + private async Task GetAccessKeyInternalAsync(string accessToken, CancellationToken ctoken = default) + { + var api = new RestApiEndpoint(GetAccessKeyUrl, accessToken); + + await new RestClient().SendAsync( + api, + HttpMethod.Get, + handleExpectedResponseAsync: HandleHttpResponseAsync, + cancellationToken: ctoken); + } + + private async Task HandleHttpResponseAsync(HttpResponseMessage response) + { + if (response.StatusCode != HttpStatusCode.OK) + { + return false; + } + + var json = await response.Content.ReadAsStringAsync(); + var obj = JObject.Parse(json); + + if (!obj.TryGetValue("KeyId", out var keyId) || keyId.Type != JTokenType.String) + { + throw new AzureSignalRException("Missing required field."); + } + if (!obj.TryGetValue("AccessKey", out var key) || key.Type != JTokenType.String) + { + throw new AzureSignalRException("Missing required field."); + } + + UpdateAccessKey(keyId.ToString(), key.ToString()); + return true; + } +} diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraTokenProvider.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraTokenProvider.cs new file mode 100644 index 000000000..cf0077523 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraTokenProvider.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.Azure.SignalR; + +internal class MicrosoftEntraTokenProvider : IAccessTokenProvider +{ + private readonly MicrosoftEntraAccessKey _accessKey; + + public MicrosoftEntraTokenProvider(MicrosoftEntraAccessKey accessKey) + { + _accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey)); + } + + public Task ProvideAsync() => _accessKey.GetMicrosoftEntraTokenAsync(); +} diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntraTokenProvider.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntraTokenProvider.cs deleted file mode 100644 index 52f799852..000000000 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntraTokenProvider.cs +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System; -using System.Threading.Tasks; - -namespace Microsoft.Azure.SignalR -{ - internal class MicrosoftEntraTokenProvider : IAccessTokenProvider - { - private readonly AccessKeyForMicrosoftEntra _accessKey; - - public MicrosoftEntraTokenProvider(AccessKeyForMicrosoftEntra accessKey) - { - _accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey)); - } - - public Task ProvideAsync() => _accessKey.GetMicrosoftEntraTokenAsync(); - } -} diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeyForMicrosoftEntra.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeyForMicrosoftEntra.cs deleted file mode 100644 index 5b487b4cb..000000000 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeyForMicrosoftEntra.cs +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Net; -using System.Net.Http; -using System.Security.Claims; -using System.Threading; -using System.Threading.Tasks; - -using Azure.Core; - -using Microsoft.Azure.SignalR.Common; - -using Newtonsoft.Json.Linq; - -namespace Microsoft.Azure.SignalR -{ - internal class AccessKeyForMicrosoftEntra : AccessKey - { - internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100); - - private const int GetAccessKeyIntervalInMinute = 55; - - private const int GetAccessKeyMaxRetryTimes = 3; - - private const int GetMicrosoftEntraTokenMaxRetryTimes = 3; - - private const string DefaultScope = "https://signalr.azure.com/.default"; - - private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { DefaultScope }); - - private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(GetAccessKeyIntervalInMinute); - - private static readonly TimeSpan GetAccessKeyIntervalWhenUnauthorized = TimeSpan.FromMinutes(5); - - private static readonly TimeSpan GetAccessKeyRetryInterval = TimeSpan.FromSeconds(3); - - private readonly TaskCompletionSource _initializedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - private volatile bool _isAuthorized = false; - - private Exception _lastException; - - private DateTime _lastUpdatedTime = DateTime.MinValue; - - public bool IsAuthorized - { - get => _isAuthorized; - private set - { - if (value) - { - _lastException = null; - } - _lastUpdatedTime = DateTime.UtcNow; - _isAuthorized = value; - _initializedTcs.TrySetResult(null); - } - } - - public TokenCredential TokenCredential { get; } - - internal string GetAccessKeyUrl { get; } - - internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(GetAccessKeyIntervalInMinute * 2); - - private Task InitializedTask => _initializedTcs.Task; - - public AccessKeyForMicrosoftEntra(Uri endpoint, TokenCredential credential, Uri serverEndpoint = null) : base(endpoint) - { - var authorizeUri = (serverEndpoint ?? endpoint).Append("/api/v1/auth/accessKey"); - GetAccessKeyUrl = authorizeUri.AbsoluteUri; - TokenCredential = credential; - } - - public virtual async Task GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default) - { - Exception latest = null; - for (var i = 0; i < GetMicrosoftEntraTokenMaxRetryTimes; i++) - { - try - { - var token = await TokenCredential.GetTokenAsync(DefaultRequestContext, ctoken); - return token.Token; - } - catch (Exception e) - { - latest = e; - } - } - throw latest; - } - - public override async Task GenerateAccessTokenAsync( - string audience, - IEnumerable claims, - TimeSpan lifetime, - AccessTokenAlgorithm algorithm, - CancellationToken ctoken = default) - { - var task = await Task.WhenAny(InitializedTask, ctoken.AsTask()); - - if (task == InitializedTask || InitializedTask.IsCompleted) - { - await task; - return IsAuthorized - ? await base.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm) - : throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential.GetType().Name, _lastException); - } - else - { - throw new TaskCanceledException("Timeout reached when authorizing AzureAD identity."); - } - } - - internal void UpdateAccessKey(string kid, string accessKey) - { - Key = new Tuple(kid, accessKey); - IsAuthorized = true; - } - - internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default) - { - var delta = DateTime.UtcNow - _lastUpdatedTime; - if (IsAuthorized && delta < GetAccessKeyInterval) - { - return; - } - else if (!IsAuthorized && delta < GetAccessKeyIntervalWhenUnauthorized) - { - return; - } - - for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++) - { - var source = new CancellationTokenSource(GetAccessKeyTimeout); - var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, ctoken); - try - { - var token = await GetMicrosoftEntraTokenAsync(linkedSource.Token); - await GetAccessKeyInternalAsync(token, linkedSource.Token); - return; - } - catch (OperationCanceledException e) - { - _lastException = e; - break; - } - catch (Exception e) - { - _lastException = e; - try - { - await Task.Delay(GetAccessKeyRetryInterval, ctoken); - } - catch (OperationCanceledException) - { - break; - } - } - } - - IsAuthorized = false; - } - - private async Task GetAccessKeyInternalAsync(string accessToken, CancellationToken ctoken = default) - { - var api = new RestApiEndpoint(GetAccessKeyUrl, accessToken); - - await new RestClient().SendAsync( - api, - HttpMethod.Get, - handleExpectedResponseAsync: HandleHttpResponseAsync, - cancellationToken: ctoken); - } - - private async Task HandleHttpResponseAsync(HttpResponseMessage response) - { - if (response.StatusCode != HttpStatusCode.OK) - { - return false; - } - - var json = await response.Content.ReadAsStringAsync(); - var obj = JObject.Parse(json); - - if (!obj.TryGetValue("KeyId", out var keyId) || keyId.Type != JTokenType.String) - { - throw new AzureSignalRException("Missing required field."); - } - if (!obj.TryGetValue("AccessKey", out var key) || key.Type != JTokenType.String) - { - throw new AzureSignalRException("Missing required field."); - } - - UpdateAccessKey(keyId.ToString(), key.ToString()); - return true; - } - } -} diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs deleted file mode 100644 index fb05f7a67..000000000 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Threading.Tasks; -using Microsoft.Extensions.Logging; - -namespace Microsoft.Azure.SignalR -{ - internal sealed class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable - { - private readonly ConcurrentDictionary _endpoints = new ConcurrentDictionary(ReferenceEqualityComparer.Instance); - - private readonly ILoggerFactory _factory; - - private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1)); - - internal IEnumerable AccessKeyForMicrosoftEntraList => _endpoints.Select(e => e.Key.AccessKey).OfType(); - - public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true) - { - } - - /// - /// Test only. - /// - internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start) - { - if (start) - { - _ = UpdateAccessKeyAsync(); - } - _factory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory)); - } - - public void AddServiceEndpoint(ServiceEndpoint endpoint) - { - if (endpoint.AccessKey is AccessKeyForMicrosoftEntra key) - { - _ = key.UpdateAccessKeyAsync(); - } - _endpoints.TryAdd(endpoint, null); - } - - public void Dispose() => _timer.Stop(); - - public void UpdateServiceEndpoints(IEnumerable endpoints) - { - _endpoints.Clear(); - foreach (var endpoint in endpoints) - { - AddServiceEndpoint(endpoint); - } - } - - internal bool ContainsServiceEndpoint(ServiceEndpoint e) => _endpoints.ContainsKey(e); - - internal int ServiceEndpointsCount() => _endpoints.Count; - - private async Task UpdateAccessKeyAsync() - { - using (_timer) - { - _timer.Start(); - - while (await _timer) - { - foreach (var key in AccessKeyForMicrosoftEntraList) - { - _ = key.UpdateAccessKeyAsync(); - } - } - } - } - - private sealed class ReferenceEqualityComparer : IEqualityComparer - { - internal static readonly ReferenceEqualityComparer Instance = new ReferenceEqualityComparer(); - - private ReferenceEqualityComparer() - { - } - - public bool Equals(ServiceEndpoint x, ServiceEndpoint y) - { - return ReferenceEquals(x, y); - } - - public int GetHashCode(ServiceEndpoint obj) - { - return RuntimeHelpers.GetHashCode(obj); - } - } - } -} diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/IAccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/IAccessKeySynchronizer.cs deleted file mode 100644 index f015891d8..000000000 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/IAccessKeySynchronizer.cs +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System.Collections.Generic; - -namespace Microsoft.Azure.SignalR -{ - internal interface IAccessKeySynchronizer - { - public void AddServiceEndpoint(ServiceEndpoint endpoint); - - public void UpdateServiceEndpoints(IEnumerable endpoints); - } -} diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs index 3d8b18a10..09830d4de 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs @@ -86,7 +86,7 @@ internal AccessKey AccessKey { lock (_lock) { - _accessKey ??= new AccessKeyForMicrosoftEntra(_serviceEndpoint, _tokenCredential, ServerEndpoint); + _accessKey ??= new MicrosoftEntraAccessKey(_serviceEndpoint, _tokenCredential, ServerEndpoint); } } return _accessKey; diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs index 8a544cb06..5eefa9b68 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs @@ -169,7 +169,7 @@ public async Task StartAsync(string target = null) TimerAwaitable syncTimer = null; try { - if (HubEndpoint != null && HubEndpoint.AccessKey is AccessKeyForMicrosoftEntra key) + if (HubEndpoint != null && HubEndpoint.AccessKey is MicrosoftEntraAccessKey key) { syncTimer = new TimerAwaitable(TimeSpan.Zero, DefaultSyncAzureIdentityInterval); _ = UpdateAzureIdentityAsync(key, syncTimer); @@ -405,7 +405,7 @@ private Task OnEventMessageAsync(ServiceEventMessage message) private Task OnAccessKeyMessageAsync(AccessKeyResponseMessage keyMessage) { - if (HubEndpoint.AccessKey is AccessKeyForMicrosoftEntra key) + if (HubEndpoint.AccessKey is MicrosoftEntraAccessKey key) { if (string.IsNullOrEmpty(keyMessage.ErrorType)) { @@ -602,7 +602,7 @@ private async Task ReceiveHandshakeResponseAsync(PipeReader input, Cancell } } - private async Task UpdateAzureIdentityAsync(AccessKeyForMicrosoftEntra key, TimerAwaitable timer) + private async Task UpdateAzureIdentityAsync(MicrosoftEntraAccessKey key, TimerAwaitable timer) { using (timer) { @@ -614,11 +614,11 @@ private async Task UpdateAzureIdentityAsync(AccessKeyForMicrosoftEntra key, Time } } - private async Task SendAccessKeyRequestMessageAsync(AccessKeyForMicrosoftEntra key) + private async Task SendAccessKeyRequestMessageAsync(MicrosoftEntraAccessKey key) { try { - var source = new CancellationTokenSource(AccessKeyForMicrosoftEntra.GetAccessKeyTimeout); + var source = new CancellationTokenSource(MicrosoftEntraAccessKey.GetAccessKeyTimeout); var token = await key.GetMicrosoftEntraTokenAsync(source.Token); var message = new AccessKeyRequestMessage(token); await SafeWriteAsync(message); diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs index d0bf24b8e..94d7b60d1 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs @@ -169,11 +169,11 @@ private static AccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, D { if (dict.TryGetValue(ClientSecretProperty, out var clientSecret)) { - return new AccessKeyForMicrosoftEntra(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); + return new MicrosoftEntraAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); } else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath)) { - return new AccessKeyForMicrosoftEntra(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); + return new MicrosoftEntraAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); } else { @@ -182,12 +182,12 @@ private static AccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, D } else { - return new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(clientId), serverEndpointUri); + return new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri); } } else { - return new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(), serverEndpointUri); + return new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri); } } @@ -200,7 +200,7 @@ private static AccessKey BuildAccessKey(Uri uri, Dictionary dict private static AccessKey BuildAzureAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) { - return new AccessKeyForMicrosoftEntra(uri, new DefaultAzureCredential(), serverEndpointUri); + return new MicrosoftEntraAccessKey(uri, new DefaultAzureCredential(), serverEndpointUri); } private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) @@ -217,11 +217,11 @@ private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, if (dict.TryGetValue(ClientSecretProperty, out var clientSecret)) { - return new AccessKeyForMicrosoftEntra(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); + return new MicrosoftEntraAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); } else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath)) { - return new AccessKeyForMicrosoftEntra(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); + return new MicrosoftEntraAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); } throw new ArgumentException(MissingClientSecretProperty, ClientSecretProperty); } @@ -229,8 +229,8 @@ private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, private static AccessKey BuildAzureMsiAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) { return dict.TryGetValue(ClientIdProperty, out var clientId) - ? new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(clientId), serverEndpointUri) - : new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(), serverEndpointUri); + ? new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri) + : new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri); } private static Dictionary ToDictionary(string connectionString) diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs index 7b90eba94..fcdb4ed8c 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs @@ -27,7 +27,7 @@ public RestApiAccessTokenGenerator(AccessKey accessKey, string serverName = null public Task Generate(string audience, TimeSpan? lifetime = null) { - if (_accessKey is AccessKeyForMicrosoftEntra key) + if (_accessKey is MicrosoftEntraAccessKey key) { return key.GetMicrosoftEntraTokenAsync(); } diff --git a/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointProvider.cs b/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointProvider.cs index 9b3b18735..a2bda2b4f 100644 --- a/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointProvider.cs +++ b/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointProvider.cs @@ -61,7 +61,7 @@ public string GetClientEndpoint(string hubName, string originalPath, string quer public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId) { - if (_accessKey is AccessKeyForMicrosoftEntra key) + if (_accessKey is MicrosoftEntraAccessKey key) { return new MicrosoftEntraTokenProvider(key); } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs index 909511014..1684c1ddd 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs @@ -23,7 +23,7 @@ public class AccessKeyForMicrosoftEntraTests [InlineData("https://a.bc:443", "https://a.bc/api/v1/auth/accessKey")] public void TestExpectedGetAccessKeyUrl(string endpoint, string expectedGetAccessKeyUrl) { - var key = new AccessKeyForMicrosoftEntra(new Uri(endpoint), new DefaultAzureCredential()); + var key = new MicrosoftEntraAccessKey(new Uri(endpoint), new DefaultAzureCredential()); Assert.Equal(expectedGetAccessKeyUrl, key.GetAccessKeyUrl); } @@ -35,7 +35,7 @@ public async Task TestUpdateAccessKey() It.IsAny(), It.IsAny())) .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new AccessKeyForMicrosoftEntra(DefaultEndpoint, mockCredential.Object); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); var audience = "http://localhost/chat"; var claims = Array.Empty(); @@ -68,19 +68,19 @@ public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int time It.IsAny(), It.IsAny())) .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new AccessKeyForMicrosoftEntra(DefaultEndpoint, mockCredential.Object); - var isAuthorizedField = typeof(AccessKeyForMicrosoftEntra).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); + var isAuthorizedField = typeof(MicrosoftEntraAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance); isAuthorizedField.SetValue(key, isAuthorized); Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key)); var lastUpdatedTime = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed); - var lastUpdatedTimeField = typeof(AccessKeyForMicrosoftEntra).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); + var lastUpdatedTimeField = typeof(MicrosoftEntraAccessKey).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); lastUpdatedTimeField.SetValue(key, lastUpdatedTime); - var initializedTcsField = typeof(AccessKeyForMicrosoftEntra).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance); + var initializedTcsField = typeof(MicrosoftEntraAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance); var initializedTcs = (TaskCompletionSource)initializedTcsField.GetValue(key); - var lastExceptionFields = typeof(AccessKeyForMicrosoftEntra).GetField("_lastException", BindingFlags.NonPublic | BindingFlags.Instance); + var lastExceptionFields = typeof(MicrosoftEntraAccessKey).GetField("_lastException", BindingFlags.NonPublic | BindingFlags.Instance); await key.UpdateAccessKeyAsync().OrTimeout(TimeSpan.FromSeconds(30)); var actualLastUpdatedTime = Assert.IsType(lastUpdatedTimeField.GetValue(key)); @@ -109,7 +109,7 @@ public async Task TestInitializeFailed() It.IsAny(), It.IsAny())) .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new AccessKeyForMicrosoftEntra(DefaultEndpoint, mockCredential.Object); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); var audience = "http://localhost/chat"; var claims = Array.Empty(); @@ -132,7 +132,7 @@ public async Task TestUpdateAccessKeyAfterInitializeFailed() It.IsAny(), It.IsAny())) .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new AccessKeyForMicrosoftEntra(DefaultEndpoint, mockCredential.Object); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); var audience = "http://localhost/chat"; var claims = Array.Empty(); @@ -146,7 +146,7 @@ public async Task TestUpdateAccessKeyAfterInitializeFailed() ); Assert.IsType(exception.InnerException); - var lastExceptionFields = typeof(AccessKeyForMicrosoftEntra).GetField("_lastException", BindingFlags.NonPublic | BindingFlags.Instance); + var lastExceptionFields = typeof(MicrosoftEntraAccessKey).GetField("_lastException", BindingFlags.NonPublic | BindingFlags.Instance); Assert.NotNull(lastExceptionFields.GetValue(key)); var (kid, accessKey) = ("foo", DefaultSigningKey); diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs index 1466ee0b9..31a25bce0 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs @@ -41,7 +41,7 @@ public class CachingTestData : IEnumerable public IEnumerator GetEnumerator() { yield return new object[] { new AccessKey("http://localhost:443", SigningKey), true }; - var key = new AccessKeyForMicrosoftEntra(new Uri("http://localhost"), new DefaultAzureCredential()); + var key = new MicrosoftEntraAccessKey(new Uri("http://localhost"), new DefaultAzureCredential()); key.UpdateAccessKey("foo", SigningKey); yield return new object[] { key, false }; } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs index 01a81c8b6..439f43a2b 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using Azure.Identity; - using Xunit; namespace Microsoft.Azure.SignalR.Common.Tests.Auth @@ -101,7 +100,7 @@ public void TestAzureApplication(string connectionString) { var r = ConnectionStringParser.Parse(connectionString); - var key = Assert.IsType(r.AccessKey); + var key = Assert.IsType(r.AccessKey); Assert.IsType(key.TokenCredential); Assert.Same(r.Endpoint, r.AccessKey.Endpoint); Assert.Null(r.Version); @@ -148,7 +147,7 @@ internal void TestDefaultAzureCredential(string expectedEndpoint, string connect var r = ConnectionStringParser.Parse(connectionString); Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var key = Assert.IsType(r.AccessKey); + var key = Assert.IsType(r.AccessKey); Assert.IsType(key.TokenCredential); Assert.Same(r.Endpoint, r.AccessKey.Endpoint); } @@ -165,7 +164,7 @@ internal void TestManagedIdentity(string expectedEndpoint, string connectionStri var r = ConnectionStringParser.Parse(connectionString); Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var key = Assert.IsType(r.AccessKey); + var key = Assert.IsType(r.AccessKey); Assert.IsType(key.TokenCredential); Assert.Same(r.Endpoint, r.AccessKey.Endpoint); Assert.Null(r.ClientEndpoint); @@ -180,7 +179,7 @@ internal void TestManagedIdentity(string expectedEndpoint, string connectionStri internal void TestAzureADWithServerEndpoint(string connectionString, string expectedAuthorizeUrl) { var r = ConnectionStringParser.Parse(connectionString); - var key = Assert.IsType(r.AccessKey); + var key = Assert.IsType(r.AccessKey); Assert.Equal(expectedAuthorizeUrl, key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs index 55c1e689e..5c46bc0b2 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs @@ -26,7 +26,7 @@ public class MicrosoftEntraApplicationTests public async Task TestAcquireAccessToken() { var options = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); - var key = new AccessKeyForMicrosoftEntra(new Uri("https://localhost:8080"), options); + var key = new MicrosoftEntraAccessKey(new Uri("https://localhost:8080"), options); var token = await key.GetMicrosoftEntraTokenAsync(); Assert.NotNull(token); } @@ -73,7 +73,7 @@ public async Task TestGetMicrosoftEntraTokenAndAuthenticate() internal async Task TestAuthenticateAsync() { var options = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); - var key = new AccessKeyForMicrosoftEntra(new Uri("https://localhost:8080"), options); + var key = new MicrosoftEntraAccessKey(new Uri("https://localhost:8080"), options); await key.UpdateAccessKeyAsync(); Assert.True(key.IsAuthorized); diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Endpoints/AccessKeySynchronizerFacts.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Endpoints/AccessKeySynchronizerFacts.cs index 6e35d1e5a..a2bc80411 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Endpoints/AccessKeySynchronizerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Endpoints/AccessKeySynchronizerFacts.cs @@ -1,5 +1,4 @@ using System.Collections.Generic; - using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.Logging.Abstractions; using Xunit; diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs index 37446112b..64257482b 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using Azure.Identity; - using Xunit; namespace Microsoft.Azure.SignalR.Common.Tests @@ -117,7 +116,7 @@ public void TestAzureADConstructor(string url, string expectedEndpoint, int port { var uri = new Uri(url); var serviceEndpoint = new ServiceEndpoint(uri, new DefaultAzureCredential()); - Assert.IsType(serviceEndpoint.AccessKey); + Assert.IsType(serviceEndpoint.AccessKey); Assert.Equal(expectedEndpoint, serviceEndpoint.Endpoint); Assert.Equal("", serviceEndpoint.Name); Assert.Equal(port, serviceEndpoint.AccessKey.Endpoint.Port); @@ -150,7 +149,7 @@ public void TestAzureADConstructorWithKey(string key, string name, EndpointType { var uri = new Uri("http://localhost"); var serviceEndpoint = new ServiceEndpoint(key, uri, new DefaultAzureCredential()); - Assert.IsType(serviceEndpoint.AccessKey); + Assert.IsType(serviceEndpoint.AccessKey); Assert.Equal(name, serviceEndpoint.Name); Assert.Equal(type, serviceEndpoint.EndpointType); TestCopyConstructor(serviceEndpoint); @@ -166,12 +165,12 @@ public void TestAzureADConstructorWithServerEndpoint() { ServerEndpoint = serverEndpoint1 }; - var key = Assert.IsType(endpoint.AccessKey); + var key = Assert.IsType(endpoint.AccessKey); Assert.Same(key, endpoint.AccessKey); Assert.Equal("http://serverEndpoint:123/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); endpoint = new ServiceEndpoint(new Uri(serviceEndpoint), new DefaultAzureCredential(), serverEndpoint: serverEndpoint2); - key = Assert.IsType(endpoint.AccessKey); + key = Assert.IsType(endpoint.AccessKey); Assert.Same(key, endpoint.AccessKey); Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); @@ -179,7 +178,7 @@ public void TestAzureADConstructorWithServerEndpoint() { ServerEndpoint = serverEndpoint2 // property initialize should override constructor param. }; - key = Assert.IsType(endpoint.AccessKey); + key = Assert.IsType(endpoint.AccessKey); Assert.Same(key, endpoint.AccessKey); Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); } diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs index 962d79f58..395f9a201 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs @@ -155,7 +155,7 @@ public async Task TestCloseConnectionMessage() [Theory] [InlineData(typeof(AccessKey))] - [InlineData(typeof(AccessKeyForMicrosoftEntra))] + [InlineData(typeof(MicrosoftEntraAccessKey))] public async Task TestAccessKeyRequestMessage(Type keyType) { var endpoint = MockServiceEndpoint(keyType.Name); @@ -180,7 +180,7 @@ public async Task TestAccessKeyRequestMessage(Type keyType) [Theory] [InlineData(typeof(AccessKey))] - [InlineData(typeof(AccessKeyForMicrosoftEntra))] + [InlineData(typeof(MicrosoftEntraAccessKey))] public async Task TestAccessKeyResponseMessage(Type keyType) { var endpoint = MockServiceEndpoint(keyType.Name); @@ -229,9 +229,9 @@ public async Task TestAccessKeyResponseMessageWithError(int minutesElapsed, int { var endpoint = new TestHubServiceEndpoint(endpoint: new TestServiceEndpoint(new DefaultAzureCredential())); - if (endpoint.AccessKey is AccessKeyForMicrosoftEntra key) + if (endpoint.AccessKey is MicrosoftEntraAccessKey key) { - var field = typeof(AccessKeyForMicrosoftEntra).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); + var field = typeof(MicrosoftEntraAccessKey).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); field.SetValue(key, DateTime.UtcNow - TimeSpan.FromMinutes(minutesElapsed)); } @@ -318,7 +318,7 @@ private ServiceEndpoint MockServiceEndpoint(string keyTypeName) case nameof(AccessKey): return new ServiceEndpoint(LocalConnectionString); - case nameof(AccessKeyForMicrosoftEntra): + case nameof(MicrosoftEntraAccessKey): var endpoint = new ServiceEndpoint(MicrosoftEntraConnectionString); var p = typeof(ServiceEndpoint).GetProperty("AccessKey", BindingFlags.NonPublic | BindingFlags.Instance); p.SetValue(endpoint, new TestAadAccessKey()); @@ -329,7 +329,7 @@ private ServiceEndpoint MockServiceEndpoint(string keyTypeName) } } - private class TestAadAccessKey : AccessKeyForMicrosoftEntra + private class TestAadAccessKey : MicrosoftEntraAccessKey { public string Token { get; } = Guid.NewGuid().ToString();