From b41b6b52315e9b02ffb2e921e0f94f79ba81ed7b Mon Sep 17 00:00:00 2001 From: Zitong Yang Date: Wed, 28 Sep 2022 10:20:06 +0800 Subject: [PATCH] Fix HubConnectionContext.UserIdentifier is null When clients negotiatie with Management SDK and connect to SignalR server, IUserIdProvider might not work as the user ID is set directly in the Management SDK. To make HubConnectionContext.UserIdentifier have the valid value in this case, we should set it before the server accesses it. HubLifetimeManager{THub}.OnConnectedAsync(HubConnectionContext) is the only chance we can set the value. However, we cannot access the Constants.ClaimType.UserId as ASRS system claims are trimmed there. HubConnectionContext.Features is the place where we can store the user Id. The following code is the injection point. https://github.com/dotnet/aspnetcore/blob/v6.0.9/src/SignalR/server/Core/src/HubConnectionHandler.cs#L132-L141 Fixes #1679 --- .../HubHost/ServiceLifetimeManager.cs | 11 +++++++ .../Internals/IServiceUserNameFeature.cs | 17 +++++++++++ .../Internals/ServiceUserIdFeature.cs | 15 ++++++++++ .../ClientConnectionContext.cs | 10 +++++-- .../ClientConnectionContextFacts.cs | 30 +++++++++++++++++++ .../ServiceLifetimeManagerFacts.cs | 25 ++++++++++++++++ 6 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 src/Microsoft.Azure.SignalR/Internals/IServiceUserNameFeature.cs create mode 100644 src/Microsoft.Azure.SignalR/Internals/ServiceUserIdFeature.cs create mode 100644 test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs index 47f9697e0..1f62113b9 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs @@ -51,6 +51,17 @@ public ServiceLifetimeManager( } } + public override Task OnConnectedAsync(HubConnectionContext connection) + { + var userIdFeature = connection.Features.Get(); + if (userIdFeature != null) + { + connection.UserIdentifier = userIdFeature.UserId; + connection.Features.Set(null); + } + return base.OnConnectedAsync(connection); + } + public override async Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(connectionId)) diff --git a/src/Microsoft.Azure.SignalR/Internals/IServiceUserNameFeature.cs b/src/Microsoft.Azure.SignalR/Internals/IServiceUserNameFeature.cs new file mode 100644 index 000000000..35f8241db --- /dev/null +++ b/src/Microsoft.Azure.SignalR/Internals/IServiceUserNameFeature.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.AspNetCore.SignalR; + +namespace Microsoft.Azure.SignalR +{ + /// + /// When clients negotiate with Management SDK and connect to SignalR server, the might not work as the user Id is set directly in the Management SDK. + /// To make have the valid value in this case, we should set it before the server can access it. is the only chance we can set the value. However, we cannot access the as ASRS system claims're trimmed there. is the place where we can store the user Id. + /// https://github.com/dotnet/aspnetcore/blob/v6.0.9/src/SignalR/server/Core/src/HubConnectionHandler.cs#L132-L141 + /// + internal interface IServiceUserIdFeature + { + string UserId { get; } + } +} diff --git a/src/Microsoft.Azure.SignalR/Internals/ServiceUserIdFeature.cs b/src/Microsoft.Azure.SignalR/Internals/ServiceUserIdFeature.cs new file mode 100644 index 000000000..812fa0066 --- /dev/null +++ b/src/Microsoft.Azure.SignalR/Internals/ServiceUserIdFeature.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.SignalR +{ + internal class ServiceUserIdFeature : IServiceUserIdFeature + { + public string UserId { get; } + + public ServiceUserIdFeature(string userId) + { + UserId = userId; + } + } +} diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs index 24263c794..bbfd17573 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs @@ -131,7 +131,7 @@ public ClientConnectionContext(OpenConnectionMessage serviceMessage, Action(this); @@ -244,6 +244,12 @@ private FeatureCollection BuildFeatures() features.Set(this); features.Set(this); features.Set(this); + + var userIdClaim = serviceMessage.Claims.FirstOrDefault(c => c.Type == Constants.ClaimType.UserId); + if (userIdClaim != default) + { + features.Set(new ServiceUserIdFeature(userIdClaim.Value)); + } return features; } diff --git a/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs new file mode 100644 index 000000000..ad7fce474 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/ClientConnectionContextFacts.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Security.Claims; +using Xunit; + +namespace Microsoft.Azure.SignalR.Tests +{ + public class ClientConnectionContextFacts + { + [Fact] + public void SetUserIdFeatureTest() + { + var claims = new Claim[] { new(Constants.ClaimType.UserId, "testUser") }; + var connection = new ClientConnectionContext(new("connectionId", claims)); + var feature = connection.Features.Get(); + Assert.NotNull(feature); + Assert.Equal("testUser", feature.UserId); + } + + [Fact] + public void DoNotSetUserIdFeatureWithoutUserIdClaimTest() + { + var connection = new ClientConnectionContext(new("connectionId", Array.Empty())); + var feature = connection.Features.Get(); + Assert.Null(feature); + } + } +} diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs index 4c8506b71..ccaa048b7 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs @@ -225,6 +225,31 @@ public async void TestSendConnectionAsyncisOverwrittenWhenClientConnectionExiste Assert.True(false); } + [Fact] + public async void SetUserIdTest() + { + var connectionContext = new TestConnectionContext(); + connectionContext.Features.Set(new ServiceUserIdFeature("testUser")); + + var hubConnectionContext = new HubConnectionContext(connectionContext, new(), NullLoggerFactory.Instance); + var serviceLifetimeManager = MockLifetimeManager(new TestServiceConnectionManager()); + await serviceLifetimeManager.OnConnectedAsync(hubConnectionContext); + + Assert.Equal("testUser", hubConnectionContext.UserIdentifier); + } + + [Fact] + public async void DoNotSetUserIdWithoutFeatureTest() + { + var connectionContext = new TestConnectionContext(); + + var hubConnectionContext = new HubConnectionContext(connectionContext, new(), NullLoggerFactory.Instance); + var serviceLifetimeManager = MockLifetimeManager(new TestServiceConnectionManager()); + await serviceLifetimeManager.OnConnectedAsync(hubConnectionContext); + + Assert.Null(hubConnectionContext.UserIdentifier); + } + private HubLifetimeManager MockLifetimeManager(IServiceConnectionManager serviceConnectionManager, IClientConnectionManager clientConnectionManager = null, IBlazorDetector blazorDetector = null) { clientConnectionManager ??= new ClientConnectionManager();