Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NatsClient to DI #689

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/NATS.Client.Simplified/NATS.Client.Simplified.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,7 @@
<ProjectReference Include="..\NATS.Client.Serializers.Json\NATS.Client.Serializers.Json.csproj" />
</ItemGroup>

</Project>
<ItemGroup>
<InternalsVisibleTo Include="NATS.Extensions.Microsoft.DependencyInjection, PublicKey=0024000004800000940000000602000000240000525341310004000001000100db7da1f2f89089327b47d26d69666fad20861f24e9acdb13965fb6c64dfee8da589b495df37a75e934ddbacb0752a42c40f3dbc79614eec9bb2a0b6741f9e2ad2876f95e74d54c23eef0063eb4efb1e7d824ee8a695b647c113c92834f04a3a83fb60f435814ddf5c4e5f66a168139c4c1b1a50a3e60c164d180e265b1f000cd"/>
</ItemGroup>
</Project>
2 changes: 2 additions & 0 deletions src/NATS.Client.Simplified/NatsClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ public NatsClient(NatsOpts opts, BoundedChannelFullMode pending = BoundedChannel
Connection = new NatsConnection(opts);
}

internal NatsClient(INatsConnection connection) => Connection = connection;

/// <inheritdoc />
public INatsConnection Connection { get; }

Expand Down
72 changes: 67 additions & 5 deletions src/NATS.Extensions.Microsoft.DependencyInjection/NatsBuilder.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using System.Text.Json.Serialization;
using System.Threading.Channels;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Client.Core;
using NATS.Net;

namespace NATS.Extensions.Microsoft.DependencyInjection;

Expand All @@ -14,6 +17,8 @@ public class NatsBuilder
private Func<IServiceProvider, NatsOpts, NatsOpts>? _configureOpts;
private Action<IServiceProvider, NatsConnection>? _configureConnection;
private object? _diKey = null;
private BoundedChannelFullMode? _pending = null;
private INatsSerializerRegistry? _serializerRegistry = null;

public NatsBuilder(IServiceCollection services)
=> _services = services;
Expand Down Expand Up @@ -78,6 +83,39 @@ public NatsBuilder WithKey(object key)
}
#endif

/// <summary>
/// Override the default <see cref="BoundedChannelFullMode"/> for the pending messages channel.
/// </summary>
/// <param name="pending">Full mode for the subscription channel.</param>
/// <returns>Builder to allow method chaining.</returns>
/// <remarks>
/// This will be applied to options overriding values set for <c>SubPendingChannelFullMode</c> in options.
/// By default, the pending messages channel will wait for space to be available when full.
/// Note that this is not the same as <c>NatsOpts</c> default <c>SubPendingChannelFullMode</c> which is <c>DropNewest</c>.
/// </remarks>
public NatsBuilder WithSubPendingChannelFullMode(BoundedChannelFullMode pending)
{
_pending = pending;
return this;
}

/// <summary>
/// Override the default <see cref="INatsSerializerRegistry"/> for the options.
/// </summary>
/// <param name="registry">Serializer registry to use.</param>
/// <returns>Builder to allow method chaining.</returns>
/// <remarks>
/// This will be applied to options overriding values set for <c>SerializerRegistry</c> in options.
/// By default, NatsClient registry will be used which allows ad-hoc JSON serialization.
/// Note that this is not the same as <c>NatsOpts</c> default <c>SerializerRegistry</c> which
/// doesn't do ad-hoc JSON serialization.
/// </remarks>
public NatsBuilder WithSerializerRegistry(INatsSerializerRegistry registry)
{
_serializerRegistry = registry;
return this;
}

internal IServiceCollection Build()
{
if (_poolSize != 1)
Expand All @@ -88,14 +126,16 @@ internal IServiceCollection Build()
_services.TryAddSingleton<INatsConnectionPool>(static provider => provider.GetRequiredService<NatsConnectionPool>());
_services.TryAddTransient<NatsConnection>(static provider => PooledConnectionFactory(provider, null));
_services.TryAddTransient<INatsConnection>(static provider => provider.GetRequiredService<NatsConnection>());
_services.TryAddTransient<INatsClient>(static provider => provider.GetRequiredService<NatsConnection>());
}
else
{
#if NET8_0_OR_GREATER
_services.TryAddKeyedSingleton<NatsConnectionPool>(_diKey, PoolFactory);
_services.TryAddKeyedSingleton<INatsConnectionPool>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnectionPool>(key));
_services.TryAddKeyedTransient(_diKey, PooledConnectionFactory);
_services.TryAddKeyedTransient<NatsConnection>(_diKey, PooledConnectionFactory);
_services.TryAddKeyedTransient<INatsConnection>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
_services.TryAddKeyedTransient<INatsClient>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
#endif
}
}
Expand All @@ -105,12 +145,14 @@ internal IServiceCollection Build()
{
_services.TryAddSingleton<NatsConnection>(provider => SingleConnectionFactory(provider));
_services.TryAddSingleton<INatsConnection>(static provider => provider.GetRequiredService<NatsConnection>());
_services.TryAddSingleton<INatsClient>(static provider => provider.GetRequiredService<NatsConnection>());
}
else
{
#if NET8_0_OR_GREATER
_services.TryAddKeyedSingleton(_diKey, SingleConnectionFactory);
_services.TryAddKeyedSingleton<INatsConnection>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
_services.TryAddKeyedSingleton<INatsClient>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
#endif
}
}
Expand All @@ -133,20 +175,40 @@ private static NatsConnection PooledConnectionFactory(IServiceProvider provider,

private NatsConnectionPool PoolFactory(IServiceProvider provider, object? diKey = null)
{
var options = NatsOpts.Default with { LoggerFactory = provider.GetRequiredService<ILoggerFactory>() };
options = _configureOpts?.Invoke(provider, options) ?? options;
var options = GetNatsOpts(provider);

return new NatsConnectionPool(_poolSize, options, con => _configureConnection?.Invoke(provider, con));
}

private NatsConnection SingleConnectionFactory(IServiceProvider provider, object? diKey = null)
{
var options = NatsOpts.Default with { LoggerFactory = provider.GetRequiredService<ILoggerFactory>() };
options = _configureOpts?.Invoke(provider, options) ?? options;
var options = GetNatsOpts(provider);

var conn = new NatsConnection(options);
_configureConnection?.Invoke(provider, conn);

return conn;
}

private NatsOpts GetNatsOpts(IServiceProvider provider)
{
var options = NatsOpts.Default with { LoggerFactory = provider.GetService<ILoggerFactory>() ?? NullLoggerFactory.Instance };
options = _configureOpts?.Invoke(provider, options) ?? options;

if (_serializerRegistry != null)
{
options = options with { SerializerRegistry = _serializerRegistry };
}
else
{
if (ReferenceEquals(options.SerializerRegistry, NatsOpts.Default.SerializerRegistry))
{
options = options with { SerializerRegistry = NatsClientDefaultSerializerRegistry.Default, };
}
}

options = options with { SubPendingChannelFullMode = _pending ?? BoundedChannelFullMode.Wait };

return options;
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using System.Threading.Channels;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Client.Core;
using NATS.Client.Core.Tests;
using NATS.Net;

namespace NATS.Extensions.Microsoft.DependencyInjection.Tests;

Expand All @@ -21,6 +23,14 @@ public void AddNatsClient_RegistersNatsConnectionAsSingleton_WhenPoolSizeIsOne()

Assert.NotNull(natsConnection1);
Assert.Same(natsConnection1, natsConnection2); // Singleton should return the same instance

var natsClient1 = provider.GetRequiredService<INatsClient>();
var natsClient2 = provider.GetRequiredService<INatsClient>();

Assert.NotNull(natsClient1);
Assert.Same(natsClient1, natsClient2);
Assert.Same(natsClient1, natsConnection1); // Same Connection implements INatsClient
Assert.Same(natsClient1.Connection, natsConnection1);
}

[Fact]
Expand All @@ -33,9 +43,138 @@ public void AddNatsClient_RegistersNatsConnectionAsTransient_WhenPoolSizeIsGreat
var provider = services.BuildServiceProvider();
var natsConnection1 = provider.GetRequiredService<INatsConnection>();
var natsConnection2 = provider.GetRequiredService<INatsConnection>();
var natsConnection3 = provider.GetRequiredService<INatsConnection>();
var natsConnection4 = provider.GetRequiredService<INatsConnection>();

Assert.NotNull(natsConnection1);
Assert.NotSame(natsConnection1, natsConnection2); // Transient should return different instances
Assert.NotSame(natsConnection3, natsConnection4);
Assert.Same(natsConnection1, natsConnection3); // The pool is round-robin
Assert.Same(natsConnection2, natsConnection4);

var natsClient1 = provider.GetRequiredService<INatsClient>();
var natsClient2 = provider.GetRequiredService<INatsClient>();
var natsClient3 = provider.GetRequiredService<INatsClient>();
var natsClient4 = provider.GetRequiredService<INatsClient>();

Assert.NotNull(natsClient1);
Assert.NotSame(natsClient1, natsClient2);
Assert.NotSame(natsClient3, natsClient4);
Assert.Same(natsClient1, natsClient3);
Assert.Same(natsClient2, natsClient4);
Assert.Same(natsClient1, natsConnection1);
Assert.Same(natsClient1.Connection, natsConnection1);
}

[Fact]
public Task AddNatsClient_OptionsWithDefaults()
{
var services = new ServiceCollection();
services.AddNatsClient();

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

Assert.Same(NullLoggerFactory.Instance, nats.Opts.LoggerFactory);

// These defaults are different from NatsOptions defaults but same as NatsClient defaults
// for ease of use for new users
Assert.Same(NatsClientDefaultSerializerRegistry.Default, nats.Opts.SerializerRegistry);
Assert.Equal(BoundedChannelFullMode.Wait, nats.Opts.SubPendingChannelFullMode);

return Task.CompletedTask;
}

[Fact]
public Task AddNatsClient_WithDefaultSerializerExplicitlySet()
{
var services = new ServiceCollection();
services.AddNatsClient(nats =>
{
// These two settings make the options same as NatsOptions defaults
nats.WithSerializerRegistry(NatsDefaultSerializerRegistry.Default)
.WithSubPendingChannelFullMode(BoundedChannelFullMode.DropNewest);
});

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

Assert.Same(NatsDefaultSerializerRegistry.Default, nats.Opts.SerializerRegistry);
Assert.Equal(BoundedChannelFullMode.DropNewest, nats.Opts.SubPendingChannelFullMode);

return Task.CompletedTask;
}

[Fact]
public Task AddNatsClient_WithSerializerExplicitlySet()
{
var mySerializerRegistry = new NatsJsonContextSerializerRegistry(MyJsonContext.Default);

var services = new ServiceCollection();
services.AddNatsClient(nats =>
{
nats.ConfigureOptions(opts => opts with { SerializerRegistry = mySerializerRegistry });
});

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

Assert.Same(mySerializerRegistry, nats.Opts.SerializerRegistry);

// You can only override this using .WithSubPendingChannelFullMode() on builder above
Assert.Equal(BoundedChannelFullMode.Wait, nats.Opts.SubPendingChannelFullMode);

return Task.CompletedTask;
}

[Fact]
public async Task AddNatsClient_WithDefaultSerializer()
{
await using var server = NatsServer.Start();
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
var cancellationToken = cts.Token;

// Default JSON serialization
{
var services = new ServiceCollection();
services.AddSingleton<ILoggerFactory, NullLoggerFactory>();
services.AddNatsClient(nats =>
{
nats.ConfigureOptions(opts => server.ClientOpts(opts));
});

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

// Ad-hoc JSON serialization
await using var sub = await nats.SubscribeCoreAsync<MyAdHocData>("foo", cancellationToken: cancellationToken);
await nats.PingAsync(cancellationToken);
await nats.PublishAsync("foo", new MyAdHocData(1, "bar"), cancellationToken: cancellationToken);

var msg = await sub.Msgs.ReadAsync(cancellationToken);
Assert.Equal(1, msg.Data?.Id);
Assert.Equal("bar", msg.Data?.Name);
}

// Default raw serialization
{
var services = new ServiceCollection();
services.AddSingleton<ILoggerFactory, NullLoggerFactory>();
services.AddNatsClient(nats =>
{
nats.ConfigureOptions(opts => server.ClientOpts(opts));
nats.WithSerializerRegistry(NatsDefaultSerializerRegistry.Default);
});

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

var exception = await Assert.ThrowsAsync<NatsException>(async () =>
{
await nats.PublishAsync("foo", new MyAdHocData(1, "bar"), cancellationToken: cancellationToken);
});
Assert.Matches("Can't serialize.*MyAdHocData", exception.Message);
}
}

[Fact]
Expand Down Expand Up @@ -208,3 +347,5 @@ public void AddNats_RegistersKeyedNatsConnection_WhenKeyIsProvided_pooled()
}
#endif
}

public record MyAdHocData(int Id, string Name);
Loading