Skip to content

Commit

Permalink
Support different EndpointRoutingMode (#1842)
Browse files Browse the repository at this point in the history
  • Loading branch information
bjqian authored Oct 16, 2023
1 parent cd929a9 commit f595647
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.Azure.SignalR
{
public enum EndpointRoutingMode
{
/// <summary>
/// Choose endpoint randomly by weight.
/// The weight is defined as (the remaining connection quota / the connection capacity).
/// This is the default mode.
/// </summary>
Weighted,

/// <summary>
/// Choose the endpoint with least connection count.
/// This mode distributes connections evenly among endpoints.
/// </summary>
LeastConnection,

/// <summary>
/// Choose the endpoint randomly
/// </summary>
Random,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,41 @@
using System.Linq;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Extensions.Options;

namespace Microsoft.Azure.SignalR
{
internal class DefaultEndpointRouter : DefaultMessageRouter, IEndpointRouter
{
private readonly EndpointRoutingMode _mode;

public DefaultEndpointRouter(IOptions<ServiceOptions> options)
{
_mode = options?.Value.EndpointRoutingMode ?? EndpointRoutingMode.Weighted;
}

/// <summary>
/// Randomly select from the available endpoints
/// Select an endpoint for negotiate request according to the mode
/// </summary>
/// <param name="context">The http context of the incoming request</param>
/// <param name="endpoints">All the available endpoints</param>
/// <returns></returns>
public ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable<ServiceEndpoint> endpoints)
{
// get primary endpoints snapshot
var availableEndpoints = GetNegotiateEndpoints(endpoints);
return GetEndpointAccordingToWeight(availableEndpoints);
return _mode switch
{
EndpointRoutingMode.Random => GetEndpointRandomly(availableEndpoints),
EndpointRoutingMode.LeastConnection => GetEndpointWithLeastConnection(availableEndpoints),
_ => GetEndpointAccordingToWeight(availableEndpoints),
};
}

/// <summary>
/// Only primary endpoints will be returned by client /negotiate
/// If no primary endpoint is available, promote one secondary endpoint
/// </summary>
/// <returns>The availbale endpoints</returns>
/// <returns>The available endpoints</returns>
private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable<ServiceEndpoint> endpoints)
{
var primary = endpoints.Where(s => s.Online && s.EndpointType == EndpointType.Primary).ToArray();
Expand All @@ -49,8 +61,7 @@ private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable<ServiceEndpoint> end

/// <summary>
/// Choose endpoint randomly by weight.
/// The weight is defined as the remaining connection quota.
/// The least weight is set to 1. So instance with no connection quota still has chance.
/// The weight is defined as (the remaining connection quota / the connection capacity).
/// </summary>
private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] availableEndpoints)
{
Expand All @@ -69,7 +80,7 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available
var remain = endpointMetrics.ConnectionCapacity -
(endpointMetrics.ClientConnectionCount +
endpointMetrics.ServerConnectionCount);
var weight = remain > 0 ? remain : 1;
var weight = Math.Max((int)((double)remain / endpointMetrics.ConnectionCapacity * 1000), 1);
totalCapacity += weight;
we[i] = totalCapacity;
}
Expand All @@ -78,6 +89,34 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available

return availableEndpoints[Array.FindLastIndex(we, x => x <= index) + 1];
}

/// <summary>
/// Choose endpoint with least connection count
/// </summary>
private ServiceEndpoint GetEndpointWithLeastConnection(ServiceEndpoint[] availableEndpoints)
{
//first check if weight is available or necessary
if (availableEndpoints.Any(endpoint => endpoint.EndpointMetrics.ConnectionCapacity == 0) ||
availableEndpoints.Length == 1)
{
return GetEndpointRandomly(availableEndpoints);
}

var leastConnectionCount = int.MaxValue;
var index = 0;
for (var i = 0; i < availableEndpoints.Length; i++)
{
var endpointMetrics = availableEndpoints[i].EndpointMetrics;
var connectionCount = endpointMetrics.ClientConnectionCount + endpointMetrics.ServerConnectionCount;
if (connectionCount < leastConnectionCount)
{
leastConnectionCount = connectionCount;
index = i;
}
}

return availableEndpoints[index];
}

private static ServiceEndpoint GetEndpointRandomly(ServiceEndpoint[] availableEndpoints)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class EndpointRouterDecorator : IEndpointRouter

public EndpointRouterDecorator(IEndpointRouter router = null)
{
_inner = router ?? new DefaultEndpointRouter();
_inner = router ?? new DefaultEndpointRouter(null);
}

public virtual ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable<ServiceEndpoint> endpoints)
Expand Down
6 changes: 6 additions & 0 deletions src/Microsoft.Azure.SignalR/ServiceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,11 @@ public int ConnectionCount
/// Gets or sets a function which accepts <see cref="HttpContext"/> and returns a bitmask combining one or more <see cref="HttpTransportType"/> values that specify what transports the service should use to receive HTTP requests.
/// </summary>
public Func<HttpContext, HttpTransportType> TransportTypeDetector { get; set; } = null;

