diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointRoutingMode.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointRoutingMode.cs
new file mode 100644
index 000000000..aefa0202d
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/EndpointRoutingMode.cs
@@ -0,0 +1,26 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace Microsoft.Azure.SignalR
+{
+ public enum EndpointRoutingMode
+ {
+ ///
+ /// Choose endpoint randomly by weight.
+ /// The weight is defined as (the remaining connection quota / the connection capacity).
+ /// This is the default mode.
+ ///
+ Weighted,
+
+ ///
+ /// Choose the endpoint with least connection count.
+ /// This mode distributes connections evenly among endpoints.
+ ///
+ LeastConnection,
+
+ ///
+ /// Choose the endpoint randomly
+ ///
+ Random,
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs b/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs
index 5f98d718b..c185fccb3 100644
--- a/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs
+++ b/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs
@@ -6,29 +6,41 @@
using System.Linq;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Common;
+using Microsoft.Extensions.Options;
namespace Microsoft.Azure.SignalR
{
internal class DefaultEndpointRouter : DefaultMessageRouter, IEndpointRouter
{
+ private readonly EndpointRoutingMode _mode;
+
+ public DefaultEndpointRouter(IOptions options)
+ {
+ _mode = options?.Value.EndpointRoutingMode ?? EndpointRoutingMode.Weighted;
+ }
+
///
- /// Randomly select from the available endpoints
+ /// Select an endpoint for negotiate request according to the mode
///
/// The http context of the incoming request
/// All the available endpoints
- ///
public ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable endpoints)
{
// get primary endpoints snapshot
var availableEndpoints = GetNegotiateEndpoints(endpoints);
- return GetEndpointAccordingToWeight(availableEndpoints);
+ return _mode switch
+ {
+ EndpointRoutingMode.Random => GetEndpointRandomly(availableEndpoints),
+ EndpointRoutingMode.LeastConnection => GetEndpointWithLeastConnection(availableEndpoints),
+ _ => GetEndpointAccordingToWeight(availableEndpoints),
+ };
}
///
/// Only primary endpoints will be returned by client /negotiate
/// If no primary endpoint is available, promote one secondary endpoint
///
- /// The availbale endpoints
+ /// The available endpoints
private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable endpoints)
{
var primary = endpoints.Where(s => s.Online && s.EndpointType == EndpointType.Primary).ToArray();
@@ -49,8 +61,7 @@ private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable end
///
/// Choose endpoint randomly by weight.
- /// The weight is defined as the remaining connection quota.
- /// The least weight is set to 1. So instance with no connection quota still has chance.
+ /// The weight is defined as (the remaining connection quota / the connection capacity).
///
private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] availableEndpoints)
{
@@ -69,7 +80,7 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available
var remain = endpointMetrics.ConnectionCapacity -
(endpointMetrics.ClientConnectionCount +
endpointMetrics.ServerConnectionCount);
- var weight = remain > 0 ? remain : 1;
+ var weight = Math.Max((int)((double)remain / endpointMetrics.ConnectionCapacity * 1000), 1);
totalCapacity += weight;
we[i] = totalCapacity;
}
@@ -78,6 +89,34 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available
return availableEndpoints[Array.FindLastIndex(we, x => x <= index) + 1];
}
+
+ ///
+ /// Choose endpoint with least connection count
+ ///
+ private ServiceEndpoint GetEndpointWithLeastConnection(ServiceEndpoint[] availableEndpoints)
+ {
+ //first check if weight is available or necessary
+ if (availableEndpoints.Any(endpoint => endpoint.EndpointMetrics.ConnectionCapacity == 0) ||
+ availableEndpoints.Length == 1)
+ {
+ return GetEndpointRandomly(availableEndpoints);
+ }
+
+ var leastConnectionCount = int.MaxValue;
+ var index = 0;
+ for (var i = 0; i < availableEndpoints.Length; i++)
+ {
+ var endpointMetrics = availableEndpoints[i].EndpointMetrics;
+ var connectionCount = endpointMetrics.ClientConnectionCount + endpointMetrics.ServerConnectionCount;
+ if (connectionCount < leastConnectionCount)
+ {
+ leastConnectionCount = connectionCount;
+ index = i;
+ }
+ }
+
+ return availableEndpoints[index];
+ }
private static ServiceEndpoint GetEndpointRandomly(ServiceEndpoint[] availableEndpoints)
{
diff --git a/src/Microsoft.Azure.SignalR/EndpointRouters/EndpointRouterDecorator.cs b/src/Microsoft.Azure.SignalR/EndpointRouters/EndpointRouterDecorator.cs
index 3618e9361..851f6967e 100644
--- a/src/Microsoft.Azure.SignalR/EndpointRouters/EndpointRouterDecorator.cs
+++ b/src/Microsoft.Azure.SignalR/EndpointRouters/EndpointRouterDecorator.cs
@@ -12,7 +12,7 @@ public class EndpointRouterDecorator : IEndpointRouter
public EndpointRouterDecorator(IEndpointRouter router = null)
{
- _inner = router ?? new DefaultEndpointRouter();
+ _inner = router ?? new DefaultEndpointRouter(null);
}
public virtual ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable endpoints)
diff --git a/src/Microsoft.Azure.SignalR/ServiceOptions.cs b/src/Microsoft.Azure.SignalR/ServiceOptions.cs
index 9160b147a..f301a6909 100644
--- a/src/Microsoft.Azure.SignalR/ServiceOptions.cs
+++ b/src/Microsoft.Azure.SignalR/ServiceOptions.cs
@@ -121,5 +121,11 @@ public int ConnectionCount
/// Gets or sets a function which accepts and returns a bitmask combining one or more values that specify what transports the service should use to receive HTTP requests.
///
public Func TransportTypeDetector { get; set; } = null;
+
+ ///
+ /// Gets or sets the default endpoint routing mode when using multiple endpoints.
+ /// by default.
+ ///
+ public EndpointRoutingMode EndpointRoutingMode { get; set; } = EndpointRoutingMode.Weighted;
}
}
diff --git a/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs b/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs
index 48cb54859..e4e4e8ffc 100644
--- a/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs
+++ b/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs
@@ -3,6 +3,8 @@
using System;
using System.Collections.Generic;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.DependencyInjection;
using Xunit;
namespace Microsoft.Azure.SignalR.Tests
@@ -10,26 +12,72 @@ namespace Microsoft.Azure.SignalR.Tests
public class EndpointRouterTests
{
[Fact]
- public void TestDefaultEndpointWeightedRouter()
+ public void TestDefaultEndpointRouterWeightedMode()
{
- const int loops = 1000;
+ var drt = GetEndpointRouter(EndpointRoutingMode.Weighted);
+
+ const int loops = 20;
+ var context = new RandomContext();
+
+ const string small = "small_instance", large = "large_instance";
+ var uSmall = GenerateServiceEndpoint(10, 0, 9, small);
+ var uLarge = GenerateServiceEndpoint(1000, 0, 900, large);
+ var el = new List() { uLarge, uSmall };
+ context.BenchTest(loops, () =>
+ {
+ var ep = drt.GetNegotiateEndpoint(null, el);
+ ep.EndpointMetrics.ClientConnectionCount++;
+ return ep.Name;
+ });
+ var uLargeCount = context.GetCount(large);
+ const int smallVar = 3;
+ var uSmallCount = context.GetCount(small);
+ Assert.True(uLargeCount is >= loops - smallVar and <= loops);
+ Assert.True(uSmallCount is >= 1 and <= smallVar);
+ context.Reset();
+ }
+
+ [Fact]
+ public void TestDefaultEndpointRouterLeastConnectionMode()
+ {
+ var drt = GetEndpointRouter(EndpointRoutingMode.LeastConnection);
+
+ const int loops = 10;
var context = new RandomContext();
- var drt = new DefaultEndpointRouter();
- const string u1Full = "u1_full", u1Empty = "u1_empty";
- var u1F = GenerateServiceEndpoint(1000, 10, 990, u1Full);
- var u1E = GenerateServiceEndpoint(1000, 10, 0, u1Empty);
- var el = new List() { u1E, u1F };
+ const string small = "small_instance", large = "large_instance";
+ var uSmall = GenerateServiceEndpoint(100, 0, 90, small);
+ var uLarge = GenerateServiceEndpoint(1000, 0, 200, large);
+ var el = new List() { uLarge, uSmall };
context.BenchTest(loops, () =>
- drt.GetNegotiateEndpoint(null, el).Name);
- var u1ECount = context.GetCount(u1Empty);
- const int smallVar = 10;
- Assert.True(u1ECount is > loops - smallVar and <= loops);
- var u1FCount = context.GetCount(u1Full);
- Assert.True(u1FCount <= smallVar);
+ {
+ var ep = drt.GetNegotiateEndpoint(null, el);
+ ep.EndpointMetrics.ClientConnectionCount++;
+ return ep.Name;
+ });
+ var uLargeCount = context.GetCount(large);
+ var uSmallCount = context.GetCount(small);
+ Assert.Equal(0, uLargeCount);
+ Assert.Equal(10, uSmallCount);
context.Reset();
}
+ private static IEndpointRouter GetEndpointRouter(EndpointRoutingMode mode)
+ {
+ var config = new ConfigurationBuilder().Build();
+ var serviceProvider = new ServiceCollection()
+ .AddSignalR()
+ .AddAzureSignalR(
+ o =>
+ {
+ o.EndpointRoutingMode = mode;
+ })
+ .Services
+ .AddSingleton(config)
+ .BuildServiceProvider();
+
+ return serviceProvider.GetRequiredService();
+ }
private static ServiceEndpoint GenerateServiceEndpoint(int capacity, int serverConnectionCount,
int clientConnectionCount, string name)
diff --git a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs
index cd2a1c60e..1d009414d 100644
--- a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs
+++ b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs
@@ -245,7 +245,7 @@ public async Task TestContainerWithOneEndpointWithAllDisconnectedConnectionThrow
{
var endpoint = new ServiceEndpoint(ConnectionString1);
var sem = new TestServiceEndpointManager(endpoint);
- var router = new DefaultEndpointRouter();
+ var router = new DefaultEndpointRouter(null);
var container = new TestMultiEndpointServiceConnectionContainer("hub",
e => new TestServiceConnectionContainer(new List {