diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointRoutingMode.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointRoutingMode.cs new file mode 100644 index 000000000..aefa0202d --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointRoutingMode.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.SignalR +{ + public enum EndpointRoutingMode + { + /// + /// Choose endpoint randomly by weight. + /// The weight is defined as (the remaining connection quota / the connection capacity). + /// This is the default mode. + /// + Weighted, + + /// + /// Choose the endpoint with least connection count. + /// This mode distributes connections evenly among endpoints. + /// + LeastConnection, + + /// + /// Choose the endpoint randomly + /// + Random, + } +} \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs b/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs index 5f98d718b..c185fccb3 100644 --- a/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs +++ b/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs @@ -6,29 +6,41 @@ using System.Linq; using Microsoft.AspNetCore.Http; using Microsoft.Azure.SignalR.Common; +using Microsoft.Extensions.Options; namespace Microsoft.Azure.SignalR { internal class DefaultEndpointRouter : DefaultMessageRouter, IEndpointRouter { + private readonly EndpointRoutingMode _mode; + + public DefaultEndpointRouter(IOptions options) + { + _mode = options?.Value.EndpointRoutingMode ?? EndpointRoutingMode.Weighted; + } + /// - /// Randomly select from the available endpoints + /// Select an endpoint for negotiate request according to the mode /// /// The http context of the incoming request /// All the available endpoints - /// public ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable endpoints) { // get primary endpoints snapshot var availableEndpoints = GetNegotiateEndpoints(endpoints); - return GetEndpointAccordingToWeight(availableEndpoints); + return _mode switch + { + EndpointRoutingMode.Random => GetEndpointRandomly(availableEndpoints), + EndpointRoutingMode.LeastConnection => GetEndpointWithLeastConnection(availableEndpoints), + _ => GetEndpointAccordingToWeight(availableEndpoints), + }; } /// /// Only primary endpoints will be returned by client /negotiate /// If no primary endpoint is available, promote one secondary endpoint /// - /// The availbale endpoints + /// The available endpoints private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable endpoints) { var primary = endpoints.Where(s => s.Online && s.EndpointType == EndpointType.Primary).ToArray(); @@ -49,8 +61,7 @@ private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable end /// /// Choose endpoint randomly by weight. - /// The weight is defined as the remaining connection quota. - /// The least weight is set to 1. So instance with no connection quota still has chance. + /// The weight is defined as (the remaining connection quota / the connection capacity). /// private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] availableEndpoints) { @@ -69,7 +80,7 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available var remain = endpointMetrics.ConnectionCapacity - (endpointMetrics.ClientConnectionCount + endpointMetrics.ServerConnectionCount); - var weight = remain > 0 ? remain : 1; + var weight = Math.Max((int)((double)remain / endpointMetrics.ConnectionCapacity * 1000), 1); totalCapacity += weight; we[i] = totalCapacity; } @@ -78,6 +89,34 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available return availableEndpoints[Array.FindLastIndex(we, x => x <= index) + 1]; } + + /// + /// Choose endpoint with least connection count + /// + private ServiceEndpoint GetEndpointWithLeastConnection(ServiceEndpoint[] availableEndpoints) + { + //first check if weight is available or necessary + if (availableEndpoints.Any(endpoint => endpoint.EndpointMetrics.ConnectionCapacity == 0) || + availableEndpoints.Length == 1) + { + return GetEndpointRandomly(availableEndpoints); + } + + var leastConnectionCount = int.MaxValue; + var index = 0; + for (var i = 0; i < availableEndpoints.Length; i++) + { + var endpointMetrics = availableEndpoints[i].EndpointMetrics; + var connectionCount = endpointMetrics.ClientConnectionCount + endpointMetrics.ServerConnectionCount; + if (connectionCount < leastConnectionCount) + { + leastConnectionCount = connectionCount; + index = i; + } + } + + return availableEndpoints[index]; + } private static ServiceEndpoint GetEndpointRandomly(ServiceEndpoint[] availableEndpoints) { diff --git a/src/Microsoft.Azure.SignalR/EndpointRouters/EndpointRouterDecorator.cs b/src/Microsoft.Azure.SignalR/EndpointRouters/EndpointRouterDecorator.cs index 3618e9361..851f6967e 100644 --- a/src/Microsoft.Azure.SignalR/EndpointRouters/EndpointRouterDecorator.cs +++ b/src/Microsoft.Azure.SignalR/EndpointRouters/EndpointRouterDecorator.cs @@ -12,7 +12,7 @@ public class EndpointRouterDecorator : IEndpointRouter public EndpointRouterDecorator(IEndpointRouter router = null) { - _inner = router ?? new DefaultEndpointRouter(); + _inner = router ?? new DefaultEndpointRouter(null); } public virtual ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable endpoints) diff --git a/src/Microsoft.Azure.SignalR/ServiceOptions.cs b/src/Microsoft.Azure.SignalR/ServiceOptions.cs index 9160b147a..f301a6909 100644 --- a/src/Microsoft.Azure.SignalR/ServiceOptions.cs +++ b/src/Microsoft.Azure.SignalR/ServiceOptions.cs @@ -121,5 +121,11 @@ public int ConnectionCount /// Gets or sets a function which accepts and returns a bitmask combining one or more values that specify what transports the service should use to receive HTTP requests. /// public Func TransportTypeDetector { get; set; } = null; + + /// + /// Gets or sets the default endpoint routing mode when using multiple endpoints. + /// by default. + /// + public EndpointRoutingMode EndpointRoutingMode { get; set; } = EndpointRoutingMode.Weighted; } } diff --git a/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs b/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs index 48cb54859..e4e4e8ffc 100644 --- a/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs @@ -3,6 +3,8 @@ using System; using System.Collections.Generic; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; using Xunit; namespace Microsoft.Azure.SignalR.Tests @@ -10,26 +12,72 @@ namespace Microsoft.Azure.SignalR.Tests public class EndpointRouterTests { [Fact] - public void TestDefaultEndpointWeightedRouter() + public void TestDefaultEndpointRouterWeightedMode() { - const int loops = 1000; + var drt = GetEndpointRouter(EndpointRoutingMode.Weighted); + + const int loops = 20; + var context = new RandomContext(); + + const string small = "small_instance", large = "large_instance"; + var uSmall = GenerateServiceEndpoint(10, 0, 9, small); + var uLarge = GenerateServiceEndpoint(1000, 0, 900, large); + var el = new List() { uLarge, uSmall }; + context.BenchTest(loops, () => + { + var ep = drt.GetNegotiateEndpoint(null, el); + ep.EndpointMetrics.ClientConnectionCount++; + return ep.Name; + }); + var uLargeCount = context.GetCount(large); + const int smallVar = 3; + var uSmallCount = context.GetCount(small); + Assert.True(uLargeCount is >= loops - smallVar and <= loops); + Assert.True(uSmallCount is >= 1 and <= smallVar); + context.Reset(); + } + + [Fact] + public void TestDefaultEndpointRouterLeastConnectionMode() + { + var drt = GetEndpointRouter(EndpointRoutingMode.LeastConnection); + + const int loops = 10; var context = new RandomContext(); - var drt = new DefaultEndpointRouter(); - const string u1Full = "u1_full", u1Empty = "u1_empty"; - var u1F = GenerateServiceEndpoint(1000, 10, 990, u1Full); - var u1E = GenerateServiceEndpoint(1000, 10, 0, u1Empty); - var el = new List() { u1E, u1F }; + const string small = "small_instance", large = "large_instance"; + var uSmall = GenerateServiceEndpoint(100, 0, 90, small); + var uLarge = GenerateServiceEndpoint(1000, 0, 200, large); + var el = new List() { uLarge, uSmall }; context.BenchTest(loops, () => - drt.GetNegotiateEndpoint(null, el).Name); - var u1ECount = context.GetCount(u1Empty); - const int smallVar = 10; - Assert.True(u1ECount is > loops - smallVar and <= loops); - var u1FCount = context.GetCount(u1Full); - Assert.True(u1FCount <= smallVar); + { + var ep = drt.GetNegotiateEndpoint(null, el); + ep.EndpointMetrics.ClientConnectionCount++; + return ep.Name; + }); + var uLargeCount = context.GetCount(large); + var uSmallCount = context.GetCount(small); + Assert.Equal(0, uLargeCount); + Assert.Equal(10, uSmallCount); context.Reset(); } + private static IEndpointRouter GetEndpointRouter(EndpointRoutingMode mode) + { + var config = new ConfigurationBuilder().Build(); + var serviceProvider = new ServiceCollection() + .AddSignalR() + .AddAzureSignalR( + o => + { + o.EndpointRoutingMode = mode; + }) + .Services + .AddSingleton(config) + .BuildServiceProvider(); + + return serviceProvider.GetRequiredService(); + } private static ServiceEndpoint GenerateServiceEndpoint(int capacity, int serverConnectionCount, int clientConnectionCount, string name) diff --git a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs index cd2a1c60e..1d009414d 100644 --- a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs @@ -245,7 +245,7 @@ public async Task TestContainerWithOneEndpointWithAllDisconnectedConnectionThrow { var endpoint = new ServiceEndpoint(ConnectionString1); var sem = new TestServiceEndpointManager(endpoint); - var router = new DefaultEndpointRouter(); + var router = new DefaultEndpointRouter(null); var container = new TestMultiEndpointServiceConnectionContainer("hub", e => new TestServiceConnectionContainer(new List {