diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs index 47f9697e0..676326386 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs @@ -51,6 +51,17 @@ public ServiceLifetimeManager( } } + public override Task OnConnectedAsync(HubConnectionContext connection) + { + var userIdFeature = connection.Features.Get(); + if (userIdFeature != null) + { + connection.UserIdentifier = userIdFeature.UserId; + connection.Features.Set(null); + } + return base.OnConnectedAsync(connection); + } + public override async Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(connectionId)) diff --git a/src/Microsoft.Azure.SignalR/Internals/ServiceUserIdFeature.cs b/src/Microsoft.Azure.SignalR/Internals/ServiceUserIdFeature.cs new file mode 100644 index 000000000..235fce1f3 --- /dev/null +++ b/src/Microsoft.Azure.SignalR/Internals/ServiceUserIdFeature.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.AspNetCore.SignalR; + +namespace Microsoft.Azure.SignalR +{ + /// + /// When clients negotiate with Management SDK and connect to SignalR server, the might not work as the user Id is set directly in the Management SDK. + /// To make have the valid value in this case, we should set it before the server can access it. is the only chance we can set the value. However, we cannot access the as ASRS system claims're trimmed there. is the place where we can store the user Id. + /// https://github.com/dotnet/aspnetcore/blob/v6.0.9/src/SignalR/server/Core/src/HubConnectionHandler.cs#L132-L141 + /// + internal class ServiceUserIdFeature + { + public string UserId { get; } + + public ServiceUserIdFeature(string userId) + { + UserId = userId; + } + } +} diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs index 73c5092df..877166c2c 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs @@ -134,7 +134,7 @@ public ClientConnectionContext(OpenConnectionMessage serviceMessage, Action(this); @@ -247,6 +247,12 @@ private FeatureCollection BuildFeatures() features.Set(this); features.Set(this); features.Set(this); + + var userIdClaim = serviceMessage.Claims?.FirstOrDefault(c => c.Type == Constants.ClaimType.UserId); + if (userIdClaim != default) + { + features.Set(new ServiceUserIdFeature(userIdClaim.Value)); + } return features; } diff --git a/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs new file mode 100644 index 000000000..7edd7cce3 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Security.Claims; +using Xunit; + +namespace Microsoft.Azure.SignalR.Tests +{ + public class ClientConnectionContextFacts + { + [Fact] + public void SetUserIdFeatureTest() + { + var claims = new Claim[] { new(Constants.ClaimType.UserId, "testUser") }; + var connection = new ClientConnectionContext(new("connectionId", claims)); + var feature = connection.Features.Get(); + Assert.NotNull(feature); + Assert.Equal("testUser", feature.UserId); + } + + [Fact] + public void DoNotSetUserIdFeatureWithoutUserIdClaimTest() + { + var connection = new ClientConnectionContext(new("connectionId", Array.Empty())); + var feature = connection.Features.Get(); + Assert.Null(feature); + } + } +} diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs index 4c8506b71..fdb69f6b7 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs @@ -225,6 +225,32 @@ public async void TestSendConnectionAsyncisOverwrittenWhenClientConnectionExiste Assert.True(false); } + [Fact] + public async void SetUserIdTest() + { + var connectionContext = new TestConnectionContext(); + connectionContext.Features.Set(new ServiceUserIdFeature("testUser")); + + var hubConnectionContext = new HubConnectionContext(connectionContext, new(), NullLoggerFactory.Instance); + var serviceLifetimeManager = MockLifetimeManager(new TestServiceConnectionManager()); + await serviceLifetimeManager.OnConnectedAsync(hubConnectionContext); + + Assert.Equal("testUser", hubConnectionContext.UserIdentifier); + } + + [Fact] + public async void DoNotSetUserIdWithoutFeatureTest() + { + var connectionContext = new TestConnectionContext(); + + var hubConnectionContext = new HubConnectionContext(connectionContext, new(), NullLoggerFactory.Instance); + var serviceLifetimeManager = MockLifetimeManager(new TestServiceConnectionManager()); + await serviceLifetimeManager.OnConnectedAsync(hubConnectionContext); + + Assert.Null(hubConnectionContext.UserIdentifier); + Assert.Null(hubConnectionContext.Features.Get()); + } + private HubLifetimeManager MockLifetimeManager(IServiceConnectionManager serviceConnectionManager, IClientConnectionManager clientConnectionManager = null, IBlazorDetector blazorDetector = null) { clientConnectionManager ??= new ClientConnectionManager();