diff --git a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/BasePathStrategy.cs b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/BasePathStrategy.cs index 782693ae..52a356a4 100644 --- a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/BasePathStrategy.cs +++ b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/BasePathStrategy.cs @@ -12,20 +12,19 @@ public class BasePathStrategy : IMultiTenantStrategy { public Task GetIdentifierAsync(object context) { - if (!(context is HttpContext httpContext)) - throw new MultiTenantException(null, - new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context))); + if (context is not HttpContext httpContext) + return Task.FromResult(null); - var path = httpContext.Request.Path; + var path = httpContext.Request.Path; - var pathSegments = - path.Value?.Split('/', 2, StringSplitOptions.RemoveEmptyEntries); + var pathSegments = + path.Value?.Split('/', 2, StringSplitOptions.RemoveEmptyEntries); - if (pathSegments is null || pathSegments.Length == 0) - return Task.FromResult(null); + if (pathSegments is null || pathSegments.Length == 0) + return Task.FromResult(null); - string identifier = pathSegments[0]; + string identifier = pathSegments[0]; - return Task.FromResult(identifier); - } + return Task.FromResult(identifier); + } } \ No newline at end of file diff --git a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/ClaimStrategy.cs b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/ClaimStrategy.cs index f72c7ebd..9ffeabb4 100644 --- a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/ClaimStrategy.cs +++ b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/ClaimStrategy.cs @@ -15,53 +15,56 @@ namespace Finbuckle.MultiTenant.AspNetCore.Strategies; // ReSharper disable once ClassNeverInstantiated.Global public class ClaimStrategy : IMultiTenantStrategy { - private readonly string _tenantKey; - private readonly string? _authenticationScheme; + private readonly string _tenantKey; + private readonly string? _authenticationScheme; - public ClaimStrategy(string template) : this(template, null) - { - } + public ClaimStrategy(string template) : this(template, null) + { + } - public ClaimStrategy(string template, string? authenticationScheme) - { - if (string.IsNullOrWhiteSpace(template)) - throw new ArgumentException(nameof(template)); + public ClaimStrategy(string template, string? authenticationScheme) + { + if (string.IsNullOrWhiteSpace(template)) + throw new ArgumentException(nameof(template)); - _tenantKey = template; - _authenticationScheme = authenticationScheme; - } + _tenantKey = template; + _authenticationScheme = authenticationScheme; + } - public async Task GetIdentifierAsync(object context) - { - if (!(context is HttpContext httpContext)) - throw new MultiTenantException(null, new ArgumentException($@"""{nameof(context)}"" type must be of type HttpContext", nameof(context))); + public async Task GetIdentifierAsync(object context) + { + if (context is not HttpContext httpContext) + return null; - if (httpContext.User.Identity is { IsAuthenticated: true }) - return httpContext.User.FindFirst(_tenantKey)?.Value; + if (httpContext.User.Identity is { IsAuthenticated: true }) + return httpContext.User.FindFirst(_tenantKey)?.Value; - AuthenticationScheme? authScheme; - var schemeProvider = httpContext.RequestServices.GetRequiredService(); - if (_authenticationScheme is null) - { - authScheme = await schemeProvider.GetDefaultAuthenticateSchemeAsync(); - } - else - { - authScheme = (await schemeProvider.GetAllSchemesAsync()).FirstOrDefault(x => x.Name == _authenticationScheme); - } + AuthenticationScheme? authScheme; + var schemeProvider = httpContext.RequestServices.GetRequiredService(); + if (_authenticationScheme is null) + { + authScheme = await schemeProvider.GetDefaultAuthenticateSchemeAsync(); + } + else + { + authScheme = + (await schemeProvider.GetAllSchemesAsync()).FirstOrDefault(x => x.Name == _authenticationScheme); + } - if (authScheme is null) - { - return null; - } + if (authScheme is null) + { + return null; + } - var handler = (IAuthenticationHandler)ActivatorUtilities.CreateInstance(httpContext.RequestServices, authScheme.HandlerType); - await handler.InitializeAsync(authScheme, httpContext); - httpContext.Items[$"{Constants.TenantToken}__bypass_validate_principal__"] = "true"; // Value doesn't matter. - var handlerResult = await handler.AuthenticateAsync(); - httpContext.Items.Remove($"{Constants.TenantToken}__bypass_validate_principal__"); + var handler = + (IAuthenticationHandler)ActivatorUtilities.CreateInstance(httpContext.RequestServices, + authScheme.HandlerType); + await handler.InitializeAsync(authScheme, httpContext); + httpContext.Items[$"{Constants.TenantToken}__bypass_validate_principal__"] = "true"; // Value doesn't matter. + var handlerResult = await handler.AuthenticateAsync(); + httpContext.Items.Remove($"{Constants.TenantToken}__bypass_validate_principal__"); - var identifier = handlerResult.Principal?.FindFirst(_tenantKey)?.Value; - return identifier; - } + var identifier = handlerResult.Principal?.FindFirst(_tenantKey)?.Value; + return identifier; + } } \ No newline at end of file diff --git a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/HeaderStrategy.cs b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/HeaderStrategy.cs index 633a83ef..1ab9b69f 100644 --- a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/HeaderStrategy.cs +++ b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/HeaderStrategy.cs @@ -21,9 +21,8 @@ public HeaderStrategy(string headerKey) public Task GetIdentifierAsync(object context) { - if (!(context is HttpContext httpContext)) - throw new MultiTenantException(null, - new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context))); + if (context is not HttpContext httpContext) + return Task.FromResult(null); return Task.FromResult(httpContext?.Request.Headers[_headerKey].FirstOrDefault()); } diff --git a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/HostStrategy.cs b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/HostStrategy.cs index c198cd20..cfd530db 100644 --- a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/HostStrategy.cs +++ b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/HostStrategy.cs @@ -16,74 +16,76 @@ public class HostStrategy : IMultiTenantStrategy public HostStrategy(string template) { - // New in 2.1, match whole domain if just "__tenant__". - if (template == Constants.TenantToken) + // match whole domain if just "__tenant__". + if (template == Constants.TenantToken) + { + template = template.Replace(Constants.TenantToken, "(?.+)"); + } + else + { + // Check for valid template. + // Template cannot be null or whitespace. + if (string.IsNullOrWhiteSpace(template)) + { + throw new MultiTenantException("Template cannot be null or whitespace."); + } + + // Wildcard "*" must be only occur once in template. + if (Regex.Match(template, @"\*(?=.*\*)").Success) + { + throw new MultiTenantException("Wildcard \"*\" must be only occur once in template."); + } + + // Wildcard "*" must be only token in template segment. + if (Regex.Match(template, @"\*[^\.]|[^\.]\*").Success) + { + throw new MultiTenantException("\"*\" wildcard must be only token in template segment."); + } + + // Wildcard "?" must be only token in template segment. + if (Regex.Match(template, @"\?[^\.]|[^\.]\?").Success) { - template = template.Replace(Constants.TenantToken, "(?.+)"); + throw new MultiTenantException("\"?\" wildcard must be only token in template segment."); } - else + + template = template.Trim().Replace(".", @"\."); + string wildcardSegmentsPattern = @"(\.[^\.]+)*"; + string singleSegmentPattern = @"[^\.]+"; + if (template.Substring(template.Length - 3, 3) == @"\.*") { - // Check for valid template. - // Template cannot be null or whitespace. - if (string.IsNullOrWhiteSpace(template)) - { - throw new MultiTenantException("Template cannot be null or whitespace."); - } - // Wildcard "*" must be only occur once in template. - if (Regex.Match(template, @"\*(?=.*\*)").Success) - { - throw new MultiTenantException("Wildcard \"*\" must be only occur once in template."); - } - // Wildcard "*" must be only token in template segment. - if (Regex.Match(template, @"\*[^\.]|[^\.]\*").Success) - { - throw new MultiTenantException("\"*\" wildcard must be only token in template segment."); - } - // Wildcard "?" must be only token in template segment. - if (Regex.Match(template, @"\?[^\.]|[^\.]\?").Success) - { - throw new MultiTenantException("\"?\" wildcard must be only token in template segment."); - } - - template = template.Trim().Replace(".", @"\."); - string wildcardSegmentsPattern = @"(\.[^\.]+)*"; - string singleSegmentPattern = @"[^\.]+"; - if (template.Substring(template.Length - 3, 3) == @"\.*") - { - template = template.Substring(0, template.Length - 3) + wildcardSegmentsPattern; - } - - wildcardSegmentsPattern = @"([^\.]+\.)*"; - template = template.Replace(@"*\.", wildcardSegmentsPattern); - template = template.Replace("?", singleSegmentPattern); - template = template.Replace(Constants.TenantToken, @"(?[^\.]+)"); + template = template.Substring(0, template.Length - 3) + wildcardSegmentsPattern; } - this.regex = $"^{template}$"; + wildcardSegmentsPattern = @"([^\.]+\.)*"; + template = template.Replace(@"*\.", wildcardSegmentsPattern); + template = template.Replace("?", singleSegmentPattern); + template = template.Replace(Constants.TenantToken, @"(?[^\.]+)"); } + this.regex = $"^{template}$"; + } + public Task GetIdentifierAsync(object context) { - if (!(context is HttpContext httpContext)) - throw new MultiTenantException(null, - new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context))); + if (context is not HttpContext httpContext) + return Task.FromResult(null); - var host = httpContext.Request.Host; + var host = httpContext.Request.Host; - if (host.HasValue == false) - return Task.FromResult(null); + if (host.HasValue == false) + return Task.FromResult(null); - string? identifier = null; + string? identifier = null; - var match = Regex.Match(host.Host, regex, - RegexOptions.ExplicitCapture, - TimeSpan.FromMilliseconds(100)); + var match = Regex.Match(host.Host, regex, + RegexOptions.ExplicitCapture, + TimeSpan.FromMilliseconds(100)); - if (match.Success) - { - identifier = match.Groups["identifier"].Value; - } - - return Task.FromResult(identifier); + if (match.Success) + { + identifier = match.Groups["identifier"].Value; } + + return Task.FromResult(identifier); + } } \ No newline at end of file diff --git a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/RemoteAuthenticationCallbackStrategy.cs b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/RemoteAuthenticationCallbackStrategy.cs index 87821945..913aa59d 100644 --- a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/RemoteAuthenticationCallbackStrategy.cs +++ b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/RemoteAuthenticationCallbackStrategy.cs @@ -19,89 +19,96 @@ public class RemoteAuthenticationCallbackStrategy : IMultiTenantStrategy { private readonly ILogger logger; - public int Priority { get => -900; } + public int Priority + { + get => -900; + } public RemoteAuthenticationCallbackStrategy(ILogger logger) { - this.logger = logger; - } + this.logger = logger; + } public async virtual Task GetIdentifierAsync(object context) { - if (!(context is HttpContext httpContext)) - throw new MultiTenantException(null, - new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context))); + if (context is not HttpContext httpContext) + return null; - var schemes = httpContext.RequestServices.GetRequiredService(); + var schemes = httpContext.RequestServices.GetRequiredService(); - foreach (var scheme in (await schemes.GetRequestHandlerSchemesAsync()). - Where(s => typeof(IAuthenticationRequestHandler).IsAssignableFrom(s.HandlerType))) - // Where(s => s.HandlerType.ImplementsOrInheritsUnboundGeneric(typeof(RemoteAuthenticationHandler<>)))) - { - // TODO verify this comment - // Unfortunately we can't rely on the ShouldHandleAsync method since OpenId Connect handler doesn't use it. - // Instead we'll get the paths to check from the options. - var optionsType = scheme.HandlerType.GetProperty("Options")?.PropertyType; + foreach (var scheme in (await schemes.GetRequestHandlerSchemesAsync()).Where(s => + typeof(IAuthenticationRequestHandler).IsAssignableFrom(s.HandlerType))) + { + // TODO verify this comment (still true as of net8.0) + // Unfortunately we can't rely on the ShouldHandleAsync method since OpenId Connect handler doesn't use it. + // Instead we'll get the paths to check from the options. + var optionsType = scheme.HandlerType.GetProperty("Options")?.PropertyType; - if (optionsType is null) - { - continue; - } + if (optionsType is null) + { + continue; + } - var optionsMonitorType = typeof(IOptionsMonitor<>).MakeGenericType(optionsType); - var optionsMonitor = httpContext.RequestServices.GetRequiredService(optionsMonitorType); - var options = optionsMonitorType?.GetMethod("Get")?.Invoke(optionsMonitor, new[] { scheme.Name }) as RemoteAuthenticationOptions; + var optionsMonitorType = typeof(IOptionsMonitor<>).MakeGenericType(optionsType); + var optionsMonitor = httpContext.RequestServices.GetRequiredService(optionsMonitorType); + var options = + optionsMonitorType?.GetMethod("Get")?.Invoke(optionsMonitor, new[] { scheme.Name }) as + RemoteAuthenticationOptions; - if (options is null) - { - continue; - } + if (options is null) + { + continue; + } - var callbackPath = (PathString)(optionsType.GetProperty("CallbackPath")?.GetValue(options) ?? PathString.Empty); - var signedOutCallbackPath = (PathString)(optionsType.GetProperty("SignedOutCallbackPath")?.GetValue(options) ?? PathString.Empty); + var callbackPath = + (PathString)(optionsType.GetProperty("CallbackPath")?.GetValue(options) ?? PathString.Empty); + var signedOutCallbackPath = + (PathString)(optionsType.GetProperty("SignedOutCallbackPath")?.GetValue(options) ?? PathString.Empty); - if (callbackPath.HasValue && callbackPath == httpContext.Request.Path || - signedOutCallbackPath.HasValue && signedOutCallbackPath == httpContext.Request.Path) + if (callbackPath.HasValue && callbackPath == httpContext.Request.Path || + signedOutCallbackPath.HasValue && signedOutCallbackPath == httpContext.Request.Path) + { + try { - try + string? state = null; + + if (string.Equals(httpContext.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) { - string? state = null; - - if (string.Equals(httpContext.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) - { - state = httpContext.Request.Query["state"]; - } - // Assumption: it is safe to read the form, limit to 1MB form size. - else if (string.Equals(httpContext.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) - && httpContext.Request.HasFormContentType - && httpContext.Request.Body.CanRead) - { - var formOptions = new FormOptions { BufferBody = true, MemoryBufferThreshold = 1048576 }; - - var form = await httpContext.Request.ReadFormAsync(formOptions); - state = form.Single(i => string.Equals(i.Key, "state", StringComparison.OrdinalIgnoreCase)).Value; - } - - var properties = ((dynamic)options).StateDataFormat.Unprotect(state) as AuthenticationProperties; - - if (properties == null) - { - if (logger != null) - logger.LogWarning("A tenant could not be determined because no state parameter passed with the remote authentication callback."); - return null; - } - - properties.Items.TryGetValue(Constants.TenantToken, out var identifier); - - return identifier; + state = httpContext.Request.Query["state"]; } - catch (Exception e) + // Assumption: it is safe to read the form, limit to 1MB form size. + else if (string.Equals(httpContext.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) + && httpContext.Request.HasFormContentType + && httpContext.Request.Body.CanRead) { - throw new MultiTenantException("Error occurred resolving tenant for remote authentication.", e); + var formOptions = new FormOptions { BufferBody = true, MemoryBufferThreshold = 1048576 }; + + var form = await httpContext.Request.ReadFormAsync(formOptions); + state = form.Single(i => string.Equals(i.Key, "state", StringComparison.OrdinalIgnoreCase)) + .Value; } + + var properties = ((dynamic)options).StateDataFormat.Unprotect(state) as AuthenticationProperties; + + if (properties == null) + { + if (logger != null) + logger.LogWarning( + "A tenant could not be determined because no state parameter passed with the remote authentication callback."); + return null; + } + + properties.Items.TryGetValue(Constants.TenantToken, out var identifier); + + return identifier; + } + catch (Exception e) + { + throw new MultiTenantException("Error occurred resolving tenant for remote authentication.", e); } } - - return null; } + + return null; + } } \ No newline at end of file diff --git a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/RouteStrategy.cs b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/RouteStrategy.cs index d8d3b54e..b2b51045 100644 --- a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/RouteStrategy.cs +++ b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/RouteStrategy.cs @@ -25,12 +25,10 @@ public RouteStrategy(string tenantParam) public Task GetIdentifierAsync(object context) { - if (!(context is HttpContext httpContext)) - throw new MultiTenantException(null, - new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context))); + if (context is not HttpContext httpContext) + return Task.FromResult(null); - object? identifier; - httpContext.Request.RouteValues.TryGetValue(TenantParam, out identifier); + httpContext.Request.RouteValues.TryGetValue(TenantParam, out var identifier); return Task.FromResult(identifier as string); } diff --git a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/SessionStrategy.cs b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/SessionStrategy.cs index a1beb852..2beb66b1 100644 --- a/src/Finbuckle.MultiTenant.AspNetCore/Strategies/SessionStrategy.cs +++ b/src/Finbuckle.MultiTenant.AspNetCore/Strategies/SessionStrategy.cs @@ -14,21 +14,20 @@ public class SessionStrategy : IMultiTenantStrategy public SessionStrategy(string tenantKey) { - if (string.IsNullOrWhiteSpace(tenantKey)) - { - throw new ArgumentException("message", nameof(tenantKey)); - } - - this.tenantKey = tenantKey; + if (string.IsNullOrWhiteSpace(tenantKey)) + { + throw new ArgumentException("message", nameof(tenantKey)); } + this.tenantKey = tenantKey; + } + public Task GetIdentifierAsync(object context) { - if(!(context is HttpContext httpContext)) - throw new MultiTenantException(null, - new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context))); + if (context is not HttpContext httpContext) + return Task.FromResult(null); - var identifier = httpContext.Session.GetString(tenantKey); - return Task.FromResult(identifier); // Prevent the compiler warning that no await exists. - } + var identifier = httpContext.Session.GetString(tenantKey); + return Task.FromResult(identifier); + } } \ No newline at end of file diff --git a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/BasePathStrategyShould.cs b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/BasePathStrategyShould.cs index 3760b851..342cdbed 100644 --- a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/BasePathStrategyShould.cs +++ b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/BasePathStrategyShould.cs @@ -98,12 +98,12 @@ public async void ReturnExpectedIdentifier(string path, string? expected) } [Fact] - public async void ThrowIfContextIsNotHttpContext() + public async void ReturnNullIfContextIsNotHttpContext() { - var context = new Object(); + var context = new object(); var strategy = new BasePathStrategy(); - await Assert.ThrowsAsync(() => strategy.GetIdentifierAsync(context)); + Assert.Null(await strategy.GetIdentifierAsync(context)); } [Fact] diff --git a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/HostStrategyShould.cs b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/HostStrategyShould.cs index 9ff92660..0c0471fc 100644 --- a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/HostStrategyShould.cs +++ b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/HostStrategyShould.cs @@ -69,11 +69,11 @@ public void ThrowIfInvalidTemplate(string? template) } [Fact] - public async void ThrowIfContextIsNotHttpContext() + public async void ReturnNullIfContextIsNotHttpContext() { - var context = new Object(); + var context = new object(); var strategy = new HostStrategy("__tenant__.*"); - await Assert.ThrowsAsync(() => strategy.GetIdentifierAsync(context)); + Assert.Null(await strategy.GetIdentifierAsync(context)); } } \ No newline at end of file diff --git a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/RemoteAuthenticationCallbackStrategyShould.cs b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/RemoteAuthenticationCallbackStrategyShould.cs index fd237386..c7a6e308 100644 --- a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/RemoteAuthenticationCallbackStrategyShould.cs +++ b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/RemoteAuthenticationCallbackStrategyShould.cs @@ -13,7 +13,16 @@ public class RemoteAuthenticationCallbackStrategyShould [Fact] public void HavePriorityNeg900() { - var strategy = new RemoteAuthenticationCallbackStrategy(null!); - Assert.Equal(-900, strategy.Priority); - } + var strategy = new RemoteAuthenticationCallbackStrategy(null!); + Assert.Equal(-900, strategy.Priority); + } + + [Fact] + public async void ReturnNullIfContextIsNotHttpContext() + { + var context = new object(); + var strategy = new RemoteAuthenticationCallbackStrategy(null!); + + Assert.Null(await strategy.GetIdentifierAsync(context)); + } } \ No newline at end of file diff --git a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/RouteStrategyShould.cs b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/RouteStrategyShould.cs index baaa3a1d..5b2c8f0c 100644 --- a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/RouteStrategyShould.cs +++ b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/RouteStrategyShould.cs @@ -32,12 +32,12 @@ public async Task ReturnExpectedIdentifier(string path, string identifier, strin } [Fact] - public async void ThrowIfContextIsNotHttpContext() + public async void ReturnNullIfContextIsNotHttpContext() { - var context = new Object(); + var context = new object(); var strategy = new RouteStrategy("__tenant__"); - await Assert.ThrowsAsync(() => strategy.GetIdentifierAsync(context)); + Assert.Null(await strategy.GetIdentifierAsync(context)); } [Fact] diff --git a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/SessionStrategyShould.cs b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/SessionStrategyShould.cs index 4582205d..e67b57f8 100644 --- a/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/SessionStrategyShould.cs +++ b/test/Finbuckle.MultiTenant.AspNetCore.Test/Strategies/SessionStrategyShould.cs @@ -53,12 +53,12 @@ private static IWebHostBuilder GetTestHostBuilder(string identifier, string sess } [Fact] - public async void ThrowIfContextIsNotHttpContext() + public async void ReturnNullIfContextIsNotHttpContext() { - var context = new Object(); + var context = new object(); var strategy = new SessionStrategy("__tenant__"); - - await Assert.ThrowsAsync(() => strategy.GetIdentifierAsync(context)); + + Assert.Null(await strategy.GetIdentifierAsync(context)); } [Fact] @@ -66,12 +66,10 @@ public async Task ReturnNullIfNoSessionValue() { var hostBuilder = GetTestHostBuilder("test_tenant", "__tenant__"); - using (var server = new TestServer(hostBuilder)) - { - var client = server.CreateClient(); - var response = await client.GetStringAsync("/test_tenant"); - Assert.Equal("", response); - } + using var server = new TestServer(hostBuilder); + var client = server.CreateClient(); + var response = await client.GetStringAsync("/test_tenant"); + Assert.Equal("", response); } // TODO: Figure out how to test this