diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 5121d551c53c..47d3525deff8 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -4,13 +4,22 @@
#
# https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
-# DevOps for Actions and other workflow changes
-.github/workflows @bitwarden/dept-devops
+## Docker files have shared ownership ##
+**/Dockerfile
+**/*.Dockerfile
+**/.dockerignore
+**/entrypoint.sh
-# DevOps for Docker changes
-**/Dockerfile @bitwarden/dept-devops
-**/*.Dockerfile @bitwarden/dept-devops
-**/.dockerignore @bitwarden/dept-devops
+## BRE team owns these workflows ##
+.github/workflows/publish.yml @bitwarden/dept-bre
+
+## These are shared workflows ##
+.github/workflows/_move_finalization_db_scripts.yml
+.github/workflows/build.yml
+.github/workflows/cleanup-after-pr.yml
+.github/workflows/cleanup-rc-branch.yml
+.github/workflows/release.yml
+.github/workflows/repository-management.yml
# Database Operations for database changes
src/Sql/** @bitwarden/dept-dbops
@@ -60,6 +69,6 @@ src/EventsProcessor @bitwarden/team-admin-console-dev
src/Admin/Controllers/ToolsController.cs @bitwarden/team-billing-dev
src/Admin/Views/Tools @bitwarden/team-billing-dev
-# Multiple owners - DO NOT REMOVE (DevOps)
+# Multiple owners - DO NOT REMOVE (BRE)
**/packages.lock.json
Directory.Build.props
diff --git a/.github/workflows/_move_finalization_db_scripts.yml b/.github/workflows/_move_finalization_db_scripts.yml
index 3eb3777cef27..6e3825733a34 100644
--- a/.github/workflows/_move_finalization_db_scripts.yml
+++ b/.github/workflows/_move_finalization_db_scripts.yml
@@ -1,4 +1,3 @@
----
name: _move_finalization_db_scripts
run-name: Move finalization database scripts
diff --git a/.github/workflows/automatic-issue-responses.yml b/.github/workflows/automatic-issue-responses.yml
index 0e6b9041b93b..b6a6a1ebf5df 100644
--- a/.github/workflows/automatic-issue-responses.yml
+++ b/.github/workflows/automatic-issue-responses.yml
@@ -1,4 +1,3 @@
----
name: Automatic responses
on:
issues:
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 6df666417491..6043e1e21e5f 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -1,4 +1,3 @@
----
name: Build
on:
@@ -408,7 +407,7 @@ jobs:
name: swagger.json
path: swagger.json
if-no-files-found: error
-
+
- name: Build Internal API Swagger
run: |
cd ./src/Api
@@ -416,17 +415,17 @@ jobs:
dotnet tool restore
echo "Publish API"
dotnet publish -c "Release" -o obj/build-output/publish
-
+
dotnet swagger tofile --output ../../internal.json --host https://api.bitwarden.com \
./obj/build-output/publish/Api.dll internal
-
+
cd ../Identity
-
+
echo "Restore Identity tools"
dotnet tool restore
echo "Publish Identity"
dotnet publish -c "Release" -o obj/build-output/publish
-
+
dotnet swagger tofile --output ../../identity.json --host https://identity.bitwarden.com \
./obj/build-output/publish/Identity.dll v1
cd ../..
@@ -448,7 +447,7 @@ jobs:
with:
name: identity.json
path: identity.json
- if-no-files-found: error
+ if-no-files-found: error
build-mssqlmigratorutility:
name: Build MSSQL migrator utility
@@ -565,7 +564,7 @@ jobs:
tag: 'main'
}
})
-
+
trigger-ee-updates:
name: Trigger Ephemeral Environment updates
if: github.ref != 'refs/heads/main' && contains(github.event.pull_request.labels.*.name, 'ephemeral-environment')
@@ -595,7 +594,7 @@ jobs:
workflow_id: '_update_ephemeral_tags.yml',
ref: 'main',
inputs: {
- ephemeral_env_branch: '${{ github.head_ref }}'
+ ephemeral_env_branch: process.env.GITHUB_HEAD_REF
}
})
diff --git a/.github/workflows/cleanup-after-pr.yml b/.github/workflows/cleanup-after-pr.yml
index 1bed3542d98c..c36dc4a034b5 100644
--- a/.github/workflows/cleanup-after-pr.yml
+++ b/.github/workflows/cleanup-after-pr.yml
@@ -1,4 +1,3 @@
----
name: Container registry cleanup
on:
diff --git a/.github/workflows/cleanup-ephemeral-environment.yml b/.github/workflows/cleanup-ephemeral-environment.yml
new file mode 100644
index 000000000000..d5c34a7bb4d7
--- /dev/null
+++ b/.github/workflows/cleanup-ephemeral-environment.yml
@@ -0,0 +1,59 @@
+name: Ephemeral environment cleanup
+
+on:
+ pull_request:
+ types: [unlabeled]
+
+jobs:
+ validate-pr:
+ name: Validate PR
+ runs-on: ubuntu-24.04
+ outputs:
+ config-exists: ${{ steps.validate-config.outputs.config-exists }}
+ steps:
+ - name: Checkout PR
+ uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
+
+ - name: Validate config exists in path
+ id: validate-config
+ run: |
+ if [[ -f "ephemeral-environments/$GITHUB_HEAD_REF.yaml" ]]; then
+ echo "Ephemeral environment config found in path, continuing."
+ echo "config-exists=true" >> $GITHUB_OUTPUT
+ fi
+
+
+ cleanup-config:
+ name: Cleanup ephemeral environment
+ runs-on: ubuntu-24.04
+ needs: validate-pr
+ if: ${{ needs.validate-pr.outputs.config-exists }}
+ steps:
+ - name: Log in to Azure - CI subscription
+ uses: Azure/login@e15b166166a8746d1a47596803bd8c1b595455cf # v1.6.0
+ with:
+ creds: ${{ secrets.AZURE_KV_CI_SERVICE_PRINCIPAL }}
+
+ - name: Retrieve GitHub PAT secrets
+ id: retrieve-secret-pat
+ uses: bitwarden/gh-actions/get-keyvault-secrets@main
+ with:
+ keyvault: "bitwarden-ci"
+ secrets: "github-pat-bitwarden-devops-bot-repo-scope"
+
+ - name: Trigger Ephemeral Environment cleanup
+ uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
+ with:
+ github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}
+ script: |
+ await github.rest.actions.createWorkflowDispatch({
+ owner: 'bitwarden',
+ repo: 'devops',
+ workflow_id: '_ephemeral_environment_pr_manager.yml',
+ ref: 'main',
+ inputs: {
+ ephemeral_env_branch: process.env.GITHUB_HEAD_REF,
+ cleanup_config: true,
+ project: 'server'
+ }
+ })
diff --git a/.github/workflows/cleanup-rc-branch.yml b/.github/workflows/cleanup-rc-branch.yml
index 1eba867a9c1d..e037c18f93e9 100644
--- a/.github/workflows/cleanup-rc-branch.yml
+++ b/.github/workflows/cleanup-rc-branch.yml
@@ -1,4 +1,3 @@
----
name: Cleanup RC Branch
on:
diff --git a/.github/workflows/enforce-labels.yml b/.github/workflows/enforce-labels.yml
index 160ee15b9650..11d5654937e1 100644
--- a/.github/workflows/enforce-labels.yml
+++ b/.github/workflows/enforce-labels.yml
@@ -1,4 +1,3 @@
----
name: Enforce PR labels
on:
@@ -7,13 +6,13 @@ on:
types: [labeled, unlabeled, opened, reopened, synchronize]
jobs:
enforce-label:
- if: ${{ contains(github.event.*.labels.*.name, 'hold') || contains(github.event.*.labels.*.name, 'needs-qa') || contains(github.event.*.labels.*.name, 'DB-migrations-changed') }}
+ if: ${{ contains(github.event.*.labels.*.name, 'hold') || contains(github.event.*.labels.*.name, 'needs-qa') || contains(github.event.*.labels.*.name, 'DB-migrations-changed') || contains(github.event.*.labels.*.name, 'ephemeral-environment') }}
name: Enforce label
runs-on: ubuntu-22.04
steps:
- name: Check for label
run: |
- echo "PRs with the hold or needs-qa labels cannot be merged"
- echo "### :x: PRs with the hold or needs-qa labels cannot be merged" >> $GITHUB_STEP_SUMMARY
+ echo "PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged"
+ echo "### :x: PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged" >> $GITHUB_STEP_SUMMARY
exit 1
diff --git a/.github/workflows/protect-files.yml b/.github/workflows/protect-files.yml
index 10924f656bae..95d57180df61 100644
--- a/.github/workflows/protect-files.yml
+++ b/.github/workflows/protect-files.yml
@@ -1,7 +1,6 @@
# Runs if there are changes to the paths: list.
# Starts a matrix job to check for modified files, then sets output based on the results.
# The input decides if the label job is ran, adding a label to the PR.
----
name: Protect files
on:
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 77ea9dca4f50..4454ea1f3c28 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -1,4 +1,3 @@
----
name: Publish
run-name: Publish ${{ inputs.publish_type }}
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 0c89a01c2f13..9d5dcb74d8ca 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -1,4 +1,3 @@
----
name: Release
run-name: Release ${{ inputs.release_type }}
diff --git a/.github/workflows/stale-bot.yml b/.github/workflows/stale-bot.yml
index d3dc92cd70ea..f8a25288f283 100644
--- a/.github/workflows/stale-bot.yml
+++ b/.github/workflows/stale-bot.yml
@@ -1,4 +1,3 @@
----
name: Staleness
on:
workflow_dispatch:
diff --git a/.github/workflows/test-database.yml b/.github/workflows/test-database.yml
index 09a4b7a182aa..7a38b0f3bd4e 100644
--- a/.github/workflows/test-database.yml
+++ b/.github/workflows/test-database.yml
@@ -1,4 +1,3 @@
----
name: Database testing
on:
@@ -55,7 +54,7 @@ jobs:
# I've seen the SQL Server container not be ready for commands right after starting up and just needing a bit longer to be ready
- name: Sleep
run: sleep 15s
-
+
- name: Checking pending model changes (MySQL)
working-directory: "util/MySqlMigrations"
run: 'dotnet ef migrations has-pending-model-changes -- --GlobalSettings:MySql:ConnectionString="$CONN_STR"'
@@ -114,7 +113,7 @@ jobs:
BW_TEST_DATABASES__3__CONNECTIONSTRING: "Data Source=${{ runner.temp }}/test.db"
run: dotnet test --logger "trx;LogFileName=infrastructure-test-results.trx"
shell: pwsh
-
+
- name: Print MySQL Logs
if: failure()
run: 'docker logs $(docker ps --quiet --filter "name=mysql")'
diff --git a/Directory.Build.props b/Directory.Build.props
index 46a12ea3db5f..5cd12bfb7114 100644
--- a/Directory.Build.props
+++ b/Directory.Build.props
@@ -3,7 +3,7 @@
net8.0
- 2024.10.0
+ 2024.10.1
Bit.$(MSBuildProjectName)
enable
diff --git a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs
index 70c09a539bfe..efab8620c0d7 100644
--- a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs
+++ b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs
@@ -11,7 +11,6 @@
using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Enums;
-using Bit.Core.Exceptions;
using Bit.Core.Models.OrganizationConnectionConfigs;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories;
@@ -236,7 +235,8 @@ public async Task Edit(Guid id, OrganizationEditModel model)
if (organization.UseSecretsManager &&
!StaticStore.GetPlan(organization.PlanType).SupportsSecretsManager)
{
- throw new BadRequestException("Plan does not support Secrets Manager");
+ TempData["Error"] = "Plan does not support Secrets Manager";
+ return RedirectToAction("Edit", new { id });
}
await _organizationRepository.ReplaceAsync(organization);
diff --git a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs
index 04079138d423..4ba22130f718 100644
--- a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs
+++ b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs
@@ -181,7 +181,6 @@ public OrganizationEditModel(
*/
public object GetPlansHelper() =>
StaticStore.Plans
- .Where(p => p.SupportsSecretsManager)
.Select(p =>
{
var plan = new
diff --git a/src/Api/AdminConsole/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Controllers/PoliciesController.cs
index b3be852dbca2..7bfd13c4088e 100644
--- a/src/Api/AdminConsole/Controllers/PoliciesController.cs
+++ b/src/Api/AdminConsole/Controllers/PoliciesController.cs
@@ -25,7 +25,6 @@ public class PoliciesController : Controller
{
private readonly IPolicyRepository _policyRepository;
private readonly IPolicyService _policyService;
- private readonly IOrganizationService _organizationService;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IUserService _userService;
private readonly ICurrentContext _currentContext;
@@ -36,7 +35,6 @@ public class PoliciesController : Controller
public PoliciesController(
IPolicyRepository policyRepository,
IPolicyService policyService,
- IOrganizationService organizationService,
IOrganizationUserRepository organizationUserRepository,
IUserService userService,
ICurrentContext currentContext,
@@ -46,7 +44,6 @@ public PoliciesController(
{
_policyRepository = policyRepository;
_policyService = policyService;
- _organizationService = organizationService;
_organizationUserRepository = organizationUserRepository;
_userService = userService;
_currentContext = currentContext;
@@ -185,7 +182,7 @@ public async Task Put(string orgId, int type, [FromBody] Po
}
var userId = _userService.GetProperUserId(User);
- await _policyService.SaveAsync(policy, _organizationService, userId);
+ await _policyService.SaveAsync(policy, userId);
return new PolicyResponseModel(policy);
}
}
diff --git a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs
index 2d83bd70559f..71e03a547ec7 100644
--- a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs
+++ b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs
@@ -6,7 +6,6 @@
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Context;
-using Bit.Core.Services;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
@@ -18,18 +17,15 @@ public class PoliciesController : Controller
{
private readonly IPolicyRepository _policyRepository;
private readonly IPolicyService _policyService;
- private readonly IOrganizationService _organizationService;
private readonly ICurrentContext _currentContext;
public PoliciesController(
IPolicyRepository policyRepository,
IPolicyService policyService,
- IOrganizationService organizationService,
ICurrentContext currentContext)
{
_policyRepository = policyRepository;
_policyService = policyService;
- _organizationService = organizationService;
_currentContext = currentContext;
}
@@ -96,7 +92,7 @@ public async Task Put(PolicyType type, [FromBody] PolicyUpdateReq
{
policy = model.ToPolicy(policy);
}
- await _policyService.SaveAsync(policy, _organizationService, null);
+ await _policyService.SaveAsync(policy, null);
var response = new PolicyResponseModel(policy);
return new JsonResult(response);
}
diff --git a/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs b/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs
index 624eab1fece8..b5f9ab2f5910 100644
--- a/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs
+++ b/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs
@@ -3,8 +3,11 @@
namespace Bit.Api.Billing.Models.Responses;
public record OrganizationMetadataResponse(
+ bool IsEligibleForSelfHost,
bool IsOnSecretsManagerStandalone)
{
public static OrganizationMetadataResponse From(OrganizationMetadata metadata)
- => new(metadata.IsOnSecretsManagerStandalone);
+ => new(
+ metadata.IsEligibleForSelfHost,
+ metadata.IsOnSecretsManagerStandalone);
}
diff --git a/src/Api/Controllers/PushController.cs b/src/Api/Controllers/PushController.cs
index c83eb200b8df..38398051060e 100644
--- a/src/Api/Controllers/PushController.cs
+++ b/src/Api/Controllers/PushController.cs
@@ -46,7 +46,7 @@ await _pushRegistrationService.CreateOrUpdateRegistrationAsync(model.PushToken,
public async Task PostDelete([FromBody] PushDeviceRequestModel model)
{
CheckUsage();
- await _pushRegistrationService.DeleteRegistrationAsync(Prefix(model.Id), model.Type);
+ await _pushRegistrationService.DeleteRegistrationAsync(Prefix(model.Id));
}
[HttpPut("add-organization")]
@@ -54,7 +54,7 @@ public async Task PutAddOrganization([FromBody] PushUpdateRequestModel model)
{
CheckUsage();
await _pushRegistrationService.AddUserRegistrationOrganizationAsync(
- model.Devices.Select(d => new KeyValuePair(Prefix(d.Id), d.Type)),
+ model.Devices.Select(d => Prefix(d.Id)),
Prefix(model.OrganizationId));
}
@@ -63,7 +63,7 @@ public async Task PutDeleteOrganization([FromBody] PushUpdateRequestModel model)
{
CheckUsage();
await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(
- model.Devices.Select(d => new KeyValuePair(Prefix(d.Id), d.Type)),
+ model.Devices.Select(d => Prefix(d.Id)),
Prefix(model.OrganizationId));
}
diff --git a/src/Core/AdminConsole/Enums/PolicyType.cs b/src/Core/AdminConsole/Enums/PolicyType.cs
index 0e1786cf528f..bdde3e424ec3 100644
--- a/src/Core/AdminConsole/Enums/PolicyType.cs
+++ b/src/Core/AdminConsole/Enums/PolicyType.cs
@@ -16,3 +16,30 @@ public enum PolicyType : byte
ActivateAutofill = 11,
AutomaticAppLogIn = 12,
}
+
+public static class PolicyTypeExtensions
+{
+ ///
+ /// Returns the name of the policy for display to the user.
+ /// Do not include the word "policy" in the return value.
+ ///
+ public static string GetName(this PolicyType type)
+ {
+ return type switch
+ {
+ PolicyType.TwoFactorAuthentication => "Require two-step login",
+ PolicyType.MasterPassword => "Master password requirements",
+ PolicyType.PasswordGenerator => "Password generator",
+ PolicyType.SingleOrg => "Single organization",
+ PolicyType.RequireSso => "Require single sign-on authentication",
+ PolicyType.PersonalOwnership => "Remove individual vault",
+ PolicyType.DisableSend => "Remove Send",
+ PolicyType.SendOptions => "Send options",
+ PolicyType.ResetPassword => "Account recovery administration",
+ PolicyType.MaximumVaultTimeout => "Vault timeout",
+ PolicyType.DisablePersonalVaultExport => "Remove individual vault export",
+ PolicyType.ActivateAutofill => "Active auto-fill",
+ PolicyType.AutomaticAppLogIn => "Automatically log in users for allowed applications",
+ };
+ }
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs
index 09444306e6e5..e6d56ea878a2 100644
--- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs
+++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs
@@ -162,12 +162,12 @@ private async Task RepositoryDeleteUserAsync(OrganizationUser orgUser, Guid? del
}
}
- private async Task>> GetUserDeviceIdsAsync(Guid userId)
+ private async Task> GetUserDeviceIdsAsync(Guid userId)
{
var devices = await _deviceRepository.GetManyByUserIdAsync(userId);
return devices
.Where(d => !string.IsNullOrWhiteSpace(d.PushToken))
- .Select(d => new KeyValuePair(d.Id.ToString(), d.Type));
+ .Select(d => d.Id.ToString());
}
private async Task DeleteAndPushUserRegistrationAsync(Guid organizationId, Guid userId)
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs
new file mode 100644
index 000000000000..6aef9f248b89
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs
@@ -0,0 +1,43 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies;
+
+///
+/// Defines behavior and functionality for a given PolicyType.
+///
+public interface IPolicyValidator
+{
+ ///
+ /// The PolicyType that this definition relates to.
+ ///
+ public PolicyType Type { get; }
+
+ ///
+ /// PolicyTypes that must be enabled before this policy can be enabled, if any.
+ /// These dependencies will be checked when this policy is enabled and when any required policy is disabled.
+ ///
+ public IEnumerable RequiredPolicies { get; }
+
+ ///
+ /// Validates a policy before saving it.
+ /// Do not use this for simple dependencies between different policies - see instead.
+ /// Implementation is optional; by default it will not perform any validation.
+ ///
+ /// The policy update request
+ /// The current policy, if any
+ /// A validation error if validation was unsuccessful, otherwise an empty string
+ public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy);
+
+ ///
+ /// Performs side effects after a policy is validated but before it is saved.
+ /// For example, this can be used to remove non-compliant users from the organization.
+ /// Implementation is optional; by default it will not perform any side effects.
+ ///
+ /// The policy update request
+ /// The current policy, if any
+ public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy);
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs
new file mode 100644
index 000000000000..5bfdfc6aa7d2
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs
@@ -0,0 +1,8 @@
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies;
+
+public interface ISavePolicyCommand
+{
+ Task SaveAsync(PolicyUpdate policy);
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs
new file mode 100644
index 000000000000..01ffce2cc670
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs
@@ -0,0 +1,129 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.Enums;
+using Bit.Core.Exceptions;
+using Bit.Core.Services;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
+
+public class SavePolicyCommand : ISavePolicyCommand
+{
+ private readonly IApplicationCacheService _applicationCacheService;
+ private readonly IEventService _eventService;
+ private readonly IPolicyRepository _policyRepository;
+ private readonly IReadOnlyDictionary _policyValidators;
+ private readonly TimeProvider _timeProvider;
+
+ public SavePolicyCommand(
+ IApplicationCacheService applicationCacheService,
+ IEventService eventService,
+ IPolicyRepository policyRepository,
+ IEnumerable policyValidators,
+ TimeProvider timeProvider)
+ {
+ _applicationCacheService = applicationCacheService;
+ _eventService = eventService;
+ _policyRepository = policyRepository;
+ _timeProvider = timeProvider;
+
+ var policyValidatorsDict = new Dictionary();
+ foreach (var policyValidator in policyValidators)
+ {
+ if (!policyValidatorsDict.TryAdd(policyValidator.Type, policyValidator))
+ {
+ throw new Exception($"Duplicate PolicyValidator for {policyValidator.Type} policy.");
+ }
+ }
+
+ _policyValidators = policyValidatorsDict;
+ }
+
+ public async Task SaveAsync(PolicyUpdate policyUpdate)
+ {
+ var org = await _applicationCacheService.GetOrganizationAbilityAsync(policyUpdate.OrganizationId);
+ if (org == null)
+ {
+ throw new BadRequestException("Organization not found");
+ }
+
+ if (!org.UsePolicies)
+ {
+ throw new BadRequestException("This organization cannot use policies.");
+ }
+
+ if (_policyValidators.TryGetValue(policyUpdate.Type, out var validator))
+ {
+ await RunValidatorAsync(validator, policyUpdate);
+ }
+
+ var policy = await _policyRepository.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type)
+ ?? new Policy
+ {
+ OrganizationId = policyUpdate.OrganizationId,
+ Type = policyUpdate.Type,
+ CreationDate = _timeProvider.GetUtcNow().UtcDateTime
+ };
+
+ policy.Enabled = policyUpdate.Enabled;
+ policy.Data = policyUpdate.Data;
+ policy.RevisionDate = _timeProvider.GetUtcNow().UtcDateTime;
+
+ await _policyRepository.UpsertAsync(policy);
+ await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated);
+ }
+
+ private async Task RunValidatorAsync(IPolicyValidator validator, PolicyUpdate policyUpdate)
+ {
+ var savedPolicies = await _policyRepository.GetManyByOrganizationIdAsync(policyUpdate.OrganizationId);
+ // Note: policies may be missing from this dict if they have never been enabled
+ var savedPoliciesDict = savedPolicies.ToDictionary(p => p.Type);
+ var currentPolicy = savedPoliciesDict.GetValueOrDefault(policyUpdate.Type);
+
+ // If enabling this policy - check that all policy requirements are satisfied
+ if (currentPolicy is not { Enabled: true } && policyUpdate.Enabled)
+ {
+ var missingRequiredPolicyTypes = validator.RequiredPolicies
+ .Where(requiredPolicyType =>
+ savedPoliciesDict.GetValueOrDefault(requiredPolicyType) is not { Enabled: true })
+ .ToList();
+
+ if (missingRequiredPolicyTypes.Count != 0)
+ {
+ throw new BadRequestException($"Turn on the {missingRequiredPolicyTypes.First().GetName()} policy because it is required for the {validator.Type.GetName()} policy.");
+ }
+ }
+
+ // If disabling this policy - ensure it's not required by any other policy
+ if (currentPolicy is { Enabled: true } && !policyUpdate.Enabled)
+ {
+ var dependentPolicyTypes = _policyValidators.Values
+ .Where(otherValidator => otherValidator.RequiredPolicies.Contains(policyUpdate.Type))
+ .Select(otherValidator => otherValidator.Type)
+ .Where(otherPolicyType => savedPoliciesDict.ContainsKey(otherPolicyType) &&
+ savedPoliciesDict[otherPolicyType].Enabled)
+ .ToList();
+
+ switch (dependentPolicyTypes)
+ {
+ case { Count: 1 }:
+ throw new BadRequestException($"Turn off the {dependentPolicyTypes.First().GetName()} policy because it requires the {validator.Type.GetName()} policy.");
+ case { Count: > 1 }:
+ throw new BadRequestException($"Turn off all of the policies that require the {validator.Type.GetName()} policy.");
+ }
+ }
+
+ // Run other validation
+ var validationError = await validator.ValidateAsync(policyUpdate, currentPolicy);
+ if (!string.IsNullOrEmpty(validationError))
+ {
+ throw new BadRequestException(validationError);
+ }
+
+ // Run side effects
+ await validator.OnSaveSideEffectsAsync(policyUpdate, currentPolicy);
+ }
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs
new file mode 100644
index 000000000000..117a7ec73396
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs
@@ -0,0 +1,28 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
+using Bit.Core.Utilities;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+
+///
+/// A request for SavePolicyCommand to update a policy
+///
+public record PolicyUpdate
+{
+ public Guid OrganizationId { get; set; }
+ public PolicyType Type { get; set; }
+ public string? Data { get; set; }
+ public bool Enabled { get; set; }
+
+ public T GetDataModel() where T : IPolicyDataModel, new()
+ {
+ return CoreHelpers.LoadClassFromJsonData(Data);
+ }
+
+ public void SetDataModel(T dataModel) where T : IPolicyDataModel, new()
+ {
+ Data = CoreHelpers.ClassToJsonData(dataModel);
+ }
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs
new file mode 100644
index 000000000000..81096ef6022a
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs
@@ -0,0 +1,22 @@
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+using Bit.Core.AdminConsole.Services;
+using Bit.Core.AdminConsole.Services.Implementations;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies;
+
+public static class PolicyServiceCollectionExtensions
+{
+ public static void AddPolicyServices(this IServiceCollection services)
+ {
+ services.AddScoped();
+ services.AddScoped();
+
+ services.AddScoped();
+ services.AddScoped();
+ services.AddScoped();
+ services.AddScoped();
+ services.AddScoped();
+ }
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs
new file mode 100644
index 000000000000..bfd4dcfe0d3a
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs
@@ -0,0 +1,15 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator
+{
+ public PolicyType Type => PolicyType.MaximumVaultTimeout;
+ public IEnumerable RequiredPolicies => [PolicyType.SingleOrg];
+ public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult("");
+ public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0);
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/PolicyValidatorHelpers.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/PolicyValidatorHelpers.cs
new file mode 100644
index 000000000000..1bbaf1aa1e84
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/PolicyValidatorHelpers.cs
@@ -0,0 +1,33 @@
+#nullable enable
+
+using Bit.Core.Auth.Entities;
+using Bit.Core.Auth.Enums;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+public static class PolicyValidatorHelpers
+{
+ ///
+ /// Validate that given Member Decryption Options are not enabled.
+ /// Used for validation when disabling a policy that is required by certain Member Decryption Options.
+ ///
+ /// The Member Decryption Options that require the policy to be enabled.
+ /// A validation error if validation was unsuccessful, otherwise an empty string
+ public static string ValidateDecryptionOptionsNotEnabled(this SsoConfig? ssoConfig,
+ MemberDecryptionType[] decryptionOptions)
+ {
+ if (ssoConfig is not { Enabled: true })
+ {
+ return "";
+ }
+
+ return ssoConfig.GetData().MemberDecryptionType switch
+ {
+ MemberDecryptionType.KeyConnector when decryptionOptions.Contains(MemberDecryptionType.KeyConnector)
+ => "Key Connector is enabled and requires this policy.",
+ MemberDecryptionType.TrustedDeviceEncryption when decryptionOptions.Contains(MemberDecryptionType
+ .TrustedDeviceEncryption) => "Trusted device encryption is on and requires this policy.",
+ _ => ""
+ };
+ }
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs
new file mode 100644
index 000000000000..2082d4305fa2
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs
@@ -0,0 +1,38 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.Auth.Enums;
+using Bit.Core.Auth.Repositories;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+public class RequireSsoPolicyValidator : IPolicyValidator
+{
+ private readonly ISsoConfigRepository _ssoConfigRepository;
+
+ public RequireSsoPolicyValidator(ISsoConfigRepository ssoConfigRepository)
+ {
+ _ssoConfigRepository = ssoConfigRepository;
+ }
+
+ public PolicyType Type => PolicyType.RequireSso;
+ public IEnumerable RequiredPolicies => [PolicyType.SingleOrg];
+
+ public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
+ {
+ if (policyUpdate is not { Enabled: true })
+ {
+ var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(policyUpdate.OrganizationId);
+ return ssoConfig.ValidateDecryptionOptionsNotEnabled([
+ MemberDecryptionType.KeyConnector,
+ MemberDecryptionType.TrustedDeviceEncryption
+ ]);
+ }
+
+ return "";
+ }
+
+ public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0);
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs
new file mode 100644
index 000000000000..1126c4b922b0
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs
@@ -0,0 +1,36 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.Auth.Enums;
+using Bit.Core.Auth.Repositories;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+public class ResetPasswordPolicyValidator : IPolicyValidator
+{
+ private readonly ISsoConfigRepository _ssoConfigRepository;
+ public PolicyType Type => PolicyType.ResetPassword;
+ public IEnumerable RequiredPolicies => [PolicyType.SingleOrg];
+
+ public ResetPasswordPolicyValidator(ISsoConfigRepository ssoConfigRepository)
+ {
+ _ssoConfigRepository = ssoConfigRepository;
+ }
+
+ public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
+ {
+ if (policyUpdate is not { Enabled: true } ||
+ policyUpdate.GetDataModel().AutoEnrollEnabled == false)
+ {
+ var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(policyUpdate.OrganizationId);
+ return ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.TrustedDeviceEncryption]);
+ }
+
+ return "";
+ }
+
+ public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0);
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs
new file mode 100644
index 000000000000..3e1f8d26c81d
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs
@@ -0,0 +1,101 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.Auth.Enums;
+using Bit.Core.Auth.Repositories;
+using Bit.Core.Context;
+using Bit.Core.Enums;
+using Bit.Core.Exceptions;
+using Bit.Core.Repositories;
+using Bit.Core.Services;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+public class SingleOrgPolicyValidator : IPolicyValidator
+{
+ public PolicyType Type => PolicyType.SingleOrg;
+
+ private readonly IOrganizationUserRepository _organizationUserRepository;
+ private readonly IMailService _mailService;
+ private readonly IOrganizationRepository _organizationRepository;
+ private readonly ISsoConfigRepository _ssoConfigRepository;
+ private readonly ICurrentContext _currentContext;
+ private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
+
+ public SingleOrgPolicyValidator(
+ IOrganizationUserRepository organizationUserRepository,
+ IMailService mailService,
+ IOrganizationRepository organizationRepository,
+ ISsoConfigRepository ssoConfigRepository,
+ ICurrentContext currentContext,
+ IRemoveOrganizationUserCommand removeOrganizationUserCommand)
+ {
+ _organizationUserRepository = organizationUserRepository;
+ _mailService = mailService;
+ _organizationRepository = organizationRepository;
+ _ssoConfigRepository = ssoConfigRepository;
+ _currentContext = currentContext;
+ _removeOrganizationUserCommand = removeOrganizationUserCommand;
+ }
+
+ public IEnumerable RequiredPolicies => [];
+
+ public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
+ {
+ if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })
+ {
+ await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId);
+ }
+ }
+
+ private async Task RemoveNonCompliantUsersAsync(Guid organizationId)
+ {
+ // Remove non-compliant users
+ var savingUserId = _currentContext.UserId;
+ // Note: must get OrganizationUserUserDetails so that Email is always populated from the User object
+ var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
+ var org = await _organizationRepository.GetByIdAsync(organizationId);
+ if (org == null)
+ {
+ throw new NotFoundException("Organization not found.");
+ }
+
+ var removableOrgUsers = orgUsers.Where(ou =>
+ ou.Status != OrganizationUserStatusType.Invited &&
+ ou.Status != OrganizationUserStatusType.Revoked &&
+ ou.Type != OrganizationUserType.Owner &&
+ ou.Type != OrganizationUserType.Admin &&
+ ou.UserId != savingUserId
+ ).ToList();
+
+ var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(
+ removableOrgUsers.Select(ou => ou.UserId!.Value));
+ foreach (var orgUser in removableOrgUsers)
+ {
+ if (userOrgs.Any(ou => ou.UserId == orgUser.UserId
+ && ou.OrganizationId != org.Id
+ && ou.Status != OrganizationUserStatusType.Invited))
+ {
+ await _removeOrganizationUserCommand.RemoveUserAsync(organizationId, orgUser.Id,
+ savingUserId);
+
+ await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(
+ org.DisplayName(), orgUser.Email);
+ }
+ }
+ }
+
+ public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
+ {
+ if (policyUpdate is not { Enabled: true })
+ {
+ var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(policyUpdate.OrganizationId);
+ return ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.KeyConnector]);
+ }
+
+ return "";
+ }
+}
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs
new file mode 100644
index 000000000000..ef896bbb9b57
--- /dev/null
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs
@@ -0,0 +1,87 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
+using Bit.Core.Context;
+using Bit.Core.Enums;
+using Bit.Core.Exceptions;
+using Bit.Core.Repositories;
+using Bit.Core.Services;
+
+namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
+{
+ private readonly IOrganizationUserRepository _organizationUserRepository;
+ private readonly IMailService _mailService;
+ private readonly IOrganizationRepository _organizationRepository;
+ private readonly ICurrentContext _currentContext;
+ private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery;
+ private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
+
+ public PolicyType Type => PolicyType.TwoFactorAuthentication;
+ public IEnumerable RequiredPolicies => [];
+
+ public TwoFactorAuthenticationPolicyValidator(
+ IOrganizationUserRepository organizationUserRepository,
+ IMailService mailService,
+ IOrganizationRepository organizationRepository,
+ ICurrentContext currentContext,
+ ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
+ IRemoveOrganizationUserCommand removeOrganizationUserCommand)
+ {
+ _organizationUserRepository = organizationUserRepository;
+ _mailService = mailService;
+ _organizationRepository = organizationRepository;
+ _currentContext = currentContext;
+ _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery;
+ _removeOrganizationUserCommand = removeOrganizationUserCommand;
+ }
+
+ public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
+ {
+ if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })
+ {
+ await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId);
+ }
+ }
+
+ private async Task RemoveNonCompliantUsersAsync(Guid organizationId)
+ {
+ var org = await _organizationRepository.GetByIdAsync(organizationId);
+ var savingUserId = _currentContext.UserId;
+
+ var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
+ var organizationUsersTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(orgUsers);
+ var removableOrgUsers = orgUsers.Where(ou =>
+ ou.Status != OrganizationUserStatusType.Invited && ou.Status != OrganizationUserStatusType.Revoked &&
+ ou.Type != OrganizationUserType.Owner && ou.Type != OrganizationUserType.Admin &&
+ ou.UserId != savingUserId);
+
+ // Reorder by HasMasterPassword to prioritize checking users without a master if they have 2FA enabled
+ foreach (var orgUser in removableOrgUsers.OrderBy(ou => ou.HasMasterPassword))
+ {
+ var userTwoFactorEnabled = organizationUsersTwoFactorEnabled.FirstOrDefault(u => u.user.Id == orgUser.Id)
+ .twoFactorIsEnabled;
+ if (!userTwoFactorEnabled)
+ {
+ if (!orgUser.HasMasterPassword)
+ {
+ throw new BadRequestException(
+ "Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page.");
+ }
+
+ await _removeOrganizationUserCommand.RemoveUserAsync(organizationId, orgUser.Id,
+ savingUserId);
+
+ await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(
+ org!.DisplayName(), orgUser.Email);
+ }
+ }
+ }
+
+ public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult("");
+}
diff --git a/src/Core/AdminConsole/Services/IPolicyService.cs b/src/Core/AdminConsole/Services/IPolicyService.cs
index 6d92a3a4f749..16ff2f4fa11c 100644
--- a/src/Core/AdminConsole/Services/IPolicyService.cs
+++ b/src/Core/AdminConsole/Services/IPolicyService.cs
@@ -4,13 +4,12 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
-using Bit.Core.Services;
namespace Bit.Core.AdminConsole.Services;
public interface IPolicyService
{
- Task SaveAsync(Policy policy, IOrganizationService organizationService, Guid? savingUserId);
+ Task SaveAsync(Policy policy, Guid? savingUserId);
///
/// Get the combined master password policy options for the specified user.
diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs
index 50a2ed84eb92..f44ce686f4ed 100644
--- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs
+++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs
@@ -1838,12 +1838,12 @@ await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(devices,
}
- private async Task>> GetUserDeviceIdsAsync(Guid userId)
+ private async Task> GetUserDeviceIdsAsync(Guid userId)
{
var devices = await _deviceRepository.GetManyByUserIdAsync(userId);
return devices
.Where(d => !string.IsNullOrWhiteSpace(d.PushToken))
- .Select(d => new KeyValuePair(d.Id.ToString(), d.Type));
+ .Select(d => d.Id.ToString());
}
public async Task ReplaceAndUpdateCacheAsync(Organization org, EventType? orgEvent = null)
diff --git a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs
index 7e689f0342fe..6ab90afe04c6 100644
--- a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs
+++ b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs
@@ -2,6 +2,8 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Repositories;
@@ -27,6 +29,8 @@ public class PolicyService : IPolicyService
private readonly IMailService _mailService;
private readonly GlobalSettings _globalSettings;
private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery;
+ private readonly IFeatureService _featureService;
+ private readonly ISavePolicyCommand _savePolicyCommand;
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
public PolicyService(
@@ -39,6 +43,8 @@ public PolicyService(
IMailService mailService,
GlobalSettings globalSettings,
ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
+ IFeatureService featureService,
+ ISavePolicyCommand savePolicyCommand,
IRemoveOrganizationUserCommand removeOrganizationUserCommand)
{
_applicationCacheService = applicationCacheService;
@@ -50,11 +56,28 @@ public PolicyService(
_mailService = mailService;
_globalSettings = globalSettings;
_twoFactorIsEnabledQuery = twoFactorIsEnabledQuery;
+ _featureService = featureService;
+ _savePolicyCommand = savePolicyCommand;
_removeOrganizationUserCommand = removeOrganizationUserCommand;
}
- public async Task SaveAsync(Policy policy, IOrganizationService organizationService, Guid? savingUserId)
+ public async Task SaveAsync(Policy policy, Guid? savingUserId)
{
+ if (_featureService.IsEnabled(FeatureFlagKeys.Pm13322AddPolicyDefinitions))
+ {
+ // Transitional mapping - this will be moved to callers once the feature flag is removed
+ var policyUpdate = new PolicyUpdate
+ {
+ OrganizationId = policy.OrganizationId,
+ Type = policy.Type,
+ Enabled = policy.Enabled,
+ Data = policy.Data
+ };
+
+ await _savePolicyCommand.SaveAsync(policyUpdate);
+ return;
+ }
+
var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId);
if (org == null)
{
@@ -88,7 +111,7 @@ public async Task SaveAsync(Policy policy, IOrganizationService organizationServ
return;
}
- await EnablePolicyAsync(policy, org, organizationService, savingUserId);
+ await EnablePolicyAsync(policy, org, savingUserId);
}
public async Task GetMasterPasswordPolicyForUserAsync(User user)
@@ -262,7 +285,7 @@ private async Task SetPolicyConfiguration(Policy policy)
await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated);
}
- private async Task EnablePolicyAsync(Policy policy, Organization org, IOrganizationService organizationService, Guid? savingUserId)
+ private async Task EnablePolicyAsync(Policy policy, Organization org, Guid? savingUserId)
{
var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id);
if (!currentPolicy?.Enabled ?? true)
diff --git a/src/Core/Auth/Services/Implementations/SsoConfigService.cs b/src/Core/Auth/Services/Implementations/SsoConfigService.cs
index fdf7e278e008..532f000394c7 100644
--- a/src/Core/Auth/Services/Implementations/SsoConfigService.cs
+++ b/src/Core/Auth/Services/Implementations/SsoConfigService.cs
@@ -20,7 +20,6 @@ public class SsoConfigService : ISsoConfigService
private readonly IPolicyService _policyService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
- private readonly IOrganizationService _organizationService;
private readonly IEventService _eventService;
public SsoConfigService(
@@ -29,7 +28,6 @@ public SsoConfigService(
IPolicyService policyService,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
- IOrganizationService organizationService,
IEventService eventService)
{
_ssoConfigRepository = ssoConfigRepository;
@@ -37,7 +35,6 @@ public SsoConfigService(
_policyService = policyService;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
- _organizationService = organizationService;
_eventService = eventService;
}
@@ -71,20 +68,20 @@ public async Task SaveAsync(SsoConfig config, Organization organization)
singleOrgPolicy.Enabled = true;
- await _policyService.SaveAsync(singleOrgPolicy, _organizationService, null);
+ await _policyService.SaveAsync(singleOrgPolicy, null);
var resetPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.ResetPassword) ??
new Policy { OrganizationId = config.OrganizationId, Type = PolicyType.ResetPassword, };
resetPolicy.Enabled = true;
resetPolicy.SetDataModel(new ResetPasswordDataModel { AutoEnrollEnabled = true });
- await _policyService.SaveAsync(resetPolicy, _organizationService, null);
+ await _policyService.SaveAsync(resetPolicy, null);
var ssoRequiredPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso) ??
new Policy { OrganizationId = config.OrganizationId, Type = PolicyType.RequireSso, };
ssoRequiredPolicy.Enabled = true;
- await _policyService.SaveAsync(ssoRequiredPolicy, _organizationService, null);
+ await _policyService.SaveAsync(ssoRequiredPolicy, null);
}
await LogEventsAsync(config, oldConfig);
diff --git a/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs b/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs
index b6a58f82ff83..7bfef8a9317a 100644
--- a/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs
+++ b/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs
@@ -3,17 +3,14 @@
public enum ProviderMigrationProgress
{
Started = 1,
- ClientsMigrated = 2,
- TeamsPlanConfigured = 3,
- EnterprisePlanConfigured = 4,
- CustomerSetup = 5,
- SubscriptionSetup = 6,
- CreditApplied = 7,
- Completed = 8,
-
- Reversing = 9,
- ReversedClientMigrations = 10,
- RemovedProviderPlans = 11
+ NoClients = 2,
+ ClientsMigrated = 3,
+ TeamsPlanConfigured = 4,
+ EnterprisePlanConfigured = 5,
+ CustomerSetup = 6,
+ SubscriptionSetup = 7,
+ CreditApplied = 8,
+ Completed = 9,
}
public class ProviderMigrationTracker
diff --git a/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs b/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs
index 0da384ae27bd..9ca515a26030 100644
--- a/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs
+++ b/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs
@@ -41,7 +41,18 @@ public async Task Migrate(Guid providerId)
await migrationTrackerCache.StartTracker(provider);
- await MigrateClientsAsync(providerId);
+ var organizations = await GetClientsAsync(provider.Id);
+
+ if (organizations.Count == 0)
+ {
+ logger.LogInformation("CB: Skipping migration for provider ({ProviderID}) with no clients", providerId);
+
+ await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.NoClients);
+
+ return;
+ }
+
+ await MigrateClientsAsync(providerId, organizations);
await ConfigureTeamsPlanAsync(providerId);
@@ -65,6 +76,16 @@ public async Task GetResult(Guid providerId)
return null;
}
+ if (providerTracker.Progress == ProviderMigrationProgress.NoClients)
+ {
+ return new ProviderMigrationResult
+ {
+ ProviderId = providerTracker.ProviderId,
+ ProviderName = providerTracker.ProviderName,
+ Result = providerTracker.Progress.ToString()
+ };
+ }
+
var clientTrackers = await Task.WhenAll(providerTracker.OrganizationIds.Select(organizationId =>
migrationTrackerCache.GetTracker(providerId, organizationId)));
@@ -99,12 +120,10 @@ public async Task GetResult(Guid providerId)
#region Steps
- private async Task MigrateClientsAsync(Guid providerId)
+ private async Task MigrateClientsAsync(Guid providerId, List organizations)
{
logger.LogInformation("CB: Migrating clients for provider ({ProviderID})", providerId);
- var organizations = await GetEnabledClientsAsync(providerId);
-
var organizationIds = organizations.Select(organization => organization.Id);
await migrationTrackerCache.SetOrganizationIds(providerId, organizationIds);
@@ -129,7 +148,7 @@ private async Task ConfigureTeamsPlanAsync(Guid providerId)
{
logger.LogInformation("CB: Configuring Teams plan for provider ({ProviderID})", providerId);
- var organizations = await GetEnabledClientsAsync(providerId);
+ var organizations = await GetClientsAsync(providerId);
var teamsSeats = organizations
.Where(IsTeams)
@@ -172,7 +191,7 @@ private async Task ConfigureEnterprisePlanAsync(Guid providerId)
{
logger.LogInformation("CB: Configuring Enterprise plan for provider ({ProviderID})", providerId);
- var organizations = await GetEnabledClientsAsync(providerId);
+ var organizations = await GetClientsAsync(providerId);
var enterpriseSeats = organizations
.Where(IsEnterprise)
@@ -215,7 +234,7 @@ private async Task SetupCustomerAsync(Provider provider)
{
if (string.IsNullOrEmpty(provider.GatewayCustomerId))
{
- var organizations = await GetEnabledClientsAsync(provider.Id);
+ var organizations = await GetClientsAsync(provider.Id);
var sampleOrganization = organizations.FirstOrDefault(organization => !string.IsNullOrEmpty(organization.GatewayCustomerId));
@@ -299,28 +318,43 @@ private async Task SetupSubscriptionAsync(Provider provider)
private async Task ApplyCreditAsync(Provider provider)
{
- var organizations = await GetEnabledClientsAsync(provider.Id);
+ var organizations = await GetClientsAsync(provider.Id);
var organizationCustomers =
await Task.WhenAll(organizations.Select(organization => stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId)));
var organizationCancellationCredit = organizationCustomers.Sum(customer => customer.Balance);
- var legacyOrganizations = organizations.Where(organization =>
- organization.PlanType is
+ await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId,
+ new CustomerBalanceTransactionCreateOptions
+ {
+ Amount = organizationCancellationCredit,
+ Currency = "USD",
+ Description = "Unused, prorated time for client organization subscriptions."
+ });
+
+ var migrationRecords = await Task.WhenAll(organizations.Select(organization =>
+ clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id)));
+
+ var legacyOrganizationMigrationRecords = migrationRecords.Where(migrationRecord =>
+ migrationRecord.PlanType is
PlanType.EnterpriseAnnually2020 or
- PlanType.EnterpriseMonthly2020 or
- PlanType.TeamsAnnually2020 or
- PlanType.TeamsMonthly2020);
+ PlanType.TeamsAnnually2020);
- var legacyOrganizationCredit = legacyOrganizations.Sum(organization => organization.Seats ?? 0);
+ var legacyOrganizationCredit = legacyOrganizationMigrationRecords.Sum(migrationRecord => migrationRecord.Seats) * 12 * -100;
- await stripeAdapter.CustomerUpdateAsync(provider.GatewayCustomerId, new CustomerUpdateOptions
+ if (legacyOrganizationCredit < 0)
{
- Balance = organizationCancellationCredit + legacyOrganizationCredit
- });
+ await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId,
+ new CustomerBalanceTransactionCreateOptions
+ {
+ Amount = legacyOrganizationCredit,
+ Currency = "USD",
+ Description = "1 year rebate for legacy client organizations."
+ });
+ }
- logger.LogInformation("CB: Applied {Credit} credit to provider ({ProviderID})", organizationCancellationCredit, provider.Id);
+ logger.LogInformation("CB: Applied {Credit} credit to provider ({ProviderID})", organizationCancellationCredit + legacyOrganizationCredit, provider.Id);
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CreditApplied);
}
@@ -340,13 +374,12 @@ private async Task UpdateProviderAsync(Provider provider)
#region Utilities
- private async Task> GetEnabledClientsAsync(Guid providerId)
+ private async Task> GetClientsAsync(Guid providerId)
{
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId);
return (await Task.WhenAll(providerOrganizations.Select(providerOrganization =>
organizationRepository.GetByIdAsync(providerOrganization.OrganizationId))))
- .Where(organization => organization.Enabled)
.ToList();
}
diff --git a/src/Core/Billing/Models/OrganizationMetadata.cs b/src/Core/Billing/Models/OrganizationMetadata.cs
index decc35ffd8c5..136964d7c10a 100644
--- a/src/Core/Billing/Models/OrganizationMetadata.cs
+++ b/src/Core/Billing/Models/OrganizationMetadata.cs
@@ -1,8 +1,10 @@
namespace Bit.Core.Billing.Models;
public record OrganizationMetadata(
+ bool IsEligibleForSelfHost,
bool IsOnSecretsManagerStandalone)
{
public static OrganizationMetadata Default() => new(
- IsOnSecretsManagerStandalone: default);
+ IsEligibleForSelfHost: false,
+ IsOnSecretsManagerStandalone: false);
}
diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs
index 3c5938cab831..7db886203296 100644
--- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs
+++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs
@@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models;
@@ -26,6 +27,7 @@ public class OrganizationBillingService(
IGlobalSettings globalSettings,
ILogger logger,
IOrganizationRepository organizationRepository,
+ IProviderRepository providerRepository,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService) : IOrganizationBillingService
@@ -69,14 +71,11 @@ public async Task Finalize(OrganizationSale sale)
var subscription = await subscriberService.GetSubscription(organization);
- if (customer == null || subscription == null)
- {
- return OrganizationMetadata.Default();
- }
+ var isEligibleForSelfHost = await IsEligibleForSelfHost(organization, subscription);
var isOnSecretsManagerStandalone = IsOnSecretsManagerStandalone(organization, customer, subscription);
- return new OrganizationMetadata(isOnSecretsManagerStandalone);
+ return new OrganizationMetadata(isEligibleForSelfHost, isOnSecretsManagerStandalone);
}
public async Task UpdatePaymentMethod(
@@ -340,11 +339,38 @@ private async Task CreateSubscriptionAsync(
return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
}
+ private async Task IsEligibleForSelfHost(
+ Organization organization,
+ Subscription? organizationSubscription)
+ {
+ if (organization.Status != OrganizationStatusType.Managed)
+ {
+ return organization.Plan.Contains("Families") ||
+ organization.Plan.Contains("Enterprise") && IsActive(organizationSubscription);
+ }
+
+ var provider = await providerRepository.GetByOrganizationIdAsync(organization.Id);
+
+ var providerSubscription = await subscriberService.GetSubscriptionOrThrow(provider);
+
+ return organization.Plan.Contains("Enterprise") && IsActive(providerSubscription);
+
+ bool IsActive(Subscription? subscription) => subscription?.Status is
+ StripeConstants.SubscriptionStatus.Active or
+ StripeConstants.SubscriptionStatus.Trialing or
+ StripeConstants.SubscriptionStatus.PastDue;
+ }
+
private static bool IsOnSecretsManagerStandalone(
Organization organization,
- Customer customer,
- Subscription subscription)
+ Customer? customer,
+ Subscription? subscription)
{
+ if (customer == null || subscription == null)
+ {
+ return false;
+ }
+
var plan = StaticStore.GetPlan(organization.PlanType);
if (!plan.SupportsSecretsManager)
diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs
index 6b4cf7e971ea..ecbe190ccd2a 100644
--- a/src/Core/Constants.cs
+++ b/src/Core/Constants.cs
@@ -146,6 +146,7 @@ public static class FeatureFlagKeys
public const string RemoveServerVersionHeader = "remove-server-version-header";
public const string AccessIntelligence = "pm-13227-access-intelligence";
public const string VerifiedSsoDomainEndpoint = "pm-12337-refactor-sso-details-endpoint";
+ public const string Pm13322AddPolicyDefinitions = "pm-13322-add-policy-definitions";
public const string LimitCollectionCreationDeletionSplit = "pm-10863-limit-collection-creation-deletion-split";
public static List GetAllKeys()
diff --git a/src/Core/Models/Api/Request/PushDeviceRequestModel.cs b/src/Core/Models/Api/Request/PushDeviceRequestModel.cs
index e1866b6f2735..8b97dcc3600f 100644
--- a/src/Core/Models/Api/Request/PushDeviceRequestModel.cs
+++ b/src/Core/Models/Api/Request/PushDeviceRequestModel.cs
@@ -1,5 +1,4 @@
using System.ComponentModel.DataAnnotations;
-using Bit.Core.Enums;
namespace Bit.Core.Models.Api;
@@ -7,6 +6,4 @@ public class PushDeviceRequestModel
{
[Required]
public string Id { get; set; }
- [Required]
- public DeviceType Type { get; set; }
}
diff --git a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs
index 9f7ed5f28851..f8c2d296fd4a 100644
--- a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs
+++ b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs
@@ -1,5 +1,4 @@
using System.ComponentModel.DataAnnotations;
-using Bit.Core.Enums;
namespace Bit.Core.Models.Api;
@@ -8,9 +7,9 @@ public class PushUpdateRequestModel
public PushUpdateRequestModel()
{ }
- public PushUpdateRequestModel(IEnumerable> devices, string organizationId)
+ public PushUpdateRequestModel(IEnumerable deviceIds, string organizationId)
{
- Devices = devices.Select(d => new PushDeviceRequestModel { Id = d.Key, Type = d.Value });
+ Devices = deviceIds.Select(d => new PushDeviceRequestModel { Id = d });
OrganizationId = organizationId;
}
diff --git a/src/Core/Models/Data/InstallationDeviceEntity.cs b/src/Core/Models/Data/InstallationDeviceEntity.cs
index 3186efc661c7..a3d960b242c0 100644
--- a/src/Core/Models/Data/InstallationDeviceEntity.cs
+++ b/src/Core/Models/Data/InstallationDeviceEntity.cs
@@ -37,4 +37,25 @@ public static bool IsInstallationDeviceId(string deviceId)
{
return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_';
}
+ public static bool TryParse(string deviceId, out InstallationDeviceEntity installationDeviceEntity)
+ {
+ installationDeviceEntity = null;
+ var installationId = Guid.Empty;
+ var deviceIdGuid = Guid.Empty;
+ if (!IsInstallationDeviceId(deviceId))
+ {
+ return false;
+ }
+ var parts = deviceId.Split("_");
+ if (parts.Length < 2)
+ {
+ return false;
+ }
+ if (!Guid.TryParse(parts[0], out installationId) || !Guid.TryParse(parts[1], out deviceIdGuid))
+ {
+ return false;
+ }
+ installationDeviceEntity = new InstallationDeviceEntity(installationId, deviceIdGuid);
+ return true;
+ }
}
diff --git a/src/Core/NotificationCenter/Commands/MarkNotificationDeletedCommand.cs b/src/Core/NotificationCenter/Commands/MarkNotificationDeletedCommand.cs
index fed9fd04699c..2ca7aa9051ee 100644
--- a/src/Core/NotificationCenter/Commands/MarkNotificationDeletedCommand.cs
+++ b/src/Core/NotificationCenter/Commands/MarkNotificationDeletedCommand.cs
@@ -49,11 +49,11 @@ await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.Us
if (notificationStatus == null)
{
- notificationStatus = new NotificationStatus()
+ notificationStatus = new NotificationStatus
{
NotificationId = notificationId,
UserId = _currentContext.UserId.Value,
- DeletedDate = DateTime.Now
+ DeletedDate = DateTime.UtcNow
};
await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notificationStatus,
diff --git a/src/Core/NotificationCenter/Commands/MarkNotificationReadCommand.cs b/src/Core/NotificationCenter/Commands/MarkNotificationReadCommand.cs
index 93686605011a..400e44463a9d 100644
--- a/src/Core/NotificationCenter/Commands/MarkNotificationReadCommand.cs
+++ b/src/Core/NotificationCenter/Commands/MarkNotificationReadCommand.cs
@@ -49,11 +49,11 @@ await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.Us
if (notificationStatus == null)
{
- notificationStatus = new NotificationStatus()
+ notificationStatus = new NotificationStatus
{
NotificationId = notificationId,
UserId = _currentContext.UserId.Value,
- ReadDate = DateTime.Now
+ ReadDate = DateTime.UtcNow
};
await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notificationStatus,
diff --git a/src/Core/NotificationHub/INotificationHubClientProxy.cs b/src/Core/NotificationHub/INotificationHubClientProxy.cs
new file mode 100644
index 000000000000..82b4d3959107
--- /dev/null
+++ b/src/Core/NotificationHub/INotificationHubClientProxy.cs
@@ -0,0 +1,8 @@
+using Microsoft.Azure.NotificationHubs;
+
+namespace Bit.Core.NotificationHub;
+
+public interface INotificationHubProxy
+{
+ Task<(INotificationHubClient Client, NotificationOutcome Outcome)[]> SendTemplateNotificationAsync(IDictionary properties, string tagExpression);
+}
diff --git a/src/Core/NotificationHub/INotificationHubPool.cs b/src/Core/NotificationHub/INotificationHubPool.cs
new file mode 100644
index 000000000000..7c383d7b9649
--- /dev/null
+++ b/src/Core/NotificationHub/INotificationHubPool.cs
@@ -0,0 +1,9 @@
+using Microsoft.Azure.NotificationHubs;
+
+namespace Bit.Core.NotificationHub;
+
+public interface INotificationHubPool
+{
+ NotificationHubClient ClientFor(Guid comb);
+ INotificationHubProxy AllClients { get; }
+}
diff --git a/src/Core/NotificationHub/NotificationHubClientProxy.cs b/src/Core/NotificationHub/NotificationHubClientProxy.cs
new file mode 100644
index 000000000000..815ac8836393
--- /dev/null
+++ b/src/Core/NotificationHub/NotificationHubClientProxy.cs
@@ -0,0 +1,26 @@
+using Microsoft.Azure.NotificationHubs;
+
+namespace Bit.Core.NotificationHub;
+
+public class NotificationHubClientProxy : INotificationHubProxy
+{
+ private readonly IEnumerable _clients;
+
+ public NotificationHubClientProxy(IEnumerable clients)
+ {
+ _clients = clients;
+ }
+
+ private async Task<(INotificationHubClient, T)[]> ApplyToAllClientsAsync(Func> action)
+ {
+ var tasks = _clients.Select(async c => (c, await action(c)));
+ return await Task.WhenAll(tasks);
+ }
+
+ // partial proxy of INotificationHubClient implementation
+ // Note: Any other methods that are needed can simply be delegated as done here.
+ public async Task<(INotificationHubClient Client, NotificationOutcome Outcome)[]> SendTemplateNotificationAsync(IDictionary properties, string tagExpression)
+ {
+ return await ApplyToAllClientsAsync(async c => await c.SendTemplateNotificationAsync(properties, tagExpression));
+ }
+}
diff --git a/src/Core/NotificationHub/NotificationHubConnection.cs b/src/Core/NotificationHub/NotificationHubConnection.cs
new file mode 100644
index 000000000000..3a1437f70c52
--- /dev/null
+++ b/src/Core/NotificationHub/NotificationHubConnection.cs
@@ -0,0 +1,128 @@
+using Bit.Core.Settings;
+using Bit.Core.Utilities;
+using Microsoft.Azure.NotificationHubs;
+
+class NotificationHubConnection
+{
+ public string HubName { get; init; }
+ public string ConnectionString { get; init; }
+ public bool EnableSendTracing { get; init; }
+ private NotificationHubClient _hubClient;
+ ///
+ /// Gets the NotificationHubClient for this connection.
+ ///
+ /// If the client is null, it will be initialized.
+ ///
+ /// Exception if the connection is invalid.
+ ///
+ public NotificationHubClient HubClient
+ {
+ get
+ {
+ if (_hubClient == null)
+ {
+ if (!IsValid)
+ {
+ throw new Exception("Invalid notification hub settings");
+ }
+ Init();
+ }
+ return _hubClient;
+ }
+ private set
+ {
+ _hubClient = value;
+ }
+ }
+ ///
+ /// Gets the start date for registration.
+ ///
+ /// If null, registration is always disabled.
+ ///
+ public DateTime? RegistrationStartDate { get; init; }
+ ///
+ /// Gets the end date for registration.
+ ///
+ /// If null, registration has no end date.
+ ///
+ public DateTime? RegistrationEndDate { get; init; }
+ ///
+ /// Gets whether all data needed to generate a connection to Notification Hub is present.
+ ///
+ public bool IsValid
+ {
+ get
+ {
+ {
+ var invalid = string.IsNullOrWhiteSpace(HubName) || string.IsNullOrWhiteSpace(ConnectionString);
+ return !invalid;
+ }
+ }
+ }
+
+ public string LogString
+ {
+ get
+ {
+ return $"HubName: {HubName}, EnableSendTracing: {EnableSendTracing}, RegistrationStartDate: {RegistrationStartDate}, RegistrationEndDate: {RegistrationEndDate}";
+ }
+ }
+
+ ///
+ /// Gets whether registration is enabled for the given comb ID.
+ /// This is based off of the generation time encoded in the comb ID.
+ ///
+ ///
+ ///
+ public bool RegistrationEnabled(Guid comb)
+ {
+ var combTime = CoreHelpers.DateFromComb(comb);
+ return RegistrationEnabled(combTime);
+ }
+
+ ///
+ /// Gets whether registration is enabled for the given time.
+ ///
+ /// The time to check
+ ///
+ public bool RegistrationEnabled(DateTime queryTime)
+ {
+ if (queryTime >= RegistrationEndDate || RegistrationStartDate == null)
+ {
+ return false;
+ }
+
+ return RegistrationStartDate < queryTime;
+ }
+
+ private NotificationHubConnection() { }
+
+ ///
+ /// Creates a new NotificationHubConnection from the given settings.
+ ///
+ ///
+ ///
+ public static NotificationHubConnection From(GlobalSettings.NotificationHubSettings settings)
+ {
+ return new()
+ {
+ HubName = settings.HubName,
+ ConnectionString = settings.ConnectionString,
+ EnableSendTracing = settings.EnableSendTracing,
+ // Comb time is not precise enough for millisecond accuracy
+ RegistrationStartDate = settings.RegistrationStartDate.HasValue ? Truncate(settings.RegistrationStartDate.Value, TimeSpan.FromMilliseconds(10)) : null,
+ RegistrationEndDate = settings.RegistrationEndDate
+ };
+ }
+
+ private NotificationHubConnection Init()
+ {
+ HubClient = NotificationHubClient.CreateClientFromConnectionString(ConnectionString, HubName, EnableSendTracing);
+ return this;
+ }
+
+ private static DateTime Truncate(DateTime dateTime, TimeSpan resolution)
+ {
+ return dateTime.AddTicks(-(dateTime.Ticks % resolution.Ticks));
+ }
+}
diff --git a/src/Core/NotificationHub/NotificationHubPool.cs b/src/Core/NotificationHub/NotificationHubPool.cs
new file mode 100644
index 000000000000..7448aad5bda1
--- /dev/null
+++ b/src/Core/NotificationHub/NotificationHubPool.cs
@@ -0,0 +1,62 @@
+using Bit.Core.Settings;
+using Bit.Core.Utilities;
+using Microsoft.Azure.NotificationHubs;
+using Microsoft.Extensions.Logging;
+
+namespace Bit.Core.NotificationHub;
+
+public class NotificationHubPool : INotificationHubPool
+{
+ private List _connections { get; }
+ private readonly IEnumerable _clients;
+ private readonly ILogger _logger;
+ public NotificationHubPool(ILogger logger, GlobalSettings globalSettings)
+ {
+ _logger = logger;
+ _connections = FilterInvalidHubs(globalSettings.NotificationHubPool.NotificationHubs);
+ _clients = _connections.GroupBy(c => c.ConnectionString).Select(g => g.First().HubClient);
+ }
+
+ private List FilterInvalidHubs(IEnumerable hubs)
+ {
+ List result = new();
+ _logger.LogDebug("Filtering {HubCount} notification hubs", hubs.Count());
+ foreach (var hub in hubs)
+ {
+ var connection = NotificationHubConnection.From(hub);
+ if (!connection.IsValid)
+ {
+ _logger.LogWarning("Invalid notification hub settings: {HubName}", hub.HubName ?? "hub name missing");
+ continue;
+ }
+ _logger.LogDebug("Adding notification hub: {ConnectionLogString}", connection.LogString);
+ result.Add(connection);
+ }
+
+ return result;
+ }
+
+
+ ///
+ /// Gets the NotificationHubClient for the given comb ID.
+ ///
+ ///
+ ///
+ /// Thrown when no notification hub is found for a given comb.
+ public NotificationHubClient ClientFor(Guid comb)
+ {
+ var possibleConnections = _connections.Where(c => c.RegistrationEnabled(comb)).ToArray();
+ if (possibleConnections.Length == 0)
+ {
+ throw new InvalidOperationException($"No valid notification hubs are available for the given comb ({comb}).\n" +
+ $"The comb's datetime is {CoreHelpers.DateFromComb(comb)}." +
+ $"Hub start and end times are configured as follows:\n" +
+ string.Join("\n", _connections.Select(c => $"Hub {c.HubName} - Start: {c.RegistrationStartDate}, End: {c.RegistrationEndDate}")));
+ }
+ var resolvedConnection = possibleConnections[CoreHelpers.BinForComb(comb, possibleConnections.Length)];
+ _logger.LogTrace("Resolved notification hub for comb {Comb} out of {HubCount} hubs.\n{ConnectionInfo}", comb, possibleConnections.Length, resolvedConnection.LogString);
+ return resolvedConnection.HubClient;
+ }
+
+ public INotificationHubProxy AllClients { get { return new NotificationHubClientProxy(_clients); } }
+}
diff --git a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs b/src/Core/NotificationHub/NotificationHubPushNotificationService.cs
similarity index 84%
rename from src/Core/Services/Implementations/NotificationHubPushNotificationService.cs
rename to src/Core/NotificationHub/NotificationHubPushNotificationService.cs
index 480f0dfa9ef8..6143676deffb 100644
--- a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs
+++ b/src/Core/NotificationHub/NotificationHubPushNotificationService.cs
@@ -6,45 +6,31 @@
using Bit.Core.Models;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
-using Bit.Core.Settings;
+using Bit.Core.Services;
using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities;
using Microsoft.AspNetCore.Http;
-using Microsoft.Azure.NotificationHubs;
using Microsoft.Extensions.Logging;
-namespace Bit.Core.Services;
+namespace Bit.Core.NotificationHub;
public class NotificationHubPushNotificationService : IPushNotificationService
{
private readonly IInstallationDeviceRepository _installationDeviceRepository;
- private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
- private readonly List _clients = [];
private readonly bool _enableTracing = false;
+ private readonly INotificationHubPool _notificationHubPool;
private readonly ILogger _logger;
public NotificationHubPushNotificationService(
IInstallationDeviceRepository installationDeviceRepository,
- GlobalSettings globalSettings,
+ INotificationHubPool notificationHubPool,
IHttpContextAccessor httpContextAccessor,
ILogger logger)
{
_installationDeviceRepository = installationDeviceRepository;
- _globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor;
-
- foreach (var hub in globalSettings.NotificationHubs)
- {
- var client = NotificationHubClient.CreateClientFromConnectionString(
- hub.ConnectionString,
- hub.HubName,
- hub.EnableSendTracing);
- _clients.Add(client);
-
- _enableTracing = _enableTracing || hub.EnableSendTracing;
- }
-
+ _notificationHubPool = notificationHubPool;
_logger = logger;
}
@@ -264,30 +250,23 @@ private string BuildTag(string tag, string identifier)
private async Task SendPayloadAsync(string tag, PushType type, object payload)
{
- var tasks = new List>();
- foreach (var client in _clients)
- {
- var task = client.SendTemplateNotificationAsync(
- new Dictionary
- {
- { "type", ((byte)type).ToString() },
- { "payload", JsonSerializer.Serialize(payload) }
- }, tag);
- tasks.Add(task);
- }
-
- await Task.WhenAll(tasks);
+ var results = await _notificationHubPool.AllClients.SendTemplateNotificationAsync(
+ new Dictionary
+ {
+ { "type", ((byte)type).ToString() },
+ { "payload", JsonSerializer.Serialize(payload) }
+ }, tag);
if (_enableTracing)
{
- for (var i = 0; i < tasks.Count; i++)
+ foreach (var (client, outcome) in results)
{
- if (_clients[i].EnableTestSend)
+ if (!client.EnableTestSend)
{
- var outcome = await tasks[i];
- _logger.LogInformation("Azure Notification Hub Tracking ID: {id} | {type} push notification with {success} successes and {failure} failures with a payload of {@payload} and result of {@results}",
- outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results);
+ continue;
}
+ _logger.LogInformation("Azure Notification Hub Tracking ID: {Id} | {Type} push notification with {Success} successes and {Failure} failures with a payload of {@Payload} and result of {@Results}",
+ outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results);
}
}
}
diff --git a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs b/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs
similarity index 64%
rename from src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs
rename to src/Core/NotificationHub/NotificationHubPushRegistrationService.cs
index 87df60e8e3ce..ae32babf4477 100644
--- a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs
+++ b/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs
@@ -1,50 +1,34 @@
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
+using Bit.Core.Services;
using Bit.Core.Settings;
using Microsoft.Azure.NotificationHubs;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
-namespace Bit.Core.Services;
+namespace Bit.Core.NotificationHub;
public class NotificationHubPushRegistrationService : IPushRegistrationService
{
private readonly IInstallationDeviceRepository _installationDeviceRepository;
private readonly GlobalSettings _globalSettings;
+ private readonly INotificationHubPool _notificationHubPool;
private readonly IServiceProvider _serviceProvider;
private readonly ILogger _logger;
- private Dictionary _clients = [];
public NotificationHubPushRegistrationService(
IInstallationDeviceRepository installationDeviceRepository,
GlobalSettings globalSettings,
+ INotificationHubPool notificationHubPool,
IServiceProvider serviceProvider,
ILogger logger)
{
_installationDeviceRepository = installationDeviceRepository;
_globalSettings = globalSettings;
+ _notificationHubPool = notificationHubPool;
_serviceProvider = serviceProvider;
_logger = logger;
-
- // Is this dirty to do in the ctor?
- void addHub(NotificationHubType type)
- {
- var hubRegistration = globalSettings.NotificationHubs.FirstOrDefault(
- h => h.HubType == type && h.EnableRegistration);
- if (hubRegistration != null)
- {
- var client = NotificationHubClient.CreateClientFromConnectionString(
- hubRegistration.ConnectionString,
- hubRegistration.HubName,
- hubRegistration.EnableSendTracing);
- _clients.Add(type, client);
- }
- }
-
- addHub(NotificationHubType.General);
- addHub(NotificationHubType.iOS);
- addHub(NotificationHubType.Android);
}
public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId,
@@ -117,7 +101,7 @@ public async Task CreateOrUpdateRegistrationAsync(string pushToken, string devic
BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate,
userId, identifier);
- await GetClient(type).CreateOrUpdateInstallationAsync(installation);
+ await ClientFor(GetComb(deviceId)).CreateOrUpdateInstallationAsync(installation);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{
await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId));
@@ -152,11 +136,11 @@ private void BuildInstallationTemplate(Installation installation, string templat
installation.Templates.Add(fullTemplateId, template);
}
- public async Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType)
+ public async Task DeleteRegistrationAsync(string deviceId)
{
try
{
- await GetClient(deviceType).DeleteInstallationAsync(deviceId);
+ await ClientFor(GetComb(deviceId)).DeleteInstallationAsync(deviceId);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{
await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId));
@@ -168,31 +152,31 @@ public async Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType
}
}
- public async Task AddUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId)
+ public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId)
{
- await PatchTagsForUserDevicesAsync(devices, UpdateOperationType.Add, $"organizationId:{organizationId}");
- if (devices.Any() && InstallationDeviceEntity.IsInstallationDeviceId(devices.First().Key))
+ await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}");
+ if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First()))
{
- var entities = devices.Select(e => new InstallationDeviceEntity(e.Key));
+ var entities = deviceIds.Select(e => new InstallationDeviceEntity(e));
await _installationDeviceRepository.UpsertManyAsync(entities.ToList());
}
}
- public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId)
+ public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId)
{
- await PatchTagsForUserDevicesAsync(devices, UpdateOperationType.Remove,
+ await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove,
$"organizationId:{organizationId}");
- if (devices.Any() && InstallationDeviceEntity.IsInstallationDeviceId(devices.First().Key))
+ if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First()))
{
- var entities = devices.Select(e => new InstallationDeviceEntity(e.Key));
+ var entities = deviceIds.Select(e => new InstallationDeviceEntity(e));
await _installationDeviceRepository.UpsertManyAsync(entities.ToList());
}
}
- private async Task PatchTagsForUserDevicesAsync(IEnumerable> devices, UpdateOperationType op,
+ private async Task PatchTagsForUserDevicesAsync(IEnumerable deviceIds, UpdateOperationType op,
string tag)
{
- if (!devices.Any())
+ if (!deviceIds.Any())
{
return;
}
@@ -212,11 +196,11 @@ private async Task PatchTagsForUserDevicesAsync(IEnumerable { operation });
+ await ClientFor(GetComb(deviceId)).PatchInstallationAsync(deviceId, new List { operation });
}
catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found"))
{
@@ -225,53 +209,29 @@ private async Task PatchTagsForUserDevicesAsync(IEnumerable> devices, string organizationId);
- Task DeleteUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId);
+ Task DeleteRegistrationAsync(string deviceId);
+ Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId);
+ Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId);
}
diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs
index bb57f1cd0d7f..a288e1cbedce 100644
--- a/src/Core/Services/IStripeAdapter.cs
+++ b/src/Core/Services/IStripeAdapter.cs
@@ -10,6 +10,8 @@ public interface IStripeAdapter
Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null);
Task CustomerDeleteAsync(string id);
Task> CustomerListPaymentMethods(string id, CustomerListPaymentMethodsOptions options = null);
+ Task CustomerBalanceTransactionCreate(string customerId,
+ CustomerBalanceTransactionCreateOptions options);
Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions);
Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null);
Task> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions);
diff --git a/src/Core/Services/Implementations/DeviceService.cs b/src/Core/Services/Implementations/DeviceService.cs
index 9d8315f691ba..5b1e4b0f0171 100644
--- a/src/Core/Services/Implementations/DeviceService.cs
+++ b/src/Core/Services/Implementations/DeviceService.cs
@@ -38,13 +38,13 @@ await _pushRegistrationService.CreateOrUpdateRegistrationAsync(device.PushToken,
public async Task ClearTokenAsync(Device device)
{
await _deviceRepository.ClearPushTokenAsync(device.Id);
- await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString(), device.Type);
+ await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString());
}
public async Task DeleteAsync(Device device)
{
await _deviceRepository.DeleteAsync(device);
- await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString(), device.Type);
+ await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString());
}
public async Task UpdateDevicesTrustAsync(string currentDeviceIdentifier,
diff --git a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs
index 92e29908f50b..00be72c980e0 100644
--- a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs
+++ b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs
@@ -1,61 +1,31 @@
using Bit.Core.Auth.Entities;
using Bit.Core.Enums;
-using Bit.Core.Repositories;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities;
-using Bit.Core.Utilities;
using Bit.Core.Vault.Entities;
-using Microsoft.AspNetCore.Http;
+using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class MultiServicePushNotificationService : IPushNotificationService
{
- private readonly List _services = new List();
+ private readonly IEnumerable _services;
private readonly ILogger _logger;
public MultiServicePushNotificationService(
- IHttpClientFactory httpFactory,
- IDeviceRepository deviceRepository,
- IInstallationDeviceRepository installationDeviceRepository,
- GlobalSettings globalSettings,
- IHttpContextAccessor httpContextAccessor,
+ [FromKeyedServices("implementation")] IEnumerable services,
ILogger logger,
- ILogger relayLogger,
- ILogger hubLogger)
+ GlobalSettings globalSettings)
{
- if (globalSettings.SelfHosted)
- {
- if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) &&
- globalSettings.Installation?.Id != null &&
- CoreHelpers.SettingHasValue(globalSettings.Installation?.Key))
- {
- _services.Add(new RelayPushNotificationService(httpFactory, deviceRepository, globalSettings,
- httpContextAccessor, relayLogger));
- }
- if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) &&
- CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications))
- {
- _services.Add(new NotificationsApiPushNotificationService(
- httpFactory, globalSettings, httpContextAccessor, hubLogger));
- }
- }
- else
- {
- var generalHub = globalSettings.NotificationHubs?.FirstOrDefault(h => h.HubType == NotificationHubType.General);
- if (CoreHelpers.SettingHasValue(generalHub?.ConnectionString))
- {
- _services.Add(new NotificationHubPushNotificationService(installationDeviceRepository,
- globalSettings, httpContextAccessor, hubLogger));
- }
- if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString))
- {
- _services.Add(new AzureQueuePushNotificationService(globalSettings, httpContextAccessor));
- }
- }
+ _services = services;
_logger = logger;
+ _logger.LogInformation("Hub services: {Services}", _services.Count());
+ globalSettings?.NotificationHubPool?.NotificationHubs?.ForEach(hub =>
+ {
+ _logger.LogInformation("HubName: {HubName}, EnableSendTracing: {EnableSendTracing}, RegistrationStartDate: {RegistrationStartDate}, RegistrationEndDate: {RegistrationEndDate}", hub.HubName, hub.EnableSendTracing, hub.RegistrationStartDate, hub.RegistrationEndDate);
+ });
}
public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds)
diff --git a/src/Core/Services/Implementations/RelayPushRegistrationService.cs b/src/Core/Services/Implementations/RelayPushRegistrationService.cs
index d9df7d04dcbc..d0f7736e984b 100644
--- a/src/Core/Services/Implementations/RelayPushRegistrationService.cs
+++ b/src/Core/Services/Implementations/RelayPushRegistrationService.cs
@@ -38,37 +38,36 @@ public async Task CreateOrUpdateRegistrationAsync(string pushToken, string devic
await SendAsync(HttpMethod.Post, "push/register", requestModel);
}
- public async Task DeleteRegistrationAsync(string deviceId, DeviceType type)
+ public async Task DeleteRegistrationAsync(string deviceId)
{
var requestModel = new PushDeviceRequestModel
{
Id = deviceId,
- Type = type,
};
await SendAsync(HttpMethod.Post, "push/delete", requestModel);
}
public async Task AddUserRegistrationOrganizationAsync(
- IEnumerable> devices, string organizationId)
+ IEnumerable deviceIds, string organizationId)
{
- if (!devices.Any())
+ if (!deviceIds.Any())
{
return;
}
- var requestModel = new PushUpdateRequestModel(devices, organizationId);
+ var requestModel = new PushUpdateRequestModel(deviceIds, organizationId);
await SendAsync(HttpMethod.Put, "push/add-organization", requestModel);
}
public async Task DeleteUserRegistrationOrganizationAsync(
- IEnumerable> devices, string organizationId)
+ IEnumerable deviceIds, string organizationId)
{
- if (!devices.Any())
+ if (!deviceIds.Any())
{
return;
}
- var requestModel = new PushUpdateRequestModel(devices, organizationId);
+ var requestModel = new PushUpdateRequestModel(deviceIds, organizationId);
await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel);
}
}
diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs
index 100a47f75a3f..e5fee63b9d82 100644
--- a/src/Core/Services/Implementations/StripeAdapter.cs
+++ b/src/Core/Services/Implementations/StripeAdapter.cs
@@ -18,6 +18,7 @@ public class StripeAdapter : IStripeAdapter
private readonly Stripe.PriceService _priceService;
private readonly Stripe.SetupIntentService _setupIntentService;
private readonly Stripe.TestHelpers.TestClockService _testClockService;
+ private readonly CustomerBalanceTransactionService _customerBalanceTransactionService;
public StripeAdapter()
{
@@ -34,6 +35,7 @@ public StripeAdapter()
_priceService = new Stripe.PriceService();
_setupIntentService = new SetupIntentService();
_testClockService = new Stripe.TestHelpers.TestClockService();
+ _customerBalanceTransactionService = new CustomerBalanceTransactionService();
}
public Task CustomerCreateAsync(Stripe.CustomerCreateOptions options)
@@ -63,6 +65,10 @@ public async Task> CustomerListPaymentMethods(string id,
return paymentMethods.Data;
}
+ public async Task CustomerBalanceTransactionCreate(string customerId,
+ CustomerBalanceTransactionCreateOptions options)
+ => await _customerBalanceTransactionService.CreateAsync(customerId, options);
+
public Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options)
{
return _subscriptionService.CreateAsync(options);
diff --git a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs
index fcd0889248a3..f6279c946798 100644
--- a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs
+++ b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs
@@ -4,7 +4,7 @@ namespace Bit.Core.Services;
public class NoopPushRegistrationService : IPushRegistrationService
{
- public Task AddUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId)
+ public Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId)
{
return Task.FromResult(0);
}
@@ -15,12 +15,12 @@ public Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, s
return Task.FromResult(0);
}
- public Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType)
+ public Task DeleteRegistrationAsync(string deviceId)
{
return Task.FromResult(0);
}
- public Task DeleteUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId)
+ public Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId)
{
return Task.FromResult(0);
}
diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs
index f99fb3b57d19..793b6ac1c16a 100644
--- a/src/Core/Settings/GlobalSettings.cs
+++ b/src/Core/Settings/GlobalSettings.cs
@@ -1,5 +1,4 @@
using Bit.Core.Auth.Settings;
-using Bit.Core.Enums;
using Bit.Core.Settings.LoggingSettings;
namespace Bit.Core.Settings;
@@ -65,7 +64,7 @@ public virtual string LicenseDirectory
public virtual SentrySettings Sentry { get; set; } = new SentrySettings();
public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings();
public virtual ILogLevelSettings MinLogLevel { get; set; } = new LogLevelSettings();
- public virtual List NotificationHubs { get; set; } = new();
+ public virtual NotificationHubPoolSettings NotificationHubPool { get; set; } = new();
public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings();
public virtual DuoSettings Duo { get; set; } = new DuoSettings();
public virtual BraintreeSettings Braintree { get; set; } = new BraintreeSettings();
@@ -424,7 +423,7 @@ public class NotificationHubSettings
public string ConnectionString
{
get => _connectionString;
- set => _connectionString = value.Trim('"');
+ set => _connectionString = value?.Trim('"');
}
public string HubName { get; set; }
///
@@ -433,10 +432,32 @@ public string ConnectionString
///
public bool EnableSendTracing { get; set; } = false;
///
- /// At least one hub configuration should have registration enabled, preferably the General hub as a safety net.
+ /// The date and time at which registration will be enabled.
+ ///
+ /// **This value should not be updated once set, as it is used to determine installation location of devices.**
+ ///
+ /// If null, registration is disabled.
+ ///
+ ///
+ public DateTime? RegistrationStartDate { get; set; }
+ ///
+ /// The date and time at which registration will be disabled.
+ ///
+ /// **This value should not be updated once set, as it is used to determine installation location of devices.**
+ ///
+ /// If null, hub registration has no yet known expiry.
+ ///
+ public DateTime? RegistrationEndDate { get; set; }
+ }
+
+ public class NotificationHubPoolSettings
+ {
+ ///
+ /// List of Notification Hub settings to use for sending push notifications.
+ ///
+ /// Note that hubs on the same namespace share active device limits, so multiple namespaces should be used to increase capacity.
///
- public bool EnableRegistration { get; set; }
- public NotificationHubType HubType { get; set; }
+ public List NotificationHubs { get; set; } = new();
}
public class YubicoSettings
diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs
index 7fcaab3c4e9b..8e4aa0e4035e 100644
--- a/src/Core/Utilities/CoreHelpers.cs
+++ b/src/Core/Utilities/CoreHelpers.cs
@@ -76,6 +76,39 @@ internal static Guid GenerateComb(Guid startingGuid, DateTime time)
return new Guid(guidArray);
}
+ internal static DateTime DateFromComb(Guid combGuid)
+ {
+ var guidArray = combGuid.ToByteArray();
+ var daysArray = new byte[4];
+ var msecsArray = new byte[4];
+
+ Array.Copy(guidArray, guidArray.Length - 6, daysArray, 2, 2);
+ Array.Copy(guidArray, guidArray.Length - 4, msecsArray, 0, 4);
+
+ Array.Reverse(daysArray);
+ Array.Reverse(msecsArray);
+
+ var days = BitConverter.ToInt32(daysArray, 0);
+ var msecs = BitConverter.ToInt32(msecsArray, 0);
+
+ var time = TimeSpan.FromDays(days) + TimeSpan.FromMilliseconds(msecs * 3.333333);
+ return new DateTime(_baseDateTicks + time.Ticks, DateTimeKind.Utc);
+ }
+
+ internal static long BinForComb(Guid combGuid, int binCount)
+ {
+ // From System.Web.Util.HashCodeCombiner
+ uint CombineHashCodes(uint h1, byte h2)
+ {
+ return (uint)(((h1 << 5) + h1) ^ h2);
+ }
+ var guidArray = combGuid.ToByteArray();
+ var randomArray = new byte[10];
+ Array.Copy(guidArray, 0, randomArray, 0, 10);
+ var hash = randomArray.Aggregate((uint)randomArray.Length, CombineHashCodes);
+ return hash % binCount;
+ }
+
public static string CleanCertificateThumbprint(string thumbprint)
{
// Clean possible garbage characters from thumbprint copy/paste
diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs
index bd3aecf2f5f7..b0a2c42eaded 100644
--- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs
+++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs
@@ -4,6 +4,7 @@
using System.Security.Cryptography.X509Certificates;
using AspNetCoreRateLimit;
using Bit.Core.AdminConsole.Models.Business.Tokenables;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.Services;
using Bit.Core.AdminConsole.Services.Implementations;
using Bit.Core.AdminConsole.Services.NoopImplementations;
@@ -24,6 +25,7 @@
using Bit.Core.HostedServices;
using Bit.Core.Identity;
using Bit.Core.IdentityServer;
+using Bit.Core.NotificationHub;
using Bit.Core.OrganizationFeatures;
using Bit.Core.Repositories;
using Bit.Core.Resources;
@@ -102,9 +104,9 @@ public static void AddBaseServices(this IServiceCollection services, IGlobalSett
services.AddUserServices(globalSettings);
services.AddTrialInitiationServices();
services.AddOrganizationServices(globalSettings);
+ services.AddPolicyServices();
services.AddScoped();
services.AddScoped();
- services.AddScoped();
services.AddScoped();
services.AddScoped();
services.AddSingleton();
@@ -263,16 +265,30 @@ public static void AddDefaultServices(this IServiceCollection services, GlobalSe
}
services.AddSingleton();
- if (globalSettings.SelfHosted &&
- CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) &&
- globalSettings.Installation?.Id != null &&
- CoreHelpers.SettingHasValue(globalSettings.Installation?.Key))
+ if (globalSettings.SelfHosted)
{
- services.AddSingleton();
+ if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) &&
+ globalSettings.Installation?.Id != null &&
+ CoreHelpers.SettingHasValue(globalSettings.Installation?.Key))
+ {
+ services.AddKeyedSingleton("implementation");
+ services.AddSingleton();
+ }
+ if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) &&
+ CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications))
+ {
+ services.AddKeyedSingleton("implementation");
+ }
}
else if (!globalSettings.SelfHosted)
{
+ services.AddSingleton();
services.AddSingleton();
+ services.AddKeyedSingleton("implementation");
+ if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString))
+ {
+ services.AddKeyedSingleton("implementation");
+ }
}
else
{
diff --git a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs
index 70ca599400e1..b46fd307e96b 100644
--- a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs
+++ b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs
@@ -52,7 +52,7 @@ public async Task GetMetadataAsync_OK(
{
sutProvider.GetDependency().AccessMembersTab(organizationId).Returns(true);
sutProvider.GetDependency().GetMetadata(organizationId)
- .Returns(new OrganizationMetadata(true));
+ .Returns(new OrganizationMetadata(true, true));
var result = await sutProvider.Sut.GetMetadataAsync(organizationId);
@@ -60,6 +60,7 @@ public async Task GetMetadataAsync_OK(
var organizationMetadataResponse = ((Ok)result).Value;
+ Assert.True(organizationMetadataResponse.IsEligibleForSelfHost);
Assert.True(organizationMetadataResponse.IsOnSecretsManagerStandalone);
}
diff --git a/test/Common/AutoFixture/ControllerCustomization.cs b/test/Common/AutoFixture/ControllerCustomization.cs
index f695f86b5503..91fffbf09971 100644
--- a/test/Common/AutoFixture/ControllerCustomization.cs
+++ b/test/Common/AutoFixture/ControllerCustomization.cs
@@ -1,6 +1,5 @@
using AutoFixture;
using Microsoft.AspNetCore.Mvc;
-using Org.BouncyCastle.Security;
namespace Bit.Test.Common.AutoFixture;
@@ -15,7 +14,7 @@ public ControllerCustomization(Type controllerType)
{
if (!controllerType.IsAssignableTo(typeof(Controller)))
{
- throw new InvalidParameterException($"{nameof(controllerType)} must derive from {typeof(Controller).Name}");
+ throw new Exception($"{nameof(controllerType)} must derive from {typeof(Controller).Name}");
}
_controllerType = controllerType;
diff --git a/test/Common/AutoFixture/SutProvider.cs b/test/Common/AutoFixture/SutProvider.cs
index ac953965bdfc..fefe6c3ebf25 100644
--- a/test/Common/AutoFixture/SutProvider.cs
+++ b/test/Common/AutoFixture/SutProvider.cs
@@ -127,7 +127,6 @@ public object Create(object request, ISpecimenContext context)
return _sutProvider.GetDependency(parameterInfo.ParameterType, "");
}
-
// This is the equivalent of _fixture.Create, but no overload for
// Create(Type type) exists.
var dependency = new SpecimenContext(_fixture).Resolve(new SeededRequest(parameterInfo.ParameterType,
diff --git a/test/Common/AutoFixture/SutProviderExtensions.cs b/test/Common/AutoFixture/SutProviderExtensions.cs
index 1fdf22653943..bdc8604166fd 100644
--- a/test/Common/AutoFixture/SutProviderExtensions.cs
+++ b/test/Common/AutoFixture/SutProviderExtensions.cs
@@ -1,6 +1,7 @@
using AutoFixture;
using Bit.Core.Services;
using Bit.Core.Settings;
+using Microsoft.Extensions.Time.Testing;
using NSubstitute;
using RichardSzalay.MockHttp;
@@ -47,4 +48,19 @@ public static SutProvider ConfigureBaseIdentityClientService(this SutProvi
.SetDependency(mockHttpClientFactory)
.Create();
}
+
+ ///
+ /// Configures SutProvider to use FakeTimeProvider.
+ /// It is registered under both the TimeProvider type and the FakeTimeProvider type
+ /// so that it can be retrieved in a type-safe manner with GetDependency.
+ /// This can be chained with other builder methods; make sure to call
+ /// before use.
+ ///
+ public static SutProvider WithFakeTimeProvider(this SutProvider sutProvider)
+ {
+ var fakeTimeProvider = new FakeTimeProvider();
+ return sutProvider
+ .SetDependency((TimeProvider)fakeTimeProvider)
+ .SetDependency(fakeTimeProvider);
+ }
}
diff --git a/test/Common/Common.csproj b/test/Common/Common.csproj
index 7b9fbe42d5eb..2f11798cef8f 100644
--- a/test/Common/Common.csproj
+++ b/test/Common/Common.csproj
@@ -5,6 +5,7 @@
+
diff --git a/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs
index f70fd579e371..09b112c43c1e 100644
--- a/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs
+++ b/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs
@@ -9,10 +9,12 @@ namespace Bit.Core.Test.AdminConsole.AutoFixture;
internal class PolicyCustomization : ICustomization
{
public PolicyType Type { get; set; }
+ public bool Enabled { get; set; }
- public PolicyCustomization(PolicyType type)
+ public PolicyCustomization(PolicyType type, bool enabled)
{
Type = type;
+ Enabled = enabled;
}
public void Customize(IFixture fixture)
@@ -20,21 +22,23 @@ public void Customize(IFixture fixture)
fixture.Customize(composer => composer
.With(o => o.OrganizationId, Guid.NewGuid())
.With(o => o.Type, Type)
- .With(o => o.Enabled, true));
+ .With(o => o.Enabled, Enabled));
}
}
public class PolicyAttribute : CustomizeAttribute
{
private readonly PolicyType _type;
+ private readonly bool _enabled;
- public PolicyAttribute(PolicyType type)
+ public PolicyAttribute(PolicyType type, bool enabled = true)
{
_type = type;
+ _enabled = enabled;
}
public override ICustomization GetCustomization(ParameterInfo parameter)
{
- return new PolicyCustomization(_type);
+ return new PolicyCustomization(_type, _enabled);
}
}
diff --git a/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs
new file mode 100644
index 000000000000..dff9b571782c
--- /dev/null
+++ b/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs
@@ -0,0 +1,25 @@
+using System.Reflection;
+using AutoFixture;
+using AutoFixture.Xunit2;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+
+namespace Bit.Core.Test.AdminConsole.AutoFixture;
+
+internal class PolicyUpdateCustomization(PolicyType type, bool enabled) : ICustomization
+{
+ public void Customize(IFixture fixture)
+ {
+ fixture.Customize(composer => composer
+ .With(o => o.Type, type)
+ .With(o => o.Enabled, enabled));
+ }
+}
+
+public class PolicyUpdateAttribute(PolicyType type, bool enabled = true) : CustomizeAttribute
+{
+ public override ICustomization GetCustomization(ParameterInfo parameter)
+ {
+ return new PolicyUpdateCustomization(type, enabled);
+ }
+}
diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorFixtures.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorFixtures.cs
new file mode 100644
index 000000000000..ba4741d8bdf9
--- /dev/null
+++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorFixtures.cs
@@ -0,0 +1,43 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using NSubstitute;
+
+namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies;
+
+public class FakeSingleOrgPolicyValidator : IPolicyValidator
+{
+ public PolicyType Type => PolicyType.SingleOrg;
+ public IEnumerable RequiredPolicies => Array.Empty();
+
+ public readonly Func> ValidateAsyncMock = Substitute.For>>();
+ public readonly Action OnSaveSideEffectsAsyncMock = Substitute.For>();
+
+ public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
+ {
+ return ValidateAsyncMock(policyUpdate, currentPolicy);
+ }
+
+ public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
+ {
+ OnSaveSideEffectsAsyncMock(policyUpdate, currentPolicy);
+ return Task.FromResult(0);
+ }
+}
+public class FakeRequireSsoPolicyValidator : IPolicyValidator
+{
+ public PolicyType Type => PolicyType.RequireSso;
+ public IEnumerable RequiredPolicies => [PolicyType.SingleOrg];
+ public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult("");
+ public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0);
+}
+public class FakeVaultTimeoutPolicyValidator : IPolicyValidator
+{
+ public PolicyType Type => PolicyType.MaximumVaultTimeout;
+ public IEnumerable RequiredPolicies => [PolicyType.SingleOrg];
+ public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult("");
+ public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0);
+}
diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorHelpersTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorHelpersTests.cs
new file mode 100644
index 000000000000..99f99706fa86
--- /dev/null
+++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorHelpersTests.cs
@@ -0,0 +1,64 @@
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+using Bit.Core.Auth.Entities;
+using Bit.Core.Auth.Enums;
+using Bit.Core.Auth.Models.Data;
+using Xunit;
+
+namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies;
+
+public class PolicyValidatorHelpersTests
+{
+ [Fact]
+ public void ValidateDecryptionOptionsNotEnabled_RequiredByKeyConnector_ValidationError()
+ {
+ var ssoConfig = new SsoConfig();
+ ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector });
+
+ var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.KeyConnector]);
+
+ Assert.Contains("Key Connector is enabled", result);
+ }
+
+ [Fact]
+ public void ValidateDecryptionOptionsNotEnabled_RequiredByTDE_ValidationError()
+ {
+ var ssoConfig = new SsoConfig();
+ ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption });
+
+ var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.TrustedDeviceEncryption]);
+
+ Assert.Contains("Trusted device encryption is on", result);
+ }
+
+ [Fact]
+ public void ValidateDecryptionOptionsNotEnabled_NullSsoConfig_NoValidationError()
+ {
+ var ssoConfig = new SsoConfig();
+ var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.KeyConnector]);
+
+ Assert.True(string.IsNullOrEmpty(result));
+ }
+
+ [Fact]
+ public void ValidateDecryptionOptionsNotEnabled_RequiredOptionNotEnabled_NoValidationError()
+ {
+ var ssoConfig = new SsoConfig();
+ ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector });
+
+ var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.TrustedDeviceEncryption]);
+
+ Assert.True(string.IsNullOrEmpty(result));
+ }
+
+ [Fact]
+ public void ValidateDecryptionOptionsNotEnabled_SsoConfigDisabled_NoValidationError()
+ {
+ var ssoConfig = new SsoConfig();
+ ssoConfig.Enabled = false;
+ ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector });
+
+ var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.KeyConnector]);
+
+ Assert.True(string.IsNullOrEmpty(result));
+ }
+}
diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs
new file mode 100644
index 000000000000..d3af765f799a
--- /dev/null
+++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs
@@ -0,0 +1,75 @@
+namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+using Bit.Core.Auth.Entities;
+using Bit.Core.Auth.Enums;
+using Bit.Core.Auth.Models.Data;
+using Bit.Core.Auth.Repositories;
+using Bit.Core.Test.AdminConsole.AutoFixture;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using NSubstitute;
+using Xunit;
+
+[SutProviderCustomize]
+public class RequireSsoPolicyValidatorTests
+{
+ [Theory, BitAutoData]
+ public async Task ValidateAsync_DisablingPolicy_KeyConnectorEnabled_ValidationError(
+ [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg)] Policy policy,
+ SutProvider sutProvider)
+ {
+ policy.OrganizationId = policyUpdate.OrganizationId;
+
+ var ssoConfig = new SsoConfig { Enabled = true };
+ ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector });
+
+ sutProvider.GetDependency()
+ .GetByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns(ssoConfig);
+
+ var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy);
+ Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Theory, BitAutoData]
+ public async Task ValidateAsync_DisablingPolicy_TdeEnabled_ValidationError(
+ [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg)] Policy policy,
+ SutProvider sutProvider)
+ {
+ policy.OrganizationId = policyUpdate.OrganizationId;
+
+ var ssoConfig = new SsoConfig { Enabled = true };
+ ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption });
+
+ sutProvider.GetDependency()
+ .GetByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns(ssoConfig);
+
+ var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy);
+ Assert.Contains("Trusted device encryption is on", result, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Theory, BitAutoData]
+ public async Task ValidateAsync_DisablingPolicy_DecryptionOptionsNotEnabled_Success(
+ [PolicyUpdate(PolicyType.ResetPassword, false)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.ResetPassword)] Policy policy,
+ SutProvider sutProvider)
+ {
+ policy.OrganizationId = policyUpdate.OrganizationId;
+
+ var ssoConfig = new SsoConfig { Enabled = false };
+
+ sutProvider.GetDependency()
+ .GetByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns(ssoConfig);
+
+ var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy);
+ Assert.True(string.IsNullOrEmpty(result));
+ }
+}
diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs
new file mode 100644
index 000000000000..83939406b590
--- /dev/null
+++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs
@@ -0,0 +1,71 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+using Bit.Core.Auth.Entities;
+using Bit.Core.Auth.Enums;
+using Bit.Core.Auth.Models.Data;
+using Bit.Core.Auth.Repositories;
+using Bit.Core.Test.AdminConsole.AutoFixture;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using NSubstitute;
+using Xunit;
+
+namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+[SutProviderCustomize]
+public class ResetPasswordPolicyValidatorTests
+{
+ [Theory]
+ [BitAutoData(true, false)]
+ [BitAutoData(false, true)]
+ [BitAutoData(false, false)]
+ public async Task ValidateAsync_DisablingPolicy_TdeEnabled_ValidationError(
+ bool policyEnabled,
+ bool autoEnrollEnabled,
+ [PolicyUpdate(PolicyType.ResetPassword)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.ResetPassword)] Policy policy,
+ SutProvider sutProvider)
+ {
+ policyUpdate.Enabled = policyEnabled;
+ policyUpdate.SetDataModel(new ResetPasswordDataModel
+ {
+ AutoEnrollEnabled = autoEnrollEnabled
+ });
+ policy.OrganizationId = policyUpdate.OrganizationId;
+
+ var ssoConfig = new SsoConfig { Enabled = true };
+ ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption });
+
+ sutProvider.GetDependency()
+ .GetByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns(ssoConfig);
+
+ var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy);
+ Assert.Contains("Trusted device encryption is on and requires this policy.", result, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Theory, BitAutoData]
+ public async Task ValidateAsync_DisablingPolicy_TdeNotEnabled_Success(
+ [PolicyUpdate(PolicyType.ResetPassword, false)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.ResetPassword)] Policy policy,
+ SutProvider sutProvider)
+ {
+ policyUpdate.SetDataModel(new ResetPasswordDataModel
+ {
+ AutoEnrollEnabled = false
+ });
+ policy.OrganizationId = policyUpdate.OrganizationId;
+
+ var ssoConfig = new SsoConfig { Enabled = false };
+
+ sutProvider.GetDependency()
+ .GetByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns(ssoConfig);
+
+ var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy);
+ Assert.True(string.IsNullOrEmpty(result));
+ }
+}
diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs
new file mode 100644
index 000000000000..76ee5748403f
--- /dev/null
+++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs
@@ -0,0 +1,129 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+using Bit.Core.Auth.Entities;
+using Bit.Core.Auth.Enums;
+using Bit.Core.Auth.Models.Data;
+using Bit.Core.Auth.Repositories;
+using Bit.Core.Context;
+using Bit.Core.Entities;
+using Bit.Core.Enums;
+using Bit.Core.Models.Data.Organizations.OrganizationUsers;
+using Bit.Core.Repositories;
+using Bit.Core.Services;
+using Bit.Core.Test.AdminConsole.AutoFixture;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using NSubstitute;
+using Xunit;
+
+namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+[SutProviderCustomize]
+public class SingleOrgPolicyValidatorTests
+{
+ [Theory, BitAutoData]
+ public async Task ValidateAsync_DisablingPolicy_KeyConnectorEnabled_ValidationError(
+ [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg)] Policy policy,
+ SutProvider sutProvider)
+ {
+ policy.OrganizationId = policyUpdate.OrganizationId;
+
+ var ssoConfig = new SsoConfig { Enabled = true };
+ ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector });
+
+ sutProvider.GetDependency()
+ .GetByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns(ssoConfig);
+
+ var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy);
+ Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Theory, BitAutoData]
+ public async Task ValidateAsync_DisablingPolicy_KeyConnectorNotEnabled_Success(
+ [PolicyUpdate(PolicyType.ResetPassword, false)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.ResetPassword)] Policy policy,
+ SutProvider sutProvider)
+ {
+ policy.OrganizationId = policyUpdate.OrganizationId;
+
+ var ssoConfig = new SsoConfig { Enabled = false };
+
+ sutProvider.GetDependency()
+ .GetByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns(ssoConfig);
+
+ var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy);
+ Assert.True(string.IsNullOrEmpty(result));
+ }
+
+ [Theory, BitAutoData]
+ public async Task OnSaveSideEffectsAsync_RemovesNonCompliantUsers(
+ [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg, false)] Policy policy,
+ Guid savingUserId,
+ Guid nonCompliantUserId,
+ Organization organization, SutProvider sutProvider)
+ {
+ policy.OrganizationId = organization.Id = policyUpdate.OrganizationId;
+
+ var compliantUser1 = new OrganizationUserUserDetails
+ {
+ OrganizationId = organization.Id,
+ Type = OrganizationUserType.User,
+ Status = OrganizationUserStatusType.Confirmed,
+ UserId = new Guid(),
+ Email = "user1@example.com"
+ };
+
+ var compliantUser2 = new OrganizationUserUserDetails
+ {
+ OrganizationId = organization.Id,
+ Type = OrganizationUserType.User,
+ Status = OrganizationUserStatusType.Confirmed,
+ UserId = new Guid(),
+ Email = "user2@example.com"
+ };
+
+ var nonCompliantUser = new OrganizationUserUserDetails
+ {
+ OrganizationId = organization.Id,
+ Type = OrganizationUserType.User,
+ Status = OrganizationUserStatusType.Confirmed,
+ UserId = nonCompliantUserId,
+ Email = "user3@example.com"
+ };
+
+ sutProvider.GetDependency()
+ .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
+ .Returns([compliantUser1, compliantUser2, nonCompliantUser]);
+
+ var otherOrganizationUser = new OrganizationUser
+ {
+ OrganizationId = new Guid(),
+ UserId = nonCompliantUserId,
+ Status = OrganizationUserStatusType.Confirmed
+ };
+
+ sutProvider.GetDependency()
+ .GetManyByManyUsersAsync(Arg.Is>(ids => ids.Contains(nonCompliantUserId)))
+ .Returns([otherOrganizationUser]);
+
+ sutProvider.GetDependency().UserId.Returns(savingUserId);
+ sutProvider.GetDependency().GetByIdAsync(policyUpdate.OrganizationId).Returns(organization);
+
+ await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy);
+
+ await sutProvider.GetDependency()
+ .Received(1)
+ .RemoveUserAsync(policyUpdate.OrganizationId, nonCompliantUser.Id, savingUserId);
+ await sutProvider.GetDependency()
+ .Received(1)
+ .SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(organization.DisplayName(),
+ "user3@example.com");
+ }
+}
diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs
new file mode 100644
index 000000000000..4dce13174934
--- /dev/null
+++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs
@@ -0,0 +1,209 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
+using Bit.Core.Context;
+using Bit.Core.Enums;
+using Bit.Core.Exceptions;
+using Bit.Core.Models.Data.Organizations.OrganizationUsers;
+using Bit.Core.Repositories;
+using Bit.Core.Services;
+using Bit.Core.Test.AdminConsole.AutoFixture;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using NSubstitute;
+using Xunit;
+
+namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
+
+[SutProviderCustomize]
+public class TwoFactorAuthenticationPolicyValidatorTests
+{
+ [Theory, BitAutoData]
+ public async Task OnSaveSideEffectsAsync_RemovesNonCompliantUsers(
+ Organization organization,
+ [PolicyUpdate(PolicyType.TwoFactorAuthentication)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.TwoFactorAuthentication, false)] Policy policy,
+ SutProvider sutProvider)
+ {
+ policy.OrganizationId = organization.Id = policyUpdate.OrganizationId;
+ sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization);
+
+ var orgUserDetailUserInvited = new OrganizationUserUserDetails
+ {
+ Id = Guid.NewGuid(),
+ Status = OrganizationUserStatusType.Invited,
+ Type = OrganizationUserType.User,
+ // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync
+ Email = "user1@test.com",
+ Name = "TEST",
+ UserId = Guid.NewGuid(),
+ HasMasterPassword = false
+ };
+ var orgUserDetailUserAcceptedWith2FA = new OrganizationUserUserDetails
+ {
+ Id = Guid.NewGuid(),
+ Status = OrganizationUserStatusType.Accepted,
+ Type = OrganizationUserType.User,
+ // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync
+ Email = "user2@test.com",
+ Name = "TEST",
+ UserId = Guid.NewGuid(),
+ HasMasterPassword = true
+ };
+ var orgUserDetailUserAcceptedWithout2FA = new OrganizationUserUserDetails
+ {
+ Id = Guid.NewGuid(),
+ Status = OrganizationUserStatusType.Accepted,
+ Type = OrganizationUserType.User,
+ // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync
+ Email = "user3@test.com",
+ Name = "TEST",
+ UserId = Guid.NewGuid(),
+ HasMasterPassword = true
+ };
+ var orgUserDetailAdmin = new OrganizationUserUserDetails
+ {
+ Id = Guid.NewGuid(),
+ Status = OrganizationUserStatusType.Confirmed,
+ Type = OrganizationUserType.Admin,
+ // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync
+ Email = "admin@test.com",
+ Name = "ADMIN",
+ UserId = Guid.NewGuid(),
+ HasMasterPassword = false
+ };
+
+ sutProvider.GetDependency()
+ .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
+ .Returns(new List
+ {
+ orgUserDetailUserInvited,
+ orgUserDetailUserAcceptedWith2FA,
+ orgUserDetailUserAcceptedWithout2FA,
+ orgUserDetailAdmin
+ });
+
+ sutProvider.GetDependency()
+ .TwoFactorIsEnabledAsync(Arg.Any>())
+ .Returns(new List<(OrganizationUserUserDetails user, bool hasTwoFactor)>()
+ {
+ (orgUserDetailUserInvited, false),
+ (orgUserDetailUserAcceptedWith2FA, true),
+ (orgUserDetailUserAcceptedWithout2FA, false),
+ (orgUserDetailAdmin, false),
+ });
+
+ var savingUserId = Guid.NewGuid();
+ sutProvider.GetDependency().UserId.Returns(savingUserId);
+
+ await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy);
+
+ var removeOrganizationUserCommand = sutProvider.GetDependency();
+
+ await removeOrganizationUserCommand.Received()
+ .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserAcceptedWithout2FA.Id, savingUserId);
+ await sutProvider.GetDependency().Received()
+ .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailUserAcceptedWithout2FA.Email);
+
+ await removeOrganizationUserCommand.DidNotReceive()
+ .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserInvited.Id, savingUserId);
+ await sutProvider.GetDependency().DidNotReceive()
+ .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailUserInvited.Email);
+ await removeOrganizationUserCommand.DidNotReceive()
+ .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserAcceptedWith2FA.Id, savingUserId);
+ await sutProvider.GetDependency().DidNotReceive()
+ .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailUserAcceptedWith2FA.Email);
+ await removeOrganizationUserCommand.DidNotReceive()
+ .RemoveUserAsync(policy.OrganizationId, orgUserDetailAdmin.Id, savingUserId);
+ await sutProvider.GetDependency().DidNotReceive()
+ .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailAdmin.Email);
+ }
+
+ [Theory, BitAutoData]
+ public async Task OnSaveSideEffectsAsync_UsersToBeRemovedDontHaveMasterPasswords_Throws(
+ Organization organization,
+ [PolicyUpdate(PolicyType.TwoFactorAuthentication)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.TwoFactorAuthentication, false)] Policy policy,
+ SutProvider sutProvider)
+ {
+ policy.OrganizationId = organization.Id = policyUpdate.OrganizationId;
+
+ var orgUserDetailUserWith2FAAndMP = new OrganizationUserUserDetails
+ {
+ Id = Guid.NewGuid(),
+ Status = OrganizationUserStatusType.Confirmed,
+ Type = OrganizationUserType.User,
+ // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync
+ Email = "user1@test.com",
+ Name = "TEST",
+ UserId = Guid.NewGuid(),
+ HasMasterPassword = true
+ };
+ var orgUserDetailUserWith2FANoMP = new OrganizationUserUserDetails
+ {
+ Id = Guid.NewGuid(),
+ Status = OrganizationUserStatusType.Confirmed,
+ Type = OrganizationUserType.User,
+ // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync
+ Email = "user2@test.com",
+ Name = "TEST",
+ UserId = Guid.NewGuid(),
+ HasMasterPassword = false
+ };
+ var orgUserDetailUserWithout2FA = new OrganizationUserUserDetails
+ {
+ Id = Guid.NewGuid(),
+ Status = OrganizationUserStatusType.Confirmed,
+ Type = OrganizationUserType.User,
+ // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync
+ Email = "user3@test.com",
+ Name = "TEST",
+ UserId = Guid.NewGuid(),
+ HasMasterPassword = false
+ };
+ var orgUserDetailAdmin = new OrganizationUserUserDetails
+ {
+ Id = Guid.NewGuid(),
+ Status = OrganizationUserStatusType.Confirmed,
+ Type = OrganizationUserType.Admin,
+ // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync
+ Email = "admin@test.com",
+ Name = "ADMIN",
+ UserId = Guid.NewGuid(),
+ HasMasterPassword = false
+ };
+
+ sutProvider.GetDependency()
+ .GetManyDetailsByOrganizationAsync(policy.OrganizationId)
+ .Returns(new List
+ {
+ orgUserDetailUserWith2FAAndMP,
+ orgUserDetailUserWith2FANoMP,
+ orgUserDetailUserWithout2FA,
+ orgUserDetailAdmin
+ });
+
+ sutProvider.GetDependency()
+ .TwoFactorIsEnabledAsync(Arg.Is>(ids =>
+ ids.Contains(orgUserDetailUserWith2FANoMP.UserId.Value)
+ && ids.Contains(orgUserDetailUserWithout2FA.UserId.Value)
+ && ids.Contains(orgUserDetailAdmin.UserId.Value)))
+ .Returns(new List<(Guid userId, bool hasTwoFactor)>()
+ {
+ (orgUserDetailUserWith2FANoMP.UserId.Value, true),
+ (orgUserDetailUserWithout2FA.UserId.Value, false),
+ (orgUserDetailAdmin.UserId.Value, false),
+ });
+
+ var badRequestException = await Assert.ThrowsAsync(
+ () => sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy));
+
+ Assert.Contains("Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
+
+ await sutProvider.GetDependency().DidNotReceiveWithAnyArgs()
+ .RemoveUserAsync(organizationId: default, organizationUserId: default, deletingUserId: default);
+ }
+}
diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs
new file mode 100644
index 000000000000..342ede9c829d
--- /dev/null
+++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs
@@ -0,0 +1,330 @@
+#nullable enable
+
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Enums;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
+using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.Exceptions;
+using Bit.Core.Models.Data.Organizations;
+using Bit.Core.Services;
+using Bit.Core.Test.AdminConsole.AutoFixture;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using Microsoft.Extensions.Time.Testing;
+using NSubstitute;
+using Xunit;
+using EventType = Bit.Core.Enums.EventType;
+
+namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies;
+
+public class SavePolicyCommandTests
+{
+ [Theory, BitAutoData]
+ public async Task SaveAsync_NewPolicy_Success([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate)
+ {
+ var fakePolicyValidator = new FakeSingleOrgPolicyValidator();
+ fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns("");
+ var sutProvider = SutProviderFactory([fakePolicyValidator]);
+
+ ArrangeOrganization(sutProvider, policyUpdate);
+ sutProvider.GetDependency().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([]);
+
+ var creationDate = sutProvider.GetDependency().Start;
+
+ await sutProvider.Sut.SaveAsync(policyUpdate);
+
+ await fakePolicyValidator.ValidateAsyncMock.Received(1).Invoke(policyUpdate, null);
+ fakePolicyValidator.OnSaveSideEffectsAsyncMock.Received(1).Invoke(policyUpdate, null);
+
+ await AssertPolicySavedAsync(sutProvider, policyUpdate);
+ await sutProvider.GetDependency().Received(1).UpsertAsync(Arg.Is(p =>
+ p.CreationDate == creationDate &&
+ p.RevisionDate == creationDate));
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_ExistingPolicy_Success(
+ [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg, false)] Policy currentPolicy)
+ {
+ var fakePolicyValidator = new FakeSingleOrgPolicyValidator();
+ fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns("");
+ var sutProvider = SutProviderFactory([fakePolicyValidator]);
+
+ currentPolicy.OrganizationId = policyUpdate.OrganizationId;
+ sutProvider.GetDependency()
+ .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type)
+ .Returns(currentPolicy);
+
+ ArrangeOrganization(sutProvider, policyUpdate);
+ sutProvider.GetDependency()
+ .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns([currentPolicy]);
+
+ // Store mutable properties separately to assert later
+ var id = currentPolicy.Id;
+ var organizationId = currentPolicy.OrganizationId;
+ var type = currentPolicy.Type;
+ var creationDate = currentPolicy.CreationDate;
+ var revisionDate = sutProvider.GetDependency().Start;
+
+ await sutProvider.Sut.SaveAsync(policyUpdate);
+
+ await fakePolicyValidator.ValidateAsyncMock.Received(1).Invoke(policyUpdate, currentPolicy);
+ fakePolicyValidator.OnSaveSideEffectsAsyncMock.Received(1).Invoke(policyUpdate, currentPolicy);
+
+ await AssertPolicySavedAsync(sutProvider, policyUpdate);
+ // Additional assertions to ensure certain properties have or have not been updated
+ await sutProvider.GetDependency().Received(1).UpsertAsync(Arg.Is(p =>
+ p.Id == id &&
+ p.OrganizationId == organizationId &&
+ p.Type == type &&
+ p.CreationDate == creationDate &&
+ p.RevisionDate == revisionDate));
+ }
+
+ [Fact]
+ public void Constructor_DuplicatePolicyValidators_Throws()
+ {
+ var exception = Assert.Throws(() =>
+ new SavePolicyCommand(
+ Substitute.For(),
+ Substitute.For(),
+ Substitute.For(),
+ [new FakeSingleOrgPolicyValidator(), new FakeSingleOrgPolicyValidator()],
+ Substitute.For()
+ ));
+ Assert.Contains("Duplicate PolicyValidator for SingleOrg policy", exception.Message);
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest(PolicyUpdate policyUpdate)
+ {
+ var sutProvider = SutProviderFactory();
+ sutProvider.GetDependency()
+ .GetOrganizationAbilityAsync(policyUpdate.OrganizationId)
+ .Returns(Task.FromResult(null));
+
+ var badRequestException = await Assert.ThrowsAsync(
+ () => sutProvider.Sut.SaveAsync(policyUpdate));
+
+ Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
+ await AssertPolicyNotSavedAsync(sutProvider);
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest(PolicyUpdate policyUpdate)
+ {
+ var sutProvider = SutProviderFactory();
+ sutProvider.GetDependency()
+ .GetOrganizationAbilityAsync(policyUpdate.OrganizationId)
+ .Returns(new OrganizationAbility
+ {
+ Id = policyUpdate.OrganizationId,
+ UsePolicies = false
+ });
+
+ var badRequestException = await Assert.ThrowsAsync(
+ () => sutProvider.Sut.SaveAsync(policyUpdate));
+
+ Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
+ await AssertPolicyNotSavedAsync(sutProvider);
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_RequiredPolicyIsNull_Throws(
+ [PolicyUpdate(PolicyType.RequireSso)] PolicyUpdate policyUpdate)
+ {
+ var sutProvider = SutProviderFactory([
+ new FakeRequireSsoPolicyValidator(),
+ new FakeSingleOrgPolicyValidator()
+ ]);
+
+ ArrangeOrganization(sutProvider, policyUpdate);
+ sutProvider.GetDependency()
+ .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns([]);
+
+ var badRequestException = await Assert.ThrowsAsync(
+ () => sutProvider.Sut.SaveAsync(policyUpdate));
+
+ Assert.Contains("Turn on the Single organization policy because it is required for the Require single sign-on authentication policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
+ await AssertPolicyNotSavedAsync(sutProvider);
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_RequiredPolicyNotEnabled_Throws(
+ [PolicyUpdate(PolicyType.RequireSso)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg, false)] Policy singleOrgPolicy)
+ {
+ var sutProvider = SutProviderFactory([
+ new FakeRequireSsoPolicyValidator(),
+ new FakeSingleOrgPolicyValidator()
+ ]);
+
+ ArrangeOrganization(sutProvider, policyUpdate);
+ sutProvider.GetDependency()
+ .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns([singleOrgPolicy]);
+
+ var badRequestException = await Assert.ThrowsAsync(
+ () => sutProvider.Sut.SaveAsync(policyUpdate));
+
+ Assert.Contains("Turn on the Single organization policy because it is required for the Require single sign-on authentication policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
+ await AssertPolicyNotSavedAsync(sutProvider);
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_RequiredPolicyEnabled_Success(
+ [PolicyUpdate(PolicyType.RequireSso)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy)
+ {
+ var sutProvider = SutProviderFactory([
+ new FakeRequireSsoPolicyValidator(),
+ new FakeSingleOrgPolicyValidator()
+ ]);
+
+ ArrangeOrganization(sutProvider, policyUpdate);
+ sutProvider.GetDependency()
+ .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns([singleOrgPolicy]);
+
+ await sutProvider.Sut.SaveAsync(policyUpdate);
+ await AssertPolicySavedAsync(sutProvider, policyUpdate);
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_DependentPolicyIsEnabled_Throws(
+ [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg)] Policy currentPolicy,
+ [Policy(PolicyType.RequireSso)] Policy requireSsoPolicy) // depends on Single Org
+ {
+ var sutProvider = SutProviderFactory([
+ new FakeRequireSsoPolicyValidator(),
+ new FakeSingleOrgPolicyValidator()
+ ]);
+
+ ArrangeOrganization(sutProvider, policyUpdate);
+ sutProvider.GetDependency()
+ .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns([currentPolicy, requireSsoPolicy]);
+
+ var badRequestException = await Assert.ThrowsAsync(
+ () => sutProvider.Sut.SaveAsync(policyUpdate));
+
+ Assert.Contains("Turn off the Require single sign-on authentication policy because it requires the Single organization policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
+ await AssertPolicyNotSavedAsync(sutProvider);
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_MultipleDependentPoliciesAreEnabled_Throws(
+ [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg)] Policy currentPolicy,
+ [Policy(PolicyType.RequireSso)] Policy requireSsoPolicy, // depends on Single Org
+ [Policy(PolicyType.MaximumVaultTimeout)] Policy vaultTimeoutPolicy) // depends on Single Org
+ {
+ var sutProvider = SutProviderFactory([
+ new FakeRequireSsoPolicyValidator(),
+ new FakeSingleOrgPolicyValidator(),
+ new FakeVaultTimeoutPolicyValidator()
+ ]);
+
+ ArrangeOrganization(sutProvider, policyUpdate);
+ sutProvider.GetDependency()
+ .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns([currentPolicy, requireSsoPolicy, vaultTimeoutPolicy]);
+
+ var badRequestException = await Assert.ThrowsAsync(
+ () => sutProvider.Sut.SaveAsync(policyUpdate));
+
+ Assert.Contains("Turn off all of the policies that require the Single organization policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
+ await AssertPolicyNotSavedAsync(sutProvider);
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_DependentPolicyNotEnabled_Success(
+ [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate,
+ [Policy(PolicyType.SingleOrg)] Policy currentPolicy,
+ [Policy(PolicyType.RequireSso, false)] Policy requireSsoPolicy) // depends on Single Org but is not enabled
+ {
+ var sutProvider = SutProviderFactory([
+ new FakeRequireSsoPolicyValidator(),
+ new FakeSingleOrgPolicyValidator()
+ ]);
+
+ ArrangeOrganization(sutProvider, policyUpdate);
+ sutProvider.GetDependency()
+ .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId)
+ .Returns([currentPolicy, requireSsoPolicy]);
+
+ await sutProvider.Sut.SaveAsync(policyUpdate);
+
+ await AssertPolicySavedAsync(sutProvider, policyUpdate);
+ }
+
+ [Theory, BitAutoData]
+ public async Task SaveAsync_ThrowsOnValidationError([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate)
+ {
+ var fakePolicyValidator = new FakeSingleOrgPolicyValidator();
+ fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns("Validation error!");
+ var sutProvider = SutProviderFactory([fakePolicyValidator]);
+
+ ArrangeOrganization(sutProvider, policyUpdate);
+ sutProvider.GetDependency().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([]);
+
+ var badRequestException = await Assert.ThrowsAsync(
+ () => sutProvider.Sut.SaveAsync(policyUpdate));
+
+ Assert.Contains("Validation error!", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
+ await AssertPolicyNotSavedAsync(sutProvider);
+ }
+
+ ///
+ /// Returns a new SutProvider with the PolicyValidators registered in the Sut.
+ ///
+ private static SutProvider SutProviderFactory(IEnumerable? policyValidators = null)
+ {
+ return new SutProvider()
+ .WithFakeTimeProvider()
+ .SetDependency(typeof(IEnumerable), policyValidators ?? [])
+ .Create();
+ }
+
+ private static void ArrangeOrganization(SutProvider sutProvider, PolicyUpdate policyUpdate)
+ {
+ sutProvider.GetDependency()
+ .GetOrganizationAbilityAsync(policyUpdate.OrganizationId)
+ .Returns(new OrganizationAbility
+ {
+ Id = policyUpdate.OrganizationId,
+ UsePolicies = true
+ });
+ }
+
+ private static async Task AssertPolicyNotSavedAsync(SutProvider sutProvider)
+ {
+ await sutProvider.GetDependency()
+ .DidNotReceiveWithAnyArgs()
+ .UpsertAsync(default!);
+
+ await sutProvider.GetDependency()
+ .DidNotReceiveWithAnyArgs()
+ .LogPolicyEventAsync(default, default);
+ }
+
+ private static async Task AssertPolicySavedAsync(SutProvider sutProvider, PolicyUpdate policyUpdate)
+ {
+ var expectedPolicy = () => Arg.Is(p =>
+ p.Type == policyUpdate.Type &&
+ p.OrganizationId == policyUpdate.OrganizationId &&
+ p.Enabled == policyUpdate.Enabled &&
+ p.Data == policyUpdate.Data);
+
+ await sutProvider.GetDependency().Received(1).UpsertAsync(expectedPolicy());
+
+ await sutProvider.GetDependency().Received(1)
+ .LogPolicyEventAsync(expectedPolicy(), EventType.Policy_Updated);
+ }
+}
diff --git a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs
index fb08a32f2f4b..f9bc49bbe7c0 100644
--- a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs
+++ b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs
@@ -34,7 +34,6 @@ public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest(
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -61,7 +60,6 @@ public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest(
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -93,7 +91,6 @@ public async Task SaveAsync_SingleOrg_RequireSsoEnabled_ThrowsBadRequest(
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Single Sign-On Authentication policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -124,7 +121,6 @@ public async Task SaveAsync_SingleOrg_VaultTimeoutEnabled_ThrowsBadRequest([Admi
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Maximum Vault Timeout policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -161,7 +157,6 @@ public async Task SaveAsync_PolicyRequiredByKeyConnector_DisablePolicy_ThrowsBad
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Key Connector is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -189,7 +184,6 @@ public async Task SaveAsync_RequireSsoPolicy_NotEnabled_ThrowsBadRequestAsync(
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -222,7 +216,7 @@ public async Task SaveAsync_NewPolicy_Created(
var utcNow = DateTime.UtcNow;
- await sutProvider.Sut.SaveAsync(policy, Substitute.For(), Guid.NewGuid());
+ await sutProvider.Sut.SaveAsync(policy, Guid.NewGuid());
await sutProvider.GetDependency().Received()
.LogPolicyEventAsync(policy, EventType.Policy_Updated);
@@ -252,7 +246,6 @@ public async Task SaveAsync_VaultTimeoutPolicy_NotEnabled_ThrowsBadRequestAsync(
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -353,14 +346,13 @@ public async Task SaveAsync_ExistingPolicy_UpdateTwoFactor(
(orgUserDetailAdmin, false),
});
- var organizationService = Substitute.For();
var removeOrganizationUserCommand = sutProvider.GetDependency();
var utcNow = DateTime.UtcNow;
var savingUserId = Guid.NewGuid();
- await sutProvider.Sut.SaveAsync(policy, organizationService, savingUserId);
+ await sutProvider.Sut.SaveAsync(policy, savingUserId);
await removeOrganizationUserCommand.Received()
.RemoveUserAsync(policy.OrganizationId, orgUserDetailUserAcceptedWithout2FA.Id, savingUserId);
@@ -468,13 +460,12 @@ public async Task SaveAsync_EnableTwoFactor_WithoutMasterPasswordOr2FA_ThrowsBad
(orgUserDetailAdmin.UserId.Value, false),
});
- var organizationService = Substitute.For();
var removeOrganizationUserCommand = sutProvider.GetDependency();
var savingUserId = Guid.NewGuid();
var badRequestException = await Assert.ThrowsAsync(
- () => sutProvider.Sut.SaveAsync(policy, organizationService, savingUserId));
+ () => sutProvider.Sut.SaveAsync(policy, savingUserId));
Assert.Contains("Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -541,13 +532,11 @@ public async Task SaveAsync_ExistingPolicy_UpdateSingleOrg(
(orgUserDetail.UserId.Value, false),
});
- var organizationService = Substitute.For();
-
var utcNow = DateTime.UtcNow;
var savingUserId = Guid.NewGuid();
- await sutProvider.Sut.SaveAsync(policy, organizationService, savingUserId);
+ await sutProvider.Sut.SaveAsync(policy, savingUserId);
await sutProvider.GetDependency().Received()
.LogPolicyEventAsync(policy, EventType.Policy_Updated);
@@ -590,7 +579,6 @@ public async Task SaveAsync_ResetPasswordPolicyRequiredByTrustedDeviceEncryption
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Trusted device encryption is on and requires this policy.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -626,7 +614,6 @@ public async Task SaveAsync_RequireSsoPolicyRequiredByTrustedDeviceEncryption_Di
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Trusted device encryption is on and requires this policy.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -659,7 +646,6 @@ public async Task SaveAsync_PolicyRequiredForAccountRecovery_NotEnabled_ThrowsBa
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
@@ -692,7 +678,6 @@ public async Task SaveAsync_SingleOrg_AccountRecoveryEnabled_ThrowsBadRequest(
var badRequestException = await Assert.ThrowsAsync(
() => sutProvider.Sut.SaveAsync(policy,
- Substitute.For(),
Guid.NewGuid()));
Assert.Contains("Account recovery policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase);
diff --git a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs
index fb566537abcd..e397c838c68f 100644
--- a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs
+++ b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs
@@ -11,7 +11,6 @@
using Bit.Core.Exceptions;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
-using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
@@ -342,14 +341,12 @@ public async Task SaveAsync_Tde_Enable_Required_Policies(SutProvider().Received(1)
.SaveAsync(
Arg.Is(t => t.Type == PolicyType.SingleOrg),
- Arg.Any(),
null
);
await sutProvider.GetDependency().Received(1)
.SaveAsync(
Arg.Is(t => t.Type == PolicyType.ResetPassword && t.GetDataModel().AutoEnrollEnabled),
- Arg.Any(),
null
);
diff --git a/test/Core.Test/NotificationHub/NotificationHubConnectionTests.cs b/test/Core.Test/NotificationHub/NotificationHubConnectionTests.cs
new file mode 100644
index 000000000000..0d7382b3cc0e
--- /dev/null
+++ b/test/Core.Test/NotificationHub/NotificationHubConnectionTests.cs
@@ -0,0 +1,205 @@
+using Bit.Core.Settings;
+using Bit.Core.Utilities;
+using Xunit;
+
+namespace Bit.Core.Test.NotificationHub;
+
+public class NotificationHubConnectionTests
+{
+ [Fact]
+ public void IsValid_ConnectionStringIsNull_ReturnsFalse()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = null,
+ HubName = "hub",
+ RegistrationStartDate = DateTime.UtcNow,
+ RegistrationEndDate = DateTime.UtcNow.AddDays(1)
+ };
+
+ // Act
+ var connection = NotificationHubConnection.From(hub);
+
+ // Assert
+ Assert.False(connection.IsValid);
+ }
+
+ [Fact]
+ public void IsValid_HubNameIsNull_ReturnsFalse()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = "Endpoint=sb://example.servicebus.windows.net/;",
+ HubName = null,
+ RegistrationStartDate = DateTime.UtcNow,
+ RegistrationEndDate = DateTime.UtcNow.AddDays(1)
+ };
+
+ // Act
+ var connection = NotificationHubConnection.From(hub);
+
+ // Assert
+ Assert.False(connection.IsValid);
+ }
+
+ [Fact]
+ public void IsValid_ConnectionStringAndHubNameAreNotNull_ReturnsTrue()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = "connection",
+ HubName = "hub",
+ RegistrationStartDate = DateTime.UtcNow,
+ RegistrationEndDate = DateTime.UtcNow.AddDays(1)
+ };
+
+ // Act
+ var connection = NotificationHubConnection.From(hub);
+
+ // Assert
+ Assert.True(connection.IsValid);
+ }
+
+ [Fact]
+ public void RegistrationEnabled_QueryTimeIsBeforeStartDate_ReturnsFalse()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = "connection",
+ HubName = "hub",
+ RegistrationStartDate = DateTime.UtcNow.AddDays(1),
+ RegistrationEndDate = DateTime.UtcNow.AddDays(2)
+ };
+ var connection = NotificationHubConnection.From(hub);
+
+ // Act
+ var result = connection.RegistrationEnabled(DateTime.UtcNow);
+
+ // Assert
+ Assert.False(result);
+ }
+
+ [Fact]
+ public void RegistrationEnabled_QueryTimeIsAfterEndDate_ReturnsFalse()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = "connection",
+ HubName = "hub",
+ RegistrationStartDate = DateTime.UtcNow,
+ RegistrationEndDate = DateTime.UtcNow.AddDays(1)
+ };
+ var connection = NotificationHubConnection.From(hub);
+
+ // Act
+ var result = connection.RegistrationEnabled(DateTime.UtcNow.AddDays(2));
+
+ // Assert
+ Assert.False(result);
+ }
+
+ [Fact]
+ public void RegistrationEnabled_NullStartDate_ReturnsFalse()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = "connection",
+ HubName = "hub",
+ RegistrationStartDate = null,
+ RegistrationEndDate = DateTime.UtcNow.AddDays(1)
+ };
+ var connection = NotificationHubConnection.From(hub);
+
+ // Act
+ var result = connection.RegistrationEnabled(DateTime.UtcNow);
+
+ // Assert
+ Assert.False(result);
+ }
+
+ [Fact]
+ public void RegistrationEnabled_QueryTimeIsBetweenStartDateAndEndDate_ReturnsTrue()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = "connection",
+ HubName = "hub",
+ RegistrationStartDate = DateTime.UtcNow,
+ RegistrationEndDate = DateTime.UtcNow.AddDays(1)
+ };
+ var connection = NotificationHubConnection.From(hub);
+
+ // Act
+ var result = connection.RegistrationEnabled(DateTime.UtcNow.AddHours(1));
+
+ // Assert
+ Assert.True(result);
+ }
+
+ [Fact]
+ public void RegistrationEnabled_CombTimeIsBeforeStartDate_ReturnsFalse()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = "connection",
+ HubName = "hub",
+ RegistrationStartDate = DateTime.UtcNow.AddDays(1),
+ RegistrationEndDate = DateTime.UtcNow.AddDays(2)
+ };
+ var connection = NotificationHubConnection.From(hub);
+
+ // Act
+ var result = connection.RegistrationEnabled(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow));
+
+ // Assert
+ Assert.False(result);
+ }
+
+ [Fact]
+ public void RegistrationEnabled_CombTimeIsAfterEndDate_ReturnsFalse()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = "connection",
+ HubName = "hub",
+ RegistrationStartDate = DateTime.UtcNow,
+ RegistrationEndDate = DateTime.UtcNow.AddDays(1)
+ };
+ var connection = NotificationHubConnection.From(hub);
+
+ // Act
+ var result = connection.RegistrationEnabled(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow.AddDays(2)));
+
+ // Assert
+ Assert.False(result);
+ }
+
+ [Fact]
+ public void RegistrationEnabled_CombTimeIsBetweenStartDateAndEndDate_ReturnsTrue()
+ {
+ // Arrange
+ var hub = new GlobalSettings.NotificationHubSettings()
+ {
+ ConnectionString = "connection",
+ HubName = "hub",
+ RegistrationStartDate = DateTime.UtcNow,
+ RegistrationEndDate = DateTime.UtcNow.AddDays(1)
+ };
+ var connection = NotificationHubConnection.From(hub);
+
+ // Act
+ var result = connection.RegistrationEnabled(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow.AddHours(1)));
+
+ // Assert
+ Assert.True(result);
+ }
+}
diff --git a/test/Core.Test/NotificationHub/NotificationHubPoolTests.cs b/test/Core.Test/NotificationHub/NotificationHubPoolTests.cs
new file mode 100644
index 000000000000..dd9afb867ee1
--- /dev/null
+++ b/test/Core.Test/NotificationHub/NotificationHubPoolTests.cs
@@ -0,0 +1,156 @@
+using Bit.Core.NotificationHub;
+using Bit.Core.Settings;
+using Bit.Core.Utilities;
+using Microsoft.Extensions.Logging;
+using NSubstitute;
+using Xunit;
+using static Bit.Core.Settings.GlobalSettings;
+
+namespace Bit.Core.Test.NotificationHub;
+
+public class NotificationHubPoolTests
+{
+ [Fact]
+ public void NotificationHubPool_WarnsOnMissingConnectionString()
+ {
+ // Arrange
+ var globalSettings = new GlobalSettings()
+ {
+ NotificationHubPool = new NotificationHubPoolSettings()
+ {
+ NotificationHubs = new() {
+ new() {
+ ConnectionString = null,
+ HubName = "hub",
+ RegistrationStartDate = DateTime.UtcNow,
+ RegistrationEndDate = DateTime.UtcNow.AddDays(1)
+ }
+ }
+ }
+ };
+ var logger = Substitute.For>();
+
+ // Act
+ var sut = new NotificationHubPool(logger, globalSettings);
+
+ // Assert
+ logger.Received().Log(LogLevel.Warning, Arg.Any(),
+ Arg.Is