Skip to content

Commit

Permalink
feat: strategies return null on invalid context type (#885)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Included strategies for ASP.NET Core would throw an exception if the passed context was not an `HttpContext` type. Now they will return null indicating no identifier was found.
  • Loading branch information
AndrewTriesToCode authored Oct 12, 2024
1 parent 3263eff commit 9834575
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 212 deletions.
21 changes: 10 additions & 11 deletions src/Finbuckle.MultiTenant.AspNetCore/Strategies/BasePathStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@ public class BasePathStrategy : IMultiTenantStrategy
{
public Task<string?> GetIdentifierAsync(object context)
{
if (!(context is HttpContext httpContext))
throw new MultiTenantException(null,
new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context)));
if (context is not HttpContext httpContext)
return Task.FromResult<string?>(null);

var path = httpContext.Request.Path;
var path = httpContext.Request.Path;

var pathSegments =
path.Value?.Split('/', 2, StringSplitOptions.RemoveEmptyEntries);
var pathSegments =
path.Value?.Split('/', 2, StringSplitOptions.RemoveEmptyEntries);

if (pathSegments is null || pathSegments.Length == 0)
return Task.FromResult<string?>(null);
if (pathSegments is null || pathSegments.Length == 0)
return Task.FromResult<string?>(null);

string identifier = pathSegments[0];
string identifier = pathSegments[0];

return Task.FromResult<string?>(identifier);
}
return Task.FromResult<string?>(identifier);
}
}
83 changes: 43 additions & 40 deletions src/Finbuckle.MultiTenant.AspNetCore/Strategies/ClaimStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,53 +15,56 @@ namespace Finbuckle.MultiTenant.AspNetCore.Strategies;
// ReSharper disable once ClassNeverInstantiated.Global
public class ClaimStrategy : IMultiTenantStrategy
{
private readonly string _tenantKey;
private readonly string? _authenticationScheme;
private readonly string _tenantKey;
private readonly string? _authenticationScheme;

public ClaimStrategy(string template) : this(template, null)
{
}
public ClaimStrategy(string template) : this(template, null)
{
}

public ClaimStrategy(string template, string? authenticationScheme)
{
if (string.IsNullOrWhiteSpace(template))
throw new ArgumentException(nameof(template));
public ClaimStrategy(string template, string? authenticationScheme)
{
if (string.IsNullOrWhiteSpace(template))
throw new ArgumentException(nameof(template));

_tenantKey = template;
_authenticationScheme = authenticationScheme;
}
_tenantKey = template;
_authenticationScheme = authenticationScheme;
}

public async Task<string?> GetIdentifierAsync(object context)
{
if (!(context is HttpContext httpContext))
throw new MultiTenantException(null, new ArgumentException($@"""{nameof(context)}"" type must be of type HttpContext", nameof(context)));
public async Task<string?> GetIdentifierAsync(object context)
{
if (context is not HttpContext httpContext)
return null;

if (httpContext.User.Identity is { IsAuthenticated: true })
return httpContext.User.FindFirst(_tenantKey)?.Value;
if (httpContext.User.Identity is { IsAuthenticated: true })
return httpContext.User.FindFirst(_tenantKey)?.Value;

AuthenticationScheme? authScheme;
var schemeProvider = httpContext.RequestServices.GetRequiredService<IAuthenticationSchemeProvider>();
if (_authenticationScheme is null)
{
authScheme = await schemeProvider.GetDefaultAuthenticateSchemeAsync();
}
else
{
authScheme = (await schemeProvider.GetAllSchemesAsync()).FirstOrDefault(x => x.Name == _authenticationScheme);
}
AuthenticationScheme? authScheme;
var schemeProvider = httpContext.RequestServices.GetRequiredService<IAuthenticationSchemeProvider>();
if (_authenticationScheme is null)
{
authScheme = await schemeProvider.GetDefaultAuthenticateSchemeAsync();
}
else
{
authScheme =
(await schemeProvider.GetAllSchemesAsync()).FirstOrDefault(x => x.Name == _authenticationScheme);
}

if (authScheme is null)
{
return null;
}
if (authScheme is null)
{
return null;
}

var handler = (IAuthenticationHandler)ActivatorUtilities.CreateInstance(httpContext.RequestServices, authScheme.HandlerType);
await handler.InitializeAsync(authScheme, httpContext);
httpContext.Items[$"{Constants.TenantToken}__bypass_validate_principal__"] = "true"; // Value doesn't matter.
var handlerResult = await handler.AuthenticateAsync();
httpContext.Items.Remove($"{Constants.TenantToken}__bypass_validate_principal__");
var handler =
(IAuthenticationHandler)ActivatorUtilities.CreateInstance(httpContext.RequestServices,
authScheme.HandlerType);
await handler.InitializeAsync(authScheme, httpContext);
httpContext.Items[$"{Constants.TenantToken}__bypass_validate_principal__"] = "true"; // Value doesn't matter.
var handlerResult = await handler.AuthenticateAsync();
httpContext.Items.Remove($"{Constants.TenantToken}__bypass_validate_principal__");

var identifier = handlerResult.Principal?.FindFirst(_tenantKey)?.Value;
return identifier;
}
var identifier = handlerResult.Principal?.FindFirst(_tenantKey)?.Value;
return identifier;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ public HeaderStrategy(string headerKey)

public Task<string?> GetIdentifierAsync(object context)
{
if (!(context is HttpContext httpContext))
throw new MultiTenantException(null,
new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context)));
if (context is not HttpContext httpContext)
return Task.FromResult<string?>(null);

return Task.FromResult(httpContext?.Request.Headers[_headerKey].FirstOrDefault());
}
Expand Down
112 changes: 57 additions & 55 deletions src/Finbuckle.MultiTenant.AspNetCore/Strategies/HostStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,74 +16,76 @@ public class HostStrategy : IMultiTenantStrategy

public HostStrategy(string template)
{
// New in 2.1, match whole domain if just "__tenant__".
if (template == Constants.TenantToken)
// match whole domain if just "__tenant__".
if (template == Constants.TenantToken)
{
template = template.Replace(Constants.TenantToken, "(?<identifier>.+)");
}
else
{
// Check for valid template.
// Template cannot be null or whitespace.
if (string.IsNullOrWhiteSpace(template))
{
throw new MultiTenantException("Template cannot be null or whitespace.");
}

// Wildcard "*" must be only occur once in template.
if (Regex.Match(template, @"\*(?=.*\*)").Success)
{
throw new MultiTenantException("Wildcard \"*\" must be only occur once in template.");
}

// Wildcard "*" must be only token in template segment.
if (Regex.Match(template, @"\*[^\.]|[^\.]\*").Success)
{
throw new MultiTenantException("\"*\" wildcard must be only token in template segment.");
}

// Wildcard "?" must be only token in template segment.
if (Regex.Match(template, @"\?[^\.]|[^\.]\?").Success)
{
template = template.Replace(Constants.TenantToken, "(?<identifier>.+)");
throw new MultiTenantException("\"?\" wildcard must be only token in template segment.");
}
else

template = template.Trim().Replace(".", @"\.");
string wildcardSegmentsPattern = @"(\.[^\.]+)*";
string singleSegmentPattern = @"[^\.]+";
if (template.Substring(template.Length - 3, 3) == @"\.*")
{
// Check for valid template.
// Template cannot be null or whitespace.
if (string.IsNullOrWhiteSpace(template))
{
throw new MultiTenantException("Template cannot be null or whitespace.");
}
// Wildcard "*" must be only occur once in template.
if (Regex.Match(template, @"\*(?=.*\*)").Success)
{
throw new MultiTenantException("Wildcard \"*\" must be only occur once in template.");
}
// Wildcard "*" must be only token in template segment.
if (Regex.Match(template, @"\*[^\.]|[^\.]\*").Success)
{
throw new MultiTenantException("\"*\" wildcard must be only token in template segment.");
}
// Wildcard "?" must be only token in template segment.
if (Regex.Match(template, @"\?[^\.]|[^\.]\?").Success)
{
throw new MultiTenantException("\"?\" wildcard must be only token in template segment.");
}

template = template.Trim().Replace(".", @"\.");
string wildcardSegmentsPattern = @"(\.[^\.]+)*";
string singleSegmentPattern = @"[^\.]+";
if (template.Substring(template.Length - 3, 3) == @"\.*")
{
template = template.Substring(0, template.Length - 3) + wildcardSegmentsPattern;
}

wildcardSegmentsPattern = @"([^\.]+\.)*";
template = template.Replace(@"*\.", wildcardSegmentsPattern);
template = template.Replace("?", singleSegmentPattern);
template = template.Replace(Constants.TenantToken, @"(?<identifier>[^\.]+)");
template = template.Substring(0, template.Length - 3) + wildcardSegmentsPattern;
}

this.regex = $"^{template}$";
wildcardSegmentsPattern = @"([^\.]+\.)*";
template = template.Replace(@"*\.", wildcardSegmentsPattern);
template = template.Replace("?", singleSegmentPattern);
template = template.Replace(Constants.TenantToken, @"(?<identifier>[^\.]+)");
}

this.regex = $"^{template}$";
}

public Task<string?> GetIdentifierAsync(object context)
{
if (!(context is HttpContext httpContext))
throw new MultiTenantException(null,
new ArgumentException($"\"{nameof(context)}\" type must be of type HttpContext", nameof(context)));
if (context is not HttpContext httpContext)
return Task.FromResult<string?>(null);

var host = httpContext.Request.Host;
var host = httpContext.Request.Host;

if (host.HasValue == false)
return Task.FromResult<string?>(null);
if (host.HasValue == false)
return Task.FromResult<string?>(null);

string? identifier = null;
string? identifier = null;

var match = Regex.Match(host.Host, regex,
RegexOptions.ExplicitCapture,
TimeSpan.FromMilliseconds(100));
var match = Regex.Match(host.Host, regex,
RegexOptions.ExplicitCapture,
TimeSpan.FromMilliseconds(100));

if (match.Success)
{
identifier = match.Groups["identifier"].Value;
}

return Task.FromResult(identifier);
if (match.Success)
{
identifier = match.Groups["identifier"].Value;
}

return Task.FromResult(identifier);
}
}
Loading

0 comments on commit 9834575

Please sign in to comment.