/// <summary>
/// Gets or sets the default endpoint routing mode when using multiple endpoints.
/// <see cref="EndpointRoutingMode.Weighted"/> by default.
/// </summary>
public EndpointRoutingMode EndpointRoutingMode { get; set; } = EndpointRoutingMode.Weighted;
}
}
74 changes: 61 additions & 13 deletions test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,81 @@

using System;
using System.Collections.Generic;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Xunit;

namespace Microsoft.Azure.SignalR.Tests
{
public class EndpointRouterTests
{
[Fact]
public void TestDefaultEndpointWeightedRouter()
public void TestDefaultEndpointRouterWeightedMode()
{
const int loops = 1000;
var drt = GetEndpointRouter(EndpointRoutingMode.Weighted);

const int loops = 20;
var context = new RandomContext();

const string small = "small_instance", large = "large_instance";
var uSmall = GenerateServiceEndpoint(10, 0, 9, small);
var uLarge = GenerateServiceEndpoint(1000, 0, 900, large);
var el = new List<ServiceEndpoint>() { uLarge, uSmall };
context.BenchTest(loops, () =>
{
var ep = drt.GetNegotiateEndpoint(null, el);
ep.EndpointMetrics.ClientConnectionCount++;
return ep.Name;
});
var uLargeCount = context.GetCount(large);
const int smallVar = 3;
var uSmallCount = context.GetCount(small);
Assert.True(uLargeCount is >= loops - smallVar and <= loops);
Assert.True(uSmallCount is >= 1 and <= smallVar);
context.Reset();
}

[Fact]
public void TestDefaultEndpointRouterLeastConnectionMode()
{
var drt = GetEndpointRouter(EndpointRoutingMode.LeastConnection);

const int loops = 10;
var context = new RandomContext();
var drt = new DefaultEndpointRouter();

const string u1Full = "u1_full", u1Empty = "u1_empty";
var u1F = GenerateServiceEndpoint(1000, 10, 990, u1Full);
var u1E = GenerateServiceEndpoint(1000, 10, 0, u1Empty);
var el = new List<ServiceEndpoint>() { u1E, u1F };
const string small = "small_instance", large = "large_instance";
var uSmall = GenerateServiceEndpoint(100, 0, 90, small);
var uLarge = GenerateServiceEndpoint(1000, 0, 200, large);
var el = new List<ServiceEndpoint>() { uLarge, uSmall };
context.BenchTest(loops, () =>
drt.GetNegotiateEndpoint(null, el).Name);
var u1ECount = context.GetCount(u1Empty);
const int smallVar = 10;
Assert.True(u1ECount is > loops - smallVar and <= loops);
var u1FCount = context.GetCount(u1Full);
Assert.True(u1FCount <= smallVar);
{
var ep = drt.GetNegotiateEndpoint(null, el);
ep.EndpointMetrics.ClientConnectionCount++;
return ep.Name;
});
var uLargeCount = context.GetCount(large);
var uSmallCount = context.GetCount(small);
Assert.Equal(0, uLargeCount);
Assert.Equal(10, uSmallCount);
context.Reset();
}

private static IEndpointRouter GetEndpointRouter(EndpointRoutingMode mode)
{
var config = new ConfigurationBuilder().Build();
var serviceProvider = new ServiceCollection()
.AddSignalR()
.AddAzureSignalR(
o =>
{
o.EndpointRoutingMode = mode;
})
.Services
.AddSingleton<IConfiguration>(config)
.BuildServiceProvider();

return serviceProvider.GetRequiredService<IEndpointRouter>();
}

private static ServiceEndpoint GenerateServiceEndpoint(int capacity, int serverConnectionCount,
int clientConnectionCount, string name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ public async Task TestContainerWithOneEndpointWithAllDisconnectedConnectionThrow
{
var endpoint = new ServiceEndpoint(ConnectionString1);
var sem = new TestServiceEndpointManager(endpoint);
var router = new DefaultEndpointRouter();
var router = new DefaultEndpointRouter(null);

var container = new TestMultiEndpointServiceConnectionContainer("hub",
e => new TestServiceConnectionContainer(new List<IServiceConnection> {
Expand Down

0 comments on commit f595647

Please sign in to comment.