Skip to content

Commit

Permalink
Hosting keyed services fixes (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtmk authored Feb 9, 2024
1 parent d80fa38 commit a912924
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 20 deletions.
56 changes: 41 additions & 15 deletions src/NATS.Client.Hosting/NatsHostingExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ public static IServiceCollection AddNats(
Func<NatsOpts, NatsOpts>? configureOpts = null,
Action<NatsConnection>? configureConnection = null
#if NET8_0_OR_GREATER
, string? key = null // This parameter is only available in .NET 8 or greater
, object? key = null // This parameter is only available in .NET 8 or greater
#endif
)
{
string? diKey = null;
object? diKey = null;
#if NET8_0_OR_GREATER
diKey = key;
#endif
Expand All @@ -34,40 +34,64 @@ public static IServiceCollection AddNats(

if (poolSize != 1)
{
services.TryAddSingleton<NatsConnectionPool>(provider =>
NatsConnectionPool PoolFactory(IServiceProvider provider)
{
var options = NatsOpts.Default with { LoggerFactory = provider.GetRequiredService<ILoggerFactory>() };
var options = NatsOpts.Default with
{
LoggerFactory = provider.GetRequiredService<ILoggerFactory>(),
};
if (configureOpts != null)
{
options = configureOpts(options);
}

return new NatsConnectionPool(poolSize, options, configureConnection ?? (_ => { }));
});
}

services.TryAddSingleton<INatsConnectionPool>(static provider => provider.GetRequiredService<NatsConnectionPool>());
services.TryAddTransient<NatsConnection>(static provider =>
static NatsConnection ConnectionFactory(IServiceProvider provider, object? key)
{
#if NET8_0_OR_GREATER
if (key == null)
{
var pool = provider.GetRequiredService<NatsConnectionPool>();
return (pool.GetConnection() as NatsConnection)!;
}
else
{
var pool = provider.GetRequiredKeyedService<NatsConnectionPool>(key);
return (pool.GetConnection() as NatsConnection)!;
}
#else
var pool = provider.GetRequiredService<NatsConnectionPool>();
return (pool.GetConnection() as NatsConnection)!;
});
#endif
}

if (string.IsNullOrEmpty(diKey))
if (diKey == null)
{
services.TryAddSingleton(PoolFactory);
services.TryAddSingleton<INatsConnectionPool>(static provider => provider.GetRequiredService<NatsConnectionPool>());
services.TryAddTransient<NatsConnection>(static provider => ConnectionFactory(provider, null));
services.TryAddTransient<INatsConnection>(static provider => provider.GetRequiredService<NatsConnection>());
}
else
{
#if NET8_0_OR_GREATER
services.AddKeyedTransient<INatsConnection>(diKey, static (provider, _) => provider.GetRequiredService<NatsConnection>());
services.TryAddKeyedSingleton(diKey, (provider, _) => PoolFactory(provider));
services.TryAddKeyedSingleton<INatsConnectionPool>(diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnectionPool>(key));
services.TryAddKeyedTransient<NatsConnection>(diKey, static (provider, key) => ConnectionFactory(provider, key));
services.TryAddKeyedTransient<INatsConnection>(diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
#endif
}
}
else
{
services.TryAddSingleton<NatsConnection>(provider =>
NatsConnection Factory(IServiceProvider provider)
{
var options = NatsOpts.Default with { LoggerFactory = provider.GetRequiredService<ILoggerFactory>() };
var options = NatsOpts.Default with
{
LoggerFactory = provider.GetRequiredService<ILoggerFactory>(),
};
if (configureOpts != null)
{
options = configureOpts(options);
Expand All @@ -80,16 +104,18 @@ public static IServiceCollection AddNats(
}

return conn;
});
}

if (string.IsNullOrEmpty(diKey))
if (diKey == null)
{
services.TryAddSingleton(Factory);
services.TryAddSingleton<INatsConnection>(static provider => provider.GetRequiredService<NatsConnection>());
}
else
{
#if NET8_0_OR_GREATER
services.AddKeyedSingleton<INatsConnection>(diKey, static (provider, _) => provider.GetRequiredService<NatsConnection>());
services.TryAddKeyedSingleton<NatsConnection>(diKey, (provider, _) => Factory(provider));
services.TryAddKeyedSingleton<INatsConnection>(diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
#endif
}
}
Expand Down
68 changes: 63 additions & 5 deletions tests/NATS.Client.Hosting.Tests/NatsHostingExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using NATS.Client.Core;

namespace NATS.Client.Hosting.Tests;

public class NatsHostingExtensionsTests
{
[Fact]
Expand Down Expand Up @@ -43,16 +42,75 @@ public void AddNats_RegistersNatsConnectionAsTransient_WhenPoolSizeIsGreaterThan
[Fact]
public void AddNats_RegistersKeyedNatsConnection_WhenKeyIsProvided()
{
var key = "TestKey";
var key1 = "TestKey1";
var key2 = "TestKey2";

var services = new ServiceCollection();
services.AddSingleton<ILoggerFactory, NullLoggerFactory>();

services.AddNats(poolSize: 1, key: key1);
services.AddNats(poolSize: 1, key: key2);
var provider = services.BuildServiceProvider();

var natsConnection1A = provider.GetKeyedService<INatsConnection>(key1);
Assert.NotNull(natsConnection1A);
var natsConnection1B = provider.GetKeyedService<INatsConnection>(key1);
Assert.NotNull(natsConnection1B);
Assert.Same(natsConnection1A, natsConnection1B);

var natsConnection2 = provider.GetKeyedService<INatsConnection>(key2);
Assert.NotNull(natsConnection2);
Assert.NotSame(natsConnection2, natsConnection1A);
}

[Fact]
public void AddNats_RegistersKeyedNatsConnection_WhenKeyIsProvided_pooled()
{
var key1 = "TestKey1";
var key2 = "TestKey2";

var services = new ServiceCollection();
services.AddSingleton<ILoggerFactory, NullLoggerFactory>();

services.AddNats(poolSize: 1, key: key);
services.AddNats(poolSize: 2, key: key1);
services.AddNats(poolSize: 2, key: key2);
var provider = services.BuildServiceProvider();

var natsConnection = provider.GetKeyedService<INatsConnection>(key);
Assert.NotNull(natsConnection);
Dictionary<string, List<object>> connections = new();
foreach (var key in new[] { key1, key2 })
{
var nats1 = provider.GetKeyedService<INatsConnection>(key);
Assert.NotNull(nats1);
var nats2 = provider.GetKeyedService<INatsConnection>(key);
Assert.NotNull(nats2);
var nats3 = provider.GetKeyedService<INatsConnection>(key);
Assert.NotNull(nats3);
var nats4 = provider.GetKeyedService<INatsConnection>(key);
Assert.NotNull(nats4);

// relying on the fact that the pool size is 2 and connections are returned in a round-robin fashion
Assert.NotSame(nats1, nats2);
Assert.Same(nats1, nats3);
Assert.NotSame(nats2, nats3);
Assert.Same(nats2, nats4);

if (!connections.TryGetValue(key, out var list))
{
list = new List<object>();
connections.Add(key, list);
}

list.Add(nats1);
list.Add(nats2);
list.Add(nats3);
list.Add(nats4);
}

foreach (var obj1 in connections[key1])
{
foreach (var obj2 in connections[key2])
Assert.NotSame(obj1, obj2);
}
}
#endif
}

0 comments on commit a912924

Please sign in to comment.