Skip to content

Commit

Permalink
Use internal class instead of interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-Sindo committed Sep 28, 2022
1 parent b41b6b5 commit 66de17a
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 43 deletions.
4 changes: 2 additions & 2 deletions src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ public ServiceLifetimeManager(

public override Task OnConnectedAsync(HubConnectionContext connection)
{
var userIdFeature = connection.Features.Get<IServiceUserIdFeature>();
var userIdFeature = connection.Features.Get<ServiceUserIdFeature>();
if (userIdFeature != null)
{
connection.UserIdentifier = userIdFeature.UserId;
connection.Features.Set<IServiceUserIdFeature>(null);
connection.Features.Set<ServiceUserIdFeature>(null);
}
return base.OnConnectedAsync(connection);
}
Expand Down
17 changes: 0 additions & 17 deletions src/Microsoft.Azure.SignalR/Internals/IServiceUserNameFeature.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Microsoft.AspNetCore.SignalR;

namespace Microsoft.Azure.SignalR
{
internal class ServiceUserIdFeature : IServiceUserIdFeature
/// <summary>
/// When clients negotiate with Management SDK and connect to SignalR server, the <see cref="IUserIdProvider"/> might not work as the user Id is set directly in the Management SDK.
/// To make <see cref="HubConnectionContext.UserIdentifier"/> have the valid value in this case, we should set it before the server can access it. <see cref="HubLifetimeManager{THub}.OnConnectedAsync(HubConnectionContext)"/> is the only chance we can set the value. However, we cannot access the <see cref="Constants.ClaimType.UserId"/> as ASRS system claims're trimmed there. <see cref="HubConnectionContext.Features"/> is the place where we can store the user Id.
/// https://github.com/dotnet/aspnetcore/blob/v6.0.9/src/SignalR/server/Core/src/HubConnectionHandler.cs#L132-L141
/// </summary>
internal class ServiceUserIdFeature
{
public string UserId { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ private FeatureCollection BuildFeatures(OpenConnectionMessage serviceMessage)
var userIdClaim = serviceMessage.Claims.FirstOrDefault(c => c.Type == Constants.ClaimType.UserId);
if (userIdClaim != default)
{
features.Set<IServiceUserIdFeature>(new ServiceUserIdFeature(userIdClaim.Value));
features.Set(new ServiceUserIdFeature(userIdClaim.Value));
}
return features;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public void SetUserIdFeatureTest()
{
var claims = new Claim[] { new(Constants.ClaimType.UserId, "testUser") };
var connection = new ClientConnectionContext(new("connectionId", claims));
var feature = connection.Features.Get<IServiceUserIdFeature>();
var feature = connection.Features.Get<ServiceUserIdFeature>();
Assert.NotNull(feature);
Assert.Equal("testUser", feature.UserId);
}
Expand All @@ -23,7 +23,7 @@ public void SetUserIdFeatureTest()
public void DoNotSetUserIdFeatureWithoutUserIdClaimTest()
{
var connection = new ClientConnectionContext(new("connectionId", Array.Empty<Claim>()));
var feature = connection.Features.Get<IServiceUserIdFeature>();
var feature = connection.Features.Get<ServiceUserIdFeature>();
Assert.Null(feature);
}
}
Expand Down
41 changes: 21 additions & 20 deletions test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ namespace Microsoft.Azure.SignalR.Tests
{
public class ServiceLifetimeManagerFacts
{
private static readonly List<string> TestUsers = new List<string> {"user1", "user2"};
private static readonly List<string> TestUsers = new List<string> { "user1", "user2" };

private static readonly List<string> TestGroups = new List<string> {"group1", "group2"};
private static readonly List<string> TestGroups = new List<string> { "group1", "group2" };

private const string MockProtocol = "blazorpack";

private const string TestMethod = "TestMethod";

private static readonly object[] TestArgs = {"TestArgs"};
private static readonly object[] TestArgs = { "TestArgs" };

private static readonly List<string> TestConnectionIds = new List<string> {"connection1", "connection2"};
private static readonly List<string> TestConnectionIds = new List<string> { "connection1", "connection2" };

private static readonly IHubProtocolResolver HubProtocolResolver =
new DefaultHubProtocolResolver(new IHubProtocol[]
Expand Down Expand Up @@ -229,7 +229,7 @@ public async void TestSendConnectionAsyncisOverwrittenWhenClientConnectionExiste
public async void SetUserIdTest()
{
var connectionContext = new TestConnectionContext();
connectionContext.Features.Set<IServiceUserIdFeature>(new ServiceUserIdFeature("testUser"));
connectionContext.Features.Set(new ServiceUserIdFeature("testUser"));

var hubConnectionContext = new HubConnectionContext(connectionContext, new(), NullLoggerFactory.Instance);
var serviceLifetimeManager = MockLifetimeManager(new TestServiceConnectionManager<TestHub>());
Expand All @@ -248,6 +248,7 @@ public async void DoNotSetUserIdWithoutFeatureTest()
await serviceLifetimeManager.OnConnectedAsync(hubConnectionContext);

Assert.Null(hubConnectionContext.UserIdentifier);
Assert.Null(hubConnectionContext.Features.Get<ServiceUserIdFeature>());
}

private HubLifetimeManager<TestHub> MockLifetimeManager(IServiceConnectionManager<TestHub> serviceConnectionManager, IClientConnectionManager clientConnectionManager = null, IBlazorDetector blazorDetector = null)
Expand Down Expand Up @@ -324,41 +325,41 @@ private static void VerifyServiceMessage(string methodName, ServiceMessage servi
switch (methodName)
{
case "SendAllAsync":
Assert.Null(((BroadcastDataMessage) serviceMessage).ExcludedList);
Assert.Null(((BroadcastDataMessage)serviceMessage).ExcludedList);
break;
case "SendAllExceptAsync":
Assert.Equal(TestConnectionIds, ((BroadcastDataMessage) serviceMessage).ExcludedList);
Assert.Equal(TestConnectionIds, ((BroadcastDataMessage)serviceMessage).ExcludedList);
break;
case "SendConnectionAsync":
Assert.Equal(TestConnectionIds[0], ((MultiConnectionDataMessage) serviceMessage).ConnectionList[0]);
Assert.Equal(TestConnectionIds[0], ((MultiConnectionDataMessage)serviceMessage).ConnectionList[0]);
break;
case "SendConnectionsAsync":
Assert.Equal(TestConnectionIds, ((MultiConnectionDataMessage) serviceMessage).ConnectionList);
Assert.Equal(TestConnectionIds, ((MultiConnectionDataMessage)serviceMessage).ConnectionList);
break;
case "SendGroupAsync":
Assert.Equal(TestGroups[0], ((GroupBroadcastDataMessage) serviceMessage).GroupName);
Assert.Null(((GroupBroadcastDataMessage) serviceMessage).ExcludedList);
Assert.Equal(TestGroups[0], ((GroupBroadcastDataMessage)serviceMessage).GroupName);
Assert.Null(((GroupBroadcastDataMessage)serviceMessage).ExcludedList);
break;
case "SendGroupsAsync":
Assert.Equal(TestGroups, ((MultiGroupBroadcastDataMessage) serviceMessage).GroupList);
Assert.Equal(TestGroups, ((MultiGroupBroadcastDataMessage)serviceMessage).GroupList);
break;
case "SendGroupExceptAsync":
Assert.Equal(TestGroups[0], ((GroupBroadcastDataMessage) serviceMessage).GroupName);
Assert.Equal(TestConnectionIds, ((GroupBroadcastDataMessage) serviceMessage).ExcludedList);
Assert.Equal(TestGroups[0], ((GroupBroadcastDataMessage)serviceMessage).GroupName);
Assert.Equal(TestConnectionIds, ((GroupBroadcastDataMessage)serviceMessage).ExcludedList);
break;
case "SendUserAsync":
Assert.Equal(TestUsers[0], ((UserDataMessage) serviceMessage).UserId);
Assert.Equal(TestUsers[0], ((UserDataMessage)serviceMessage).UserId);
break;
case "SendUsersAsync":
Assert.Equal(TestUsers, ((MultiUserDataMessage) serviceMessage).UserList);
Assert.Equal(TestUsers, ((MultiUserDataMessage)serviceMessage).UserList);
break;
case "AddToGroupAsync":
Assert.Equal(TestConnectionIds[0], ((JoinGroupWithAckMessage) serviceMessage).ConnectionId);
Assert.Equal(TestGroups[0], ((JoinGroupWithAckMessage) serviceMessage).GroupName);
Assert.Equal(TestConnectionIds[0], ((JoinGroupWithAckMessage)serviceMessage).ConnectionId);
Assert.Equal(TestGroups[0], ((JoinGroupWithAckMessage)serviceMessage).GroupName);
break;
case "RemoveFromGroupAsync":
Assert.Equal(TestConnectionIds[0], ((LeaveGroupWithAckMessage) serviceMessage).ConnectionId);
Assert.Equal(TestGroups[0], ((LeaveGroupWithAckMessage) serviceMessage).GroupName);
Assert.Equal(TestConnectionIds[0], ((LeaveGroupWithAckMessage)serviceMessage).ConnectionId);
Assert.Equal(TestGroups[0], ((LeaveGroupWithAckMessage)serviceMessage).GroupName);
break;
default:
break;
Expand Down

0 comments on commit 66de17a

Please sign in to comment.