diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 5121d551c53c..47d3525deff8 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,13 +4,22 @@ # # https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners -# DevOps for Actions and other workflow changes -.github/workflows @bitwarden/dept-devops +## Docker files have shared ownership ## +**/Dockerfile +**/*.Dockerfile +**/.dockerignore +**/entrypoint.sh -# DevOps for Docker changes -**/Dockerfile @bitwarden/dept-devops -**/*.Dockerfile @bitwarden/dept-devops -**/.dockerignore @bitwarden/dept-devops +## BRE team owns these workflows ## +.github/workflows/publish.yml @bitwarden/dept-bre + +## These are shared workflows ## +.github/workflows/_move_finalization_db_scripts.yml +.github/workflows/build.yml +.github/workflows/cleanup-after-pr.yml +.github/workflows/cleanup-rc-branch.yml +.github/workflows/release.yml +.github/workflows/repository-management.yml # Database Operations for database changes src/Sql/** @bitwarden/dept-dbops @@ -60,6 +69,6 @@ src/EventsProcessor @bitwarden/team-admin-console-dev src/Admin/Controllers/ToolsController.cs @bitwarden/team-billing-dev src/Admin/Views/Tools @bitwarden/team-billing-dev -# Multiple owners - DO NOT REMOVE (DevOps) +# Multiple owners - DO NOT REMOVE (BRE) **/packages.lock.json Directory.Build.props diff --git a/.github/workflows/_move_finalization_db_scripts.yml b/.github/workflows/_move_finalization_db_scripts.yml index 3eb3777cef27..6e3825733a34 100644 --- a/.github/workflows/_move_finalization_db_scripts.yml +++ b/.github/workflows/_move_finalization_db_scripts.yml @@ -1,4 +1,3 @@ ---- name: _move_finalization_db_scripts run-name: Move finalization database scripts diff --git a/.github/workflows/automatic-issue-responses.yml b/.github/workflows/automatic-issue-responses.yml index 0e6b9041b93b..b6a6a1ebf5df 100644 --- a/.github/workflows/automatic-issue-responses.yml +++ b/.github/workflows/automatic-issue-responses.yml @@ -1,4 +1,3 @@ ---- name: Automatic responses on: issues: diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6df666417491..6043e1e21e5f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,4 +1,3 @@ ---- name: Build on: @@ -408,7 +407,7 @@ jobs: name: swagger.json path: swagger.json if-no-files-found: error - + - name: Build Internal API Swagger run: | cd ./src/Api @@ -416,17 +415,17 @@ jobs: dotnet tool restore echo "Publish API" dotnet publish -c "Release" -o obj/build-output/publish - + dotnet swagger tofile --output ../../internal.json --host https://api.bitwarden.com \ ./obj/build-output/publish/Api.dll internal - + cd ../Identity - + echo "Restore Identity tools" dotnet tool restore echo "Publish Identity" dotnet publish -c "Release" -o obj/build-output/publish - + dotnet swagger tofile --output ../../identity.json --host https://identity.bitwarden.com \ ./obj/build-output/publish/Identity.dll v1 cd ../.. @@ -448,7 +447,7 @@ jobs: with: name: identity.json path: identity.json - if-no-files-found: error + if-no-files-found: error build-mssqlmigratorutility: name: Build MSSQL migrator utility @@ -565,7 +564,7 @@ jobs: tag: 'main' } }) - + trigger-ee-updates: name: Trigger Ephemeral Environment updates if: github.ref != 'refs/heads/main' && contains(github.event.pull_request.labels.*.name, 'ephemeral-environment') @@ -595,7 +594,7 @@ jobs: workflow_id: '_update_ephemeral_tags.yml', ref: 'main', inputs: { - ephemeral_env_branch: '${{ github.head_ref }}' + ephemeral_env_branch: process.env.GITHUB_HEAD_REF } }) diff --git a/.github/workflows/cleanup-after-pr.yml b/.github/workflows/cleanup-after-pr.yml index 1bed3542d98c..c36dc4a034b5 100644 --- a/.github/workflows/cleanup-after-pr.yml +++ b/.github/workflows/cleanup-after-pr.yml @@ -1,4 +1,3 @@ ---- name: Container registry cleanup on: diff --git a/.github/workflows/cleanup-ephemeral-environment.yml b/.github/workflows/cleanup-ephemeral-environment.yml new file mode 100644 index 000000000000..d5c34a7bb4d7 --- /dev/null +++ b/.github/workflows/cleanup-ephemeral-environment.yml @@ -0,0 +1,59 @@ +name: Ephemeral environment cleanup + +on: + pull_request: + types: [unlabeled] + +jobs: + validate-pr: + name: Validate PR + runs-on: ubuntu-24.04 + outputs: + config-exists: ${{ steps.validate-config.outputs.config-exists }} + steps: + - name: Checkout PR + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + + - name: Validate config exists in path + id: validate-config + run: | + if [[ -f "ephemeral-environments/$GITHUB_HEAD_REF.yaml" ]]; then + echo "Ephemeral environment config found in path, continuing." + echo "config-exists=true" >> $GITHUB_OUTPUT + fi + + + cleanup-config: + name: Cleanup ephemeral environment + runs-on: ubuntu-24.04 + needs: validate-pr + if: ${{ needs.validate-pr.outputs.config-exists }} + steps: + - name: Log in to Azure - CI subscription + uses: Azure/login@e15b166166a8746d1a47596803bd8c1b595455cf # v1.6.0 + with: + creds: ${{ secrets.AZURE_KV_CI_SERVICE_PRINCIPAL }} + + - name: Retrieve GitHub PAT secrets + id: retrieve-secret-pat + uses: bitwarden/gh-actions/get-keyvault-secrets@main + with: + keyvault: "bitwarden-ci" + secrets: "github-pat-bitwarden-devops-bot-repo-scope" + + - name: Trigger Ephemeral Environment cleanup + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + script: | + await github.rest.actions.createWorkflowDispatch({ + owner: 'bitwarden', + repo: 'devops', + workflow_id: '_ephemeral_environment_pr_manager.yml', + ref: 'main', + inputs: { + ephemeral_env_branch: process.env.GITHUB_HEAD_REF, + cleanup_config: true, + project: 'server' + } + }) diff --git a/.github/workflows/cleanup-rc-branch.yml b/.github/workflows/cleanup-rc-branch.yml index 1eba867a9c1d..e037c18f93e9 100644 --- a/.github/workflows/cleanup-rc-branch.yml +++ b/.github/workflows/cleanup-rc-branch.yml @@ -1,4 +1,3 @@ ---- name: Cleanup RC Branch on: diff --git a/.github/workflows/enforce-labels.yml b/.github/workflows/enforce-labels.yml index 160ee15b9650..11d5654937e1 100644 --- a/.github/workflows/enforce-labels.yml +++ b/.github/workflows/enforce-labels.yml @@ -1,4 +1,3 @@ ---- name: Enforce PR labels on: @@ -7,13 +6,13 @@ on: types: [labeled, unlabeled, opened, reopened, synchronize] jobs: enforce-label: - if: ${{ contains(github.event.*.labels.*.name, 'hold') || contains(github.event.*.labels.*.name, 'needs-qa') || contains(github.event.*.labels.*.name, 'DB-migrations-changed') }} + if: ${{ contains(github.event.*.labels.*.name, 'hold') || contains(github.event.*.labels.*.name, 'needs-qa') || contains(github.event.*.labels.*.name, 'DB-migrations-changed') || contains(github.event.*.labels.*.name, 'ephemeral-environment') }} name: Enforce label runs-on: ubuntu-22.04 steps: - name: Check for label run: | - echo "PRs with the hold or needs-qa labels cannot be merged" - echo "### :x: PRs with the hold or needs-qa labels cannot be merged" >> $GITHUB_STEP_SUMMARY + echo "PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged" + echo "### :x: PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged" >> $GITHUB_STEP_SUMMARY exit 1 diff --git a/.github/workflows/protect-files.yml b/.github/workflows/protect-files.yml index 10924f656bae..95d57180df61 100644 --- a/.github/workflows/protect-files.yml +++ b/.github/workflows/protect-files.yml @@ -1,7 +1,6 @@ # Runs if there are changes to the paths: list. # Starts a matrix job to check for modified files, then sets output based on the results. # The input decides if the label job is ran, adding a label to the PR. ---- name: Protect files on: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 77ea9dca4f50..4454ea1f3c28 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,4 +1,3 @@ ---- name: Publish run-name: Publish ${{ inputs.publish_type }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0c89a01c2f13..9d5dcb74d8ca 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,4 +1,3 @@ ---- name: Release run-name: Release ${{ inputs.release_type }} diff --git a/.github/workflows/stale-bot.yml b/.github/workflows/stale-bot.yml index d3dc92cd70ea..f8a25288f283 100644 --- a/.github/workflows/stale-bot.yml +++ b/.github/workflows/stale-bot.yml @@ -1,4 +1,3 @@ ---- name: Staleness on: workflow_dispatch: diff --git a/.github/workflows/test-database.yml b/.github/workflows/test-database.yml index 09a4b7a182aa..7a38b0f3bd4e 100644 --- a/.github/workflows/test-database.yml +++ b/.github/workflows/test-database.yml @@ -1,4 +1,3 @@ ---- name: Database testing on: @@ -55,7 +54,7 @@ jobs: # I've seen the SQL Server container not be ready for commands right after starting up and just needing a bit longer to be ready - name: Sleep run: sleep 15s - + - name: Checking pending model changes (MySQL) working-directory: "util/MySqlMigrations" run: 'dotnet ef migrations has-pending-model-changes -- --GlobalSettings:MySql:ConnectionString="$CONN_STR"' @@ -114,7 +113,7 @@ jobs: BW_TEST_DATABASES__3__CONNECTIONSTRING: "Data Source=${{ runner.temp }}/test.db" run: dotnet test --logger "trx;LogFileName=infrastructure-test-results.trx" shell: pwsh - + - name: Print MySQL Logs if: failure() run: 'docker logs $(docker ps --quiet --filter "name=mysql")' diff --git a/Directory.Build.props b/Directory.Build.props index 46a12ea3db5f..5cd12bfb7114 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2024.10.0 + 2024.10.1 Bit.$(MSBuildProjectName) enable diff --git a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs index 70c09a539bfe..efab8620c0d7 100644 --- a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs @@ -11,7 +11,6 @@ using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; -using Bit.Core.Exceptions; using Bit.Core.Models.OrganizationConnectionConfigs; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; @@ -236,7 +235,8 @@ public async Task Edit(Guid id, OrganizationEditModel model) if (organization.UseSecretsManager && !StaticStore.GetPlan(organization.PlanType).SupportsSecretsManager) { - throw new BadRequestException("Plan does not support Secrets Manager"); + TempData["Error"] = "Plan does not support Secrets Manager"; + return RedirectToAction("Edit", new { id }); } await _organizationRepository.ReplaceAsync(organization); diff --git a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs index 04079138d423..4ba22130f718 100644 --- a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs +++ b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs @@ -181,7 +181,6 @@ public OrganizationEditModel( */ public object GetPlansHelper() => StaticStore.Plans - .Where(p => p.SupportsSecretsManager) .Select(p => { var plan = new diff --git a/src/Api/AdminConsole/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Controllers/PoliciesController.cs index b3be852dbca2..7bfd13c4088e 100644 --- a/src/Api/AdminConsole/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Controllers/PoliciesController.cs @@ -25,7 +25,6 @@ public class PoliciesController : Controller { private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; - private readonly IOrganizationService _organizationService; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IUserService _userService; private readonly ICurrentContext _currentContext; @@ -36,7 +35,6 @@ public class PoliciesController : Controller public PoliciesController( IPolicyRepository policyRepository, IPolicyService policyService, - IOrganizationService organizationService, IOrganizationUserRepository organizationUserRepository, IUserService userService, ICurrentContext currentContext, @@ -46,7 +44,6 @@ public PoliciesController( { _policyRepository = policyRepository; _policyService = policyService; - _organizationService = organizationService; _organizationUserRepository = organizationUserRepository; _userService = userService; _currentContext = currentContext; @@ -185,7 +182,7 @@ public async Task Put(string orgId, int type, [FromBody] Po } var userId = _userService.GetProperUserId(User); - await _policyService.SaveAsync(policy, _organizationService, userId); + await _policyService.SaveAsync(policy, userId); return new PolicyResponseModel(policy); } } diff --git a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs index 2d83bd70559f..71e03a547ec7 100644 --- a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs @@ -6,7 +6,6 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Context; -using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -18,18 +17,15 @@ public class PoliciesController : Controller { private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; - private readonly IOrganizationService _organizationService; private readonly ICurrentContext _currentContext; public PoliciesController( IPolicyRepository policyRepository, IPolicyService policyService, - IOrganizationService organizationService, ICurrentContext currentContext) { _policyRepository = policyRepository; _policyService = policyService; - _organizationService = organizationService; _currentContext = currentContext; } @@ -96,7 +92,7 @@ public async Task Put(PolicyType type, [FromBody] PolicyUpdateReq { policy = model.ToPolicy(policy); } - await _policyService.SaveAsync(policy, _organizationService, null); + await _policyService.SaveAsync(policy, null); var response = new PolicyResponseModel(policy); return new JsonResult(response); } diff --git a/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs b/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs index 624eab1fece8..b5f9ab2f5910 100644 --- a/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs +++ b/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs @@ -3,8 +3,11 @@ namespace Bit.Api.Billing.Models.Responses; public record OrganizationMetadataResponse( + bool IsEligibleForSelfHost, bool IsOnSecretsManagerStandalone) { public static OrganizationMetadataResponse From(OrganizationMetadata metadata) - => new(metadata.IsOnSecretsManagerStandalone); + => new( + metadata.IsEligibleForSelfHost, + metadata.IsOnSecretsManagerStandalone); } diff --git a/src/Api/Controllers/PushController.cs b/src/Api/Controllers/PushController.cs index c83eb200b8df..38398051060e 100644 --- a/src/Api/Controllers/PushController.cs +++ b/src/Api/Controllers/PushController.cs @@ -46,7 +46,7 @@ await _pushRegistrationService.CreateOrUpdateRegistrationAsync(model.PushToken, public async Task PostDelete([FromBody] PushDeviceRequestModel model) { CheckUsage(); - await _pushRegistrationService.DeleteRegistrationAsync(Prefix(model.Id), model.Type); + await _pushRegistrationService.DeleteRegistrationAsync(Prefix(model.Id)); } [HttpPut("add-organization")] @@ -54,7 +54,7 @@ public async Task PutAddOrganization([FromBody] PushUpdateRequestModel model) { CheckUsage(); await _pushRegistrationService.AddUserRegistrationOrganizationAsync( - model.Devices.Select(d => new KeyValuePair(Prefix(d.Id), d.Type)), + model.Devices.Select(d => Prefix(d.Id)), Prefix(model.OrganizationId)); } @@ -63,7 +63,7 @@ public async Task PutDeleteOrganization([FromBody] PushUpdateRequestModel model) { CheckUsage(); await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync( - model.Devices.Select(d => new KeyValuePair(Prefix(d.Id), d.Type)), + model.Devices.Select(d => Prefix(d.Id)), Prefix(model.OrganizationId)); } diff --git a/src/Core/AdminConsole/Enums/PolicyType.cs b/src/Core/AdminConsole/Enums/PolicyType.cs index 0e1786cf528f..bdde3e424ec3 100644 --- a/src/Core/AdminConsole/Enums/PolicyType.cs +++ b/src/Core/AdminConsole/Enums/PolicyType.cs @@ -16,3 +16,30 @@ public enum PolicyType : byte ActivateAutofill = 11, AutomaticAppLogIn = 12, } + +public static class PolicyTypeExtensions +{ + /// + /// Returns the name of the policy for display to the user. + /// Do not include the word "policy" in the return value. + /// + public static string GetName(this PolicyType type) + { + return type switch + { + PolicyType.TwoFactorAuthentication => "Require two-step login", + PolicyType.MasterPassword => "Master password requirements", + PolicyType.PasswordGenerator => "Password generator", + PolicyType.SingleOrg => "Single organization", + PolicyType.RequireSso => "Require single sign-on authentication", + PolicyType.PersonalOwnership => "Remove individual vault", + PolicyType.DisableSend => "Remove Send", + PolicyType.SendOptions => "Send options", + PolicyType.ResetPassword => "Account recovery administration", + PolicyType.MaximumVaultTimeout => "Vault timeout", + PolicyType.DisablePersonalVaultExport => "Remove individual vault export", + PolicyType.ActivateAutofill => "Active auto-fill", + PolicyType.AutomaticAppLogIn => "Automatically log in users for allowed applications", + }; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs index 09444306e6e5..e6d56ea878a2 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs @@ -162,12 +162,12 @@ private async Task RepositoryDeleteUserAsync(OrganizationUser orgUser, Guid? del } } - private async Task>> GetUserDeviceIdsAsync(Guid userId) + private async Task> GetUserDeviceIdsAsync(Guid userId) { var devices = await _deviceRepository.GetManyByUserIdAsync(userId); return devices .Where(d => !string.IsNullOrWhiteSpace(d.PushToken)) - .Select(d => new KeyValuePair(d.Id.ToString(), d.Type)); + .Select(d => d.Id.ToString()); } private async Task DeleteAndPushUserRegistrationAsync(Guid organizationId, Guid userId) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs new file mode 100644 index 000000000000..6aef9f248b89 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs @@ -0,0 +1,43 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies; + +/// +/// Defines behavior and functionality for a given PolicyType. +/// +public interface IPolicyValidator +{ + /// + /// The PolicyType that this definition relates to. + /// + public PolicyType Type { get; } + + /// + /// PolicyTypes that must be enabled before this policy can be enabled, if any. + /// These dependencies will be checked when this policy is enabled and when any required policy is disabled. + /// + public IEnumerable RequiredPolicies { get; } + + /// + /// Validates a policy before saving it. + /// Do not use this for simple dependencies between different policies - see instead. + /// Implementation is optional; by default it will not perform any validation. + /// + /// The policy update request + /// The current policy, if any + /// A validation error if validation was unsuccessful, otherwise an empty string + public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy); + + /// + /// Performs side effects after a policy is validated but before it is saved. + /// For example, this can be used to remove non-compliant users from the organization. + /// Implementation is optional; by default it will not perform any side effects. + /// + /// The policy update request + /// The current policy, if any + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs new file mode 100644 index 000000000000..5bfdfc6aa7d2 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs @@ -0,0 +1,8 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies; + +public interface ISavePolicyCommand +{ + Task SaveAsync(PolicyUpdate policy); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs new file mode 100644 index 000000000000..01ffce2cc670 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs @@ -0,0 +1,129 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; + +public class SavePolicyCommand : ISavePolicyCommand +{ + private readonly IApplicationCacheService _applicationCacheService; + private readonly IEventService _eventService; + private readonly IPolicyRepository _policyRepository; + private readonly IReadOnlyDictionary _policyValidators; + private readonly TimeProvider _timeProvider; + + public SavePolicyCommand( + IApplicationCacheService applicationCacheService, + IEventService eventService, + IPolicyRepository policyRepository, + IEnumerable policyValidators, + TimeProvider timeProvider) + { + _applicationCacheService = applicationCacheService; + _eventService = eventService; + _policyRepository = policyRepository; + _timeProvider = timeProvider; + + var policyValidatorsDict = new Dictionary(); + foreach (var policyValidator in policyValidators) + { + if (!policyValidatorsDict.TryAdd(policyValidator.Type, policyValidator)) + { + throw new Exception($"Duplicate PolicyValidator for {policyValidator.Type} policy."); + } + } + + _policyValidators = policyValidatorsDict; + } + + public async Task SaveAsync(PolicyUpdate policyUpdate) + { + var org = await _applicationCacheService.GetOrganizationAbilityAsync(policyUpdate.OrganizationId); + if (org == null) + { + throw new BadRequestException("Organization not found"); + } + + if (!org.UsePolicies) + { + throw new BadRequestException("This organization cannot use policies."); + } + + if (_policyValidators.TryGetValue(policyUpdate.Type, out var validator)) + { + await RunValidatorAsync(validator, policyUpdate); + } + + var policy = await _policyRepository.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) + ?? new Policy + { + OrganizationId = policyUpdate.OrganizationId, + Type = policyUpdate.Type, + CreationDate = _timeProvider.GetUtcNow().UtcDateTime + }; + + policy.Enabled = policyUpdate.Enabled; + policy.Data = policyUpdate.Data; + policy.RevisionDate = _timeProvider.GetUtcNow().UtcDateTime; + + await _policyRepository.UpsertAsync(policy); + await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); + } + + private async Task RunValidatorAsync(IPolicyValidator validator, PolicyUpdate policyUpdate) + { + var savedPolicies = await _policyRepository.GetManyByOrganizationIdAsync(policyUpdate.OrganizationId); + // Note: policies may be missing from this dict if they have never been enabled + var savedPoliciesDict = savedPolicies.ToDictionary(p => p.Type); + var currentPolicy = savedPoliciesDict.GetValueOrDefault(policyUpdate.Type); + + // If enabling this policy - check that all policy requirements are satisfied + if (currentPolicy is not { Enabled: true } && policyUpdate.Enabled) + { + var missingRequiredPolicyTypes = validator.RequiredPolicies + .Where(requiredPolicyType => + savedPoliciesDict.GetValueOrDefault(requiredPolicyType) is not { Enabled: true }) + .ToList(); + + if (missingRequiredPolicyTypes.Count != 0) + { + throw new BadRequestException($"Turn on the {missingRequiredPolicyTypes.First().GetName()} policy because it is required for the {validator.Type.GetName()} policy."); + } + } + + // If disabling this policy - ensure it's not required by any other policy + if (currentPolicy is { Enabled: true } && !policyUpdate.Enabled) + { + var dependentPolicyTypes = _policyValidators.Values + .Where(otherValidator => otherValidator.RequiredPolicies.Contains(policyUpdate.Type)) + .Select(otherValidator => otherValidator.Type) + .Where(otherPolicyType => savedPoliciesDict.ContainsKey(otherPolicyType) && + savedPoliciesDict[otherPolicyType].Enabled) + .ToList(); + + switch (dependentPolicyTypes) + { + case { Count: 1 }: + throw new BadRequestException($"Turn off the {dependentPolicyTypes.First().GetName()} policy because it requires the {validator.Type.GetName()} policy."); + case { Count: > 1 }: + throw new BadRequestException($"Turn off all of the policies that require the {validator.Type.GetName()} policy."); + } + } + + // Run other validation + var validationError = await validator.ValidateAsync(policyUpdate, currentPolicy); + if (!string.IsNullOrEmpty(validationError)) + { + throw new BadRequestException(validationError); + } + + // Run side effects + await validator.OnSaveSideEffectsAsync(policyUpdate, currentPolicy); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs new file mode 100644 index 000000000000..117a7ec73396 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs @@ -0,0 +1,28 @@ +#nullable enable + +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.Utilities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +/// +/// A request for SavePolicyCommand to update a policy +/// +public record PolicyUpdate +{ + public Guid OrganizationId { get; set; } + public PolicyType Type { get; set; } + public string? Data { get; set; } + public bool Enabled { get; set; } + + public T GetDataModel() where T : IPolicyDataModel, new() + { + return CoreHelpers.LoadClassFromJsonData(Data); + } + + public void SetDataModel(T dataModel) where T : IPolicyDataModel, new() + { + Data = CoreHelpers.ClassToJsonData(dataModel); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs new file mode 100644 index 000000000000..81096ef6022a --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs @@ -0,0 +1,22 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Bit.Core.AdminConsole.Services; +using Bit.Core.AdminConsole.Services.Implementations; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies; + +public static class PolicyServiceCollectionExtensions +{ + public static void AddPolicyServices(this IServiceCollection services) + { + services.AddScoped(); + services.AddScoped(); + + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs new file mode 100644 index 000000000000..bfd4dcfe0d3a --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs @@ -0,0 +1,15 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator +{ + public PolicyType Type => PolicyType.MaximumVaultTimeout; + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(""); + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/PolicyValidatorHelpers.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/PolicyValidatorHelpers.cs new file mode 100644 index 000000000000..1bbaf1aa1e84 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/PolicyValidatorHelpers.cs @@ -0,0 +1,33 @@ +#nullable enable + +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public static class PolicyValidatorHelpers +{ + /// + /// Validate that given Member Decryption Options are not enabled. + /// Used for validation when disabling a policy that is required by certain Member Decryption Options. + /// + /// The Member Decryption Options that require the policy to be enabled. + /// A validation error if validation was unsuccessful, otherwise an empty string + public static string ValidateDecryptionOptionsNotEnabled(this SsoConfig? ssoConfig, + MemberDecryptionType[] decryptionOptions) + { + if (ssoConfig is not { Enabled: true }) + { + return ""; + } + + return ssoConfig.GetData().MemberDecryptionType switch + { + MemberDecryptionType.KeyConnector when decryptionOptions.Contains(MemberDecryptionType.KeyConnector) + => "Key Connector is enabled and requires this policy.", + MemberDecryptionType.TrustedDeviceEncryption when decryptionOptions.Contains(MemberDecryptionType + .TrustedDeviceEncryption) => "Trusted device encryption is on and requires this policy.", + _ => "" + }; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs new file mode 100644 index 000000000000..2082d4305fa2 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs @@ -0,0 +1,38 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Repositories; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class RequireSsoPolicyValidator : IPolicyValidator +{ + private readonly ISsoConfigRepository _ssoConfigRepository; + + public RequireSsoPolicyValidator(ISsoConfigRepository ssoConfigRepository) + { + _ssoConfigRepository = ssoConfigRepository; + } + + public PolicyType Type => PolicyType.RequireSso; + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + if (policyUpdate is not { Enabled: true }) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(policyUpdate.OrganizationId); + return ssoConfig.ValidateDecryptionOptionsNotEnabled([ + MemberDecryptionType.KeyConnector, + MemberDecryptionType.TrustedDeviceEncryption + ]); + } + + return ""; + } + + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs new file mode 100644 index 000000000000..1126c4b922b0 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs @@ -0,0 +1,36 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Repositories; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class ResetPasswordPolicyValidator : IPolicyValidator +{ + private readonly ISsoConfigRepository _ssoConfigRepository; + public PolicyType Type => PolicyType.ResetPassword; + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + + public ResetPasswordPolicyValidator(ISsoConfigRepository ssoConfigRepository) + { + _ssoConfigRepository = ssoConfigRepository; + } + + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + if (policyUpdate is not { Enabled: true } || + policyUpdate.GetDataModel().AutoEnrollEnabled == false) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(policyUpdate.OrganizationId); + return ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.TrustedDeviceEncryption]); + } + + return ""; + } + + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs new file mode 100644 index 000000000000..3e1f8d26c81d --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs @@ -0,0 +1,101 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Repositories; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class SingleOrgPolicyValidator : IPolicyValidator +{ + public PolicyType Type => PolicyType.SingleOrg; + + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IMailService _mailService; + private readonly IOrganizationRepository _organizationRepository; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ICurrentContext _currentContext; + private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; + + public SingleOrgPolicyValidator( + IOrganizationUserRepository organizationUserRepository, + IMailService mailService, + IOrganizationRepository organizationRepository, + ISsoConfigRepository ssoConfigRepository, + ICurrentContext currentContext, + IRemoveOrganizationUserCommand removeOrganizationUserCommand) + { + _organizationUserRepository = organizationUserRepository; + _mailService = mailService; + _organizationRepository = organizationRepository; + _ssoConfigRepository = ssoConfigRepository; + _currentContext = currentContext; + _removeOrganizationUserCommand = removeOrganizationUserCommand; + } + + public IEnumerable RequiredPolicies => []; + + public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) + { + await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId); + } + } + + private async Task RemoveNonCompliantUsersAsync(Guid organizationId) + { + // Remove non-compliant users + var savingUserId = _currentContext.UserId; + // Note: must get OrganizationUserUserDetails so that Email is always populated from the User object + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + var org = await _organizationRepository.GetByIdAsync(organizationId); + if (org == null) + { + throw new NotFoundException("Organization not found."); + } + + var removableOrgUsers = orgUsers.Where(ou => + ou.Status != OrganizationUserStatusType.Invited && + ou.Status != OrganizationUserStatusType.Revoked && + ou.Type != OrganizationUserType.Owner && + ou.Type != OrganizationUserType.Admin && + ou.UserId != savingUserId + ).ToList(); + + var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync( + removableOrgUsers.Select(ou => ou.UserId!.Value)); + foreach (var orgUser in removableOrgUsers) + { + if (userOrgs.Any(ou => ou.UserId == orgUser.UserId + && ou.OrganizationId != org.Id + && ou.Status != OrganizationUserStatusType.Invited)) + { + await _removeOrganizationUserCommand.RemoveUserAsync(organizationId, orgUser.Id, + savingUserId); + + await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync( + org.DisplayName(), orgUser.Email); + } + } + } + + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + if (policyUpdate is not { Enabled: true }) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(policyUpdate.OrganizationId); + return ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.KeyConnector]); + } + + return ""; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs new file mode 100644 index 000000000000..ef896bbb9b57 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs @@ -0,0 +1,87 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator +{ + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IMailService _mailService; + private readonly IOrganizationRepository _organizationRepository; + private readonly ICurrentContext _currentContext; + private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; + private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; + + public PolicyType Type => PolicyType.TwoFactorAuthentication; + public IEnumerable RequiredPolicies => []; + + public TwoFactorAuthenticationPolicyValidator( + IOrganizationUserRepository organizationUserRepository, + IMailService mailService, + IOrganizationRepository organizationRepository, + ICurrentContext currentContext, + ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, + IRemoveOrganizationUserCommand removeOrganizationUserCommand) + { + _organizationUserRepository = organizationUserRepository; + _mailService = mailService; + _organizationRepository = organizationRepository; + _currentContext = currentContext; + _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; + _removeOrganizationUserCommand = removeOrganizationUserCommand; + } + + public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) + { + await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId); + } + } + + private async Task RemoveNonCompliantUsersAsync(Guid organizationId) + { + var org = await _organizationRepository.GetByIdAsync(organizationId); + var savingUserId = _currentContext.UserId; + + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + var organizationUsersTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(orgUsers); + var removableOrgUsers = orgUsers.Where(ou => + ou.Status != OrganizationUserStatusType.Invited && ou.Status != OrganizationUserStatusType.Revoked && + ou.Type != OrganizationUserType.Owner && ou.Type != OrganizationUserType.Admin && + ou.UserId != savingUserId); + + // Reorder by HasMasterPassword to prioritize checking users without a master if they have 2FA enabled + foreach (var orgUser in removableOrgUsers.OrderBy(ou => ou.HasMasterPassword)) + { + var userTwoFactorEnabled = organizationUsersTwoFactorEnabled.FirstOrDefault(u => u.user.Id == orgUser.Id) + .twoFactorIsEnabled; + if (!userTwoFactorEnabled) + { + if (!orgUser.HasMasterPassword) + { + throw new BadRequestException( + "Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page."); + } + + await _removeOrganizationUserCommand.RemoveUserAsync(organizationId, orgUser.Id, + savingUserId); + + await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( + org!.DisplayName(), orgUser.Email); + } + } + } + + public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(""); +} diff --git a/src/Core/AdminConsole/Services/IPolicyService.cs b/src/Core/AdminConsole/Services/IPolicyService.cs index 6d92a3a4f749..16ff2f4fa11c 100644 --- a/src/Core/AdminConsole/Services/IPolicyService.cs +++ b/src/Core/AdminConsole/Services/IPolicyService.cs @@ -4,13 +4,12 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -using Bit.Core.Services; namespace Bit.Core.AdminConsole.Services; public interface IPolicyService { - Task SaveAsync(Policy policy, IOrganizationService organizationService, Guid? savingUserId); + Task SaveAsync(Policy policy, Guid? savingUserId); /// /// Get the combined master password policy options for the specified user. diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 50a2ed84eb92..f44ce686f4ed 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -1838,12 +1838,12 @@ await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(devices, } - private async Task>> GetUserDeviceIdsAsync(Guid userId) + private async Task> GetUserDeviceIdsAsync(Guid userId) { var devices = await _deviceRepository.GetManyByUserIdAsync(userId); return devices .Where(d => !string.IsNullOrWhiteSpace(d.PushToken)) - .Select(d => new KeyValuePair(d.Id.ToString(), d.Type)); + .Select(d => d.Id.ToString()); } public async Task ReplaceAndUpdateCacheAsync(Organization org, EventType? orgEvent = null) diff --git a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs index 7e689f0342fe..6ab90afe04c6 100644 --- a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs +++ b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs @@ -2,6 +2,8 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; @@ -27,6 +29,8 @@ public class PolicyService : IPolicyService private readonly IMailService _mailService; private readonly GlobalSettings _globalSettings; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; + private readonly IFeatureService _featureService; + private readonly ISavePolicyCommand _savePolicyCommand; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; public PolicyService( @@ -39,6 +43,8 @@ public PolicyService( IMailService mailService, GlobalSettings globalSettings, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, + IFeatureService featureService, + ISavePolicyCommand savePolicyCommand, IRemoveOrganizationUserCommand removeOrganizationUserCommand) { _applicationCacheService = applicationCacheService; @@ -50,11 +56,28 @@ public PolicyService( _mailService = mailService; _globalSettings = globalSettings; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; + _featureService = featureService; + _savePolicyCommand = savePolicyCommand; _removeOrganizationUserCommand = removeOrganizationUserCommand; } - public async Task SaveAsync(Policy policy, IOrganizationService organizationService, Guid? savingUserId) + public async Task SaveAsync(Policy policy, Guid? savingUserId) { + if (_featureService.IsEnabled(FeatureFlagKeys.Pm13322AddPolicyDefinitions)) + { + // Transitional mapping - this will be moved to callers once the feature flag is removed + var policyUpdate = new PolicyUpdate + { + OrganizationId = policy.OrganizationId, + Type = policy.Type, + Enabled = policy.Enabled, + Data = policy.Data + }; + + await _savePolicyCommand.SaveAsync(policyUpdate); + return; + } + var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId); if (org == null) { @@ -88,7 +111,7 @@ public async Task SaveAsync(Policy policy, IOrganizationService organizationServ return; } - await EnablePolicyAsync(policy, org, organizationService, savingUserId); + await EnablePolicyAsync(policy, org, savingUserId); } public async Task GetMasterPasswordPolicyForUserAsync(User user) @@ -262,7 +285,7 @@ private async Task SetPolicyConfiguration(Policy policy) await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); } - private async Task EnablePolicyAsync(Policy policy, Organization org, IOrganizationService organizationService, Guid? savingUserId) + private async Task EnablePolicyAsync(Policy policy, Organization org, Guid? savingUserId) { var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id); if (!currentPolicy?.Enabled ?? true) diff --git a/src/Core/Auth/Services/Implementations/SsoConfigService.cs b/src/Core/Auth/Services/Implementations/SsoConfigService.cs index fdf7e278e008..532f000394c7 100644 --- a/src/Core/Auth/Services/Implementations/SsoConfigService.cs +++ b/src/Core/Auth/Services/Implementations/SsoConfigService.cs @@ -20,7 +20,6 @@ public class SsoConfigService : ISsoConfigService private readonly IPolicyService _policyService; private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationService _organizationService; private readonly IEventService _eventService; public SsoConfigService( @@ -29,7 +28,6 @@ public SsoConfigService( IPolicyService policyService, IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, - IOrganizationService organizationService, IEventService eventService) { _ssoConfigRepository = ssoConfigRepository; @@ -37,7 +35,6 @@ public SsoConfigService( _policyService = policyService; _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; - _organizationService = organizationService; _eventService = eventService; } @@ -71,20 +68,20 @@ public async Task SaveAsync(SsoConfig config, Organization organization) singleOrgPolicy.Enabled = true; - await _policyService.SaveAsync(singleOrgPolicy, _organizationService, null); + await _policyService.SaveAsync(singleOrgPolicy, null); var resetPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.ResetPassword) ?? new Policy { OrganizationId = config.OrganizationId, Type = PolicyType.ResetPassword, }; resetPolicy.Enabled = true; resetPolicy.SetDataModel(new ResetPasswordDataModel { AutoEnrollEnabled = true }); - await _policyService.SaveAsync(resetPolicy, _organizationService, null); + await _policyService.SaveAsync(resetPolicy, null); var ssoRequiredPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso) ?? new Policy { OrganizationId = config.OrganizationId, Type = PolicyType.RequireSso, }; ssoRequiredPolicy.Enabled = true; - await _policyService.SaveAsync(ssoRequiredPolicy, _organizationService, null); + await _policyService.SaveAsync(ssoRequiredPolicy, null); } await LogEventsAsync(config, oldConfig); diff --git a/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs b/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs index b6a58f82ff83..7bfef8a9317a 100644 --- a/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs +++ b/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs @@ -3,17 +3,14 @@ public enum ProviderMigrationProgress { Started = 1, - ClientsMigrated = 2, - TeamsPlanConfigured = 3, - EnterprisePlanConfigured = 4, - CustomerSetup = 5, - SubscriptionSetup = 6, - CreditApplied = 7, - Completed = 8, - - Reversing = 9, - ReversedClientMigrations = 10, - RemovedProviderPlans = 11 + NoClients = 2, + ClientsMigrated = 3, + TeamsPlanConfigured = 4, + EnterprisePlanConfigured = 5, + CustomerSetup = 6, + SubscriptionSetup = 7, + CreditApplied = 8, + Completed = 9, } public class ProviderMigrationTracker diff --git a/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs b/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs index 0da384ae27bd..9ca515a26030 100644 --- a/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs +++ b/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs @@ -41,7 +41,18 @@ public async Task Migrate(Guid providerId) await migrationTrackerCache.StartTracker(provider); - await MigrateClientsAsync(providerId); + var organizations = await GetClientsAsync(provider.Id); + + if (organizations.Count == 0) + { + logger.LogInformation("CB: Skipping migration for provider ({ProviderID}) with no clients", providerId); + + await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.NoClients); + + return; + } + + await MigrateClientsAsync(providerId, organizations); await ConfigureTeamsPlanAsync(providerId); @@ -65,6 +76,16 @@ public async Task GetResult(Guid providerId) return null; } + if (providerTracker.Progress == ProviderMigrationProgress.NoClients) + { + return new ProviderMigrationResult + { + ProviderId = providerTracker.ProviderId, + ProviderName = providerTracker.ProviderName, + Result = providerTracker.Progress.ToString() + }; + } + var clientTrackers = await Task.WhenAll(providerTracker.OrganizationIds.Select(organizationId => migrationTrackerCache.GetTracker(providerId, organizationId))); @@ -99,12 +120,10 @@ public async Task GetResult(Guid providerId) #region Steps - private async Task MigrateClientsAsync(Guid providerId) + private async Task MigrateClientsAsync(Guid providerId, List organizations) { logger.LogInformation("CB: Migrating clients for provider ({ProviderID})", providerId); - var organizations = await GetEnabledClientsAsync(providerId); - var organizationIds = organizations.Select(organization => organization.Id); await migrationTrackerCache.SetOrganizationIds(providerId, organizationIds); @@ -129,7 +148,7 @@ private async Task ConfigureTeamsPlanAsync(Guid providerId) { logger.LogInformation("CB: Configuring Teams plan for provider ({ProviderID})", providerId); - var organizations = await GetEnabledClientsAsync(providerId); + var organizations = await GetClientsAsync(providerId); var teamsSeats = organizations .Where(IsTeams) @@ -172,7 +191,7 @@ private async Task ConfigureEnterprisePlanAsync(Guid providerId) { logger.LogInformation("CB: Configuring Enterprise plan for provider ({ProviderID})", providerId); - var organizations = await GetEnabledClientsAsync(providerId); + var organizations = await GetClientsAsync(providerId); var enterpriseSeats = organizations .Where(IsEnterprise) @@ -215,7 +234,7 @@ private async Task SetupCustomerAsync(Provider provider) { if (string.IsNullOrEmpty(provider.GatewayCustomerId)) { - var organizations = await GetEnabledClientsAsync(provider.Id); + var organizations = await GetClientsAsync(provider.Id); var sampleOrganization = organizations.FirstOrDefault(organization => !string.IsNullOrEmpty(organization.GatewayCustomerId)); @@ -299,28 +318,43 @@ private async Task SetupSubscriptionAsync(Provider provider) private async Task ApplyCreditAsync(Provider provider) { - var organizations = await GetEnabledClientsAsync(provider.Id); + var organizations = await GetClientsAsync(provider.Id); var organizationCustomers = await Task.WhenAll(organizations.Select(organization => stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId))); var organizationCancellationCredit = organizationCustomers.Sum(customer => customer.Balance); - var legacyOrganizations = organizations.Where(organization => - organization.PlanType is + await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, + new CustomerBalanceTransactionCreateOptions + { + Amount = organizationCancellationCredit, + Currency = "USD", + Description = "Unused, prorated time for client organization subscriptions." + }); + + var migrationRecords = await Task.WhenAll(organizations.Select(organization => + clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id))); + + var legacyOrganizationMigrationRecords = migrationRecords.Where(migrationRecord => + migrationRecord.PlanType is PlanType.EnterpriseAnnually2020 or - PlanType.EnterpriseMonthly2020 or - PlanType.TeamsAnnually2020 or - PlanType.TeamsMonthly2020); + PlanType.TeamsAnnually2020); - var legacyOrganizationCredit = legacyOrganizations.Sum(organization => organization.Seats ?? 0); + var legacyOrganizationCredit = legacyOrganizationMigrationRecords.Sum(migrationRecord => migrationRecord.Seats) * 12 * -100; - await stripeAdapter.CustomerUpdateAsync(provider.GatewayCustomerId, new CustomerUpdateOptions + if (legacyOrganizationCredit < 0) { - Balance = organizationCancellationCredit + legacyOrganizationCredit - }); + await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, + new CustomerBalanceTransactionCreateOptions + { + Amount = legacyOrganizationCredit, + Currency = "USD", + Description = "1 year rebate for legacy client organizations." + }); + } - logger.LogInformation("CB: Applied {Credit} credit to provider ({ProviderID})", organizationCancellationCredit, provider.Id); + logger.LogInformation("CB: Applied {Credit} credit to provider ({ProviderID})", organizationCancellationCredit + legacyOrganizationCredit, provider.Id); await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CreditApplied); } @@ -340,13 +374,12 @@ private async Task UpdateProviderAsync(Provider provider) #region Utilities - private async Task> GetEnabledClientsAsync(Guid providerId) + private async Task> GetClientsAsync(Guid providerId) { var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); return (await Task.WhenAll(providerOrganizations.Select(providerOrganization => organizationRepository.GetByIdAsync(providerOrganization.OrganizationId)))) - .Where(organization => organization.Enabled) .ToList(); } diff --git a/src/Core/Billing/Models/OrganizationMetadata.cs b/src/Core/Billing/Models/OrganizationMetadata.cs index decc35ffd8c5..136964d7c10a 100644 --- a/src/Core/Billing/Models/OrganizationMetadata.cs +++ b/src/Core/Billing/Models/OrganizationMetadata.cs @@ -1,8 +1,10 @@ namespace Bit.Core.Billing.Models; public record OrganizationMetadata( + bool IsEligibleForSelfHost, bool IsOnSecretsManagerStandalone) { public static OrganizationMetadata Default() => new( - IsOnSecretsManagerStandalone: default); + IsEligibleForSelfHost: false, + IsOnSecretsManagerStandalone: false); } diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index 3c5938cab831..7db886203296 100644 --- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; @@ -26,6 +27,7 @@ public class OrganizationBillingService( IGlobalSettings globalSettings, ILogger logger, IOrganizationRepository organizationRepository, + IProviderRepository providerRepository, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService) : IOrganizationBillingService @@ -69,14 +71,11 @@ public async Task Finalize(OrganizationSale sale) var subscription = await subscriberService.GetSubscription(organization); - if (customer == null || subscription == null) - { - return OrganizationMetadata.Default(); - } + var isEligibleForSelfHost = await IsEligibleForSelfHost(organization, subscription); var isOnSecretsManagerStandalone = IsOnSecretsManagerStandalone(organization, customer, subscription); - return new OrganizationMetadata(isOnSecretsManagerStandalone); + return new OrganizationMetadata(isEligibleForSelfHost, isOnSecretsManagerStandalone); } public async Task UpdatePaymentMethod( @@ -340,11 +339,38 @@ private async Task CreateSubscriptionAsync( return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); } + private async Task IsEligibleForSelfHost( + Organization organization, + Subscription? organizationSubscription) + { + if (organization.Status != OrganizationStatusType.Managed) + { + return organization.Plan.Contains("Families") || + organization.Plan.Contains("Enterprise") && IsActive(organizationSubscription); + } + + var provider = await providerRepository.GetByOrganizationIdAsync(organization.Id); + + var providerSubscription = await subscriberService.GetSubscriptionOrThrow(provider); + + return organization.Plan.Contains("Enterprise") && IsActive(providerSubscription); + + bool IsActive(Subscription? subscription) => subscription?.Status is + StripeConstants.SubscriptionStatus.Active or + StripeConstants.SubscriptionStatus.Trialing or + StripeConstants.SubscriptionStatus.PastDue; + } + private static bool IsOnSecretsManagerStandalone( Organization organization, - Customer customer, - Subscription subscription) + Customer? customer, + Subscription? subscription) { + if (customer == null || subscription == null) + { + return false; + } + var plan = StaticStore.GetPlan(organization.PlanType); if (!plan.SupportsSecretsManager) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 6b4cf7e971ea..ecbe190ccd2a 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -146,6 +146,7 @@ public static class FeatureFlagKeys public const string RemoveServerVersionHeader = "remove-server-version-header"; public const string AccessIntelligence = "pm-13227-access-intelligence"; public const string VerifiedSsoDomainEndpoint = "pm-12337-refactor-sso-details-endpoint"; + public const string Pm13322AddPolicyDefinitions = "pm-13322-add-policy-definitions"; public const string LimitCollectionCreationDeletionSplit = "pm-10863-limit-collection-creation-deletion-split"; public static List GetAllKeys() diff --git a/src/Core/Models/Api/Request/PushDeviceRequestModel.cs b/src/Core/Models/Api/Request/PushDeviceRequestModel.cs index e1866b6f2735..8b97dcc3600f 100644 --- a/src/Core/Models/Api/Request/PushDeviceRequestModel.cs +++ b/src/Core/Models/Api/Request/PushDeviceRequestModel.cs @@ -1,5 +1,4 @@ using System.ComponentModel.DataAnnotations; -using Bit.Core.Enums; namespace Bit.Core.Models.Api; @@ -7,6 +6,4 @@ public class PushDeviceRequestModel { [Required] public string Id { get; set; } - [Required] - public DeviceType Type { get; set; } } diff --git a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs index 9f7ed5f28851..f8c2d296fd4a 100644 --- a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs +++ b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs @@ -1,5 +1,4 @@ using System.ComponentModel.DataAnnotations; -using Bit.Core.Enums; namespace Bit.Core.Models.Api; @@ -8,9 +7,9 @@ public class PushUpdateRequestModel public PushUpdateRequestModel() { } - public PushUpdateRequestModel(IEnumerable> devices, string organizationId) + public PushUpdateRequestModel(IEnumerable deviceIds, string organizationId) { - Devices = devices.Select(d => new PushDeviceRequestModel { Id = d.Key, Type = d.Value }); + Devices = deviceIds.Select(d => new PushDeviceRequestModel { Id = d }); OrganizationId = organizationId; } diff --git a/src/Core/Models/Data/InstallationDeviceEntity.cs b/src/Core/Models/Data/InstallationDeviceEntity.cs index 3186efc661c7..a3d960b242c0 100644 --- a/src/Core/Models/Data/InstallationDeviceEntity.cs +++ b/src/Core/Models/Data/InstallationDeviceEntity.cs @@ -37,4 +37,25 @@ public static bool IsInstallationDeviceId(string deviceId) { return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_'; } + public static bool TryParse(string deviceId, out InstallationDeviceEntity installationDeviceEntity) + { + installationDeviceEntity = null; + var installationId = Guid.Empty; + var deviceIdGuid = Guid.Empty; + if (!IsInstallationDeviceId(deviceId)) + { + return false; + } + var parts = deviceId.Split("_"); + if (parts.Length < 2) + { + return false; + } + if (!Guid.TryParse(parts[0], out installationId) || !Guid.TryParse(parts[1], out deviceIdGuid)) + { + return false; + } + installationDeviceEntity = new InstallationDeviceEntity(installationId, deviceIdGuid); + return true; + } } diff --git a/src/Core/NotificationCenter/Commands/MarkNotificationDeletedCommand.cs b/src/Core/NotificationCenter/Commands/MarkNotificationDeletedCommand.cs index fed9fd04699c..2ca7aa9051ee 100644 --- a/src/Core/NotificationCenter/Commands/MarkNotificationDeletedCommand.cs +++ b/src/Core/NotificationCenter/Commands/MarkNotificationDeletedCommand.cs @@ -49,11 +49,11 @@ await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.Us if (notificationStatus == null) { - notificationStatus = new NotificationStatus() + notificationStatus = new NotificationStatus { NotificationId = notificationId, UserId = _currentContext.UserId.Value, - DeletedDate = DateTime.Now + DeletedDate = DateTime.UtcNow }; await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notificationStatus, diff --git a/src/Core/NotificationCenter/Commands/MarkNotificationReadCommand.cs b/src/Core/NotificationCenter/Commands/MarkNotificationReadCommand.cs index 93686605011a..400e44463a9d 100644 --- a/src/Core/NotificationCenter/Commands/MarkNotificationReadCommand.cs +++ b/src/Core/NotificationCenter/Commands/MarkNotificationReadCommand.cs @@ -49,11 +49,11 @@ await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.Us if (notificationStatus == null) { - notificationStatus = new NotificationStatus() + notificationStatus = new NotificationStatus { NotificationId = notificationId, UserId = _currentContext.UserId.Value, - ReadDate = DateTime.Now + ReadDate = DateTime.UtcNow }; await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notificationStatus, diff --git a/src/Core/NotificationHub/INotificationHubClientProxy.cs b/src/Core/NotificationHub/INotificationHubClientProxy.cs new file mode 100644 index 000000000000..82b4d3959107 --- /dev/null +++ b/src/Core/NotificationHub/INotificationHubClientProxy.cs @@ -0,0 +1,8 @@ +using Microsoft.Azure.NotificationHubs; + +namespace Bit.Core.NotificationHub; + +public interface INotificationHubProxy +{ + Task<(INotificationHubClient Client, NotificationOutcome Outcome)[]> SendTemplateNotificationAsync(IDictionary properties, string tagExpression); +} diff --git a/src/Core/NotificationHub/INotificationHubPool.cs b/src/Core/NotificationHub/INotificationHubPool.cs new file mode 100644 index 000000000000..7c383d7b9649 --- /dev/null +++ b/src/Core/NotificationHub/INotificationHubPool.cs @@ -0,0 +1,9 @@ +using Microsoft.Azure.NotificationHubs; + +namespace Bit.Core.NotificationHub; + +public interface INotificationHubPool +{ + NotificationHubClient ClientFor(Guid comb); + INotificationHubProxy AllClients { get; } +} diff --git a/src/Core/NotificationHub/NotificationHubClientProxy.cs b/src/Core/NotificationHub/NotificationHubClientProxy.cs new file mode 100644 index 000000000000..815ac8836393 --- /dev/null +++ b/src/Core/NotificationHub/NotificationHubClientProxy.cs @@ -0,0 +1,26 @@ +using Microsoft.Azure.NotificationHubs; + +namespace Bit.Core.NotificationHub; + +public class NotificationHubClientProxy : INotificationHubProxy +{ + private readonly IEnumerable _clients; + + public NotificationHubClientProxy(IEnumerable clients) + { + _clients = clients; + } + + private async Task<(INotificationHubClient, T)[]> ApplyToAllClientsAsync(Func> action) + { + var tasks = _clients.Select(async c => (c, await action(c))); + return await Task.WhenAll(tasks); + } + + // partial proxy of INotificationHubClient implementation + // Note: Any other methods that are needed can simply be delegated as done here. + public async Task<(INotificationHubClient Client, NotificationOutcome Outcome)[]> SendTemplateNotificationAsync(IDictionary properties, string tagExpression) + { + return await ApplyToAllClientsAsync(async c => await c.SendTemplateNotificationAsync(properties, tagExpression)); + } +} diff --git a/src/Core/NotificationHub/NotificationHubConnection.cs b/src/Core/NotificationHub/NotificationHubConnection.cs new file mode 100644 index 000000000000..3a1437f70c52 --- /dev/null +++ b/src/Core/NotificationHub/NotificationHubConnection.cs @@ -0,0 +1,128 @@ +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Azure.NotificationHubs; + +class NotificationHubConnection +{ + public string HubName { get; init; } + public string ConnectionString { get; init; } + public bool EnableSendTracing { get; init; } + private NotificationHubClient _hubClient; + /// + /// Gets the NotificationHubClient for this connection. + /// + /// If the client is null, it will be initialized. + /// + /// Exception if the connection is invalid. + /// + public NotificationHubClient HubClient + { + get + { + if (_hubClient == null) + { + if (!IsValid) + { + throw new Exception("Invalid notification hub settings"); + } + Init(); + } + return _hubClient; + } + private set + { + _hubClient = value; + } + } + /// + /// Gets the start date for registration. + /// + /// If null, registration is always disabled. + /// + public DateTime? RegistrationStartDate { get; init; } + /// + /// Gets the end date for registration. + /// + /// If null, registration has no end date. + /// + public DateTime? RegistrationEndDate { get; init; } + /// + /// Gets whether all data needed to generate a connection to Notification Hub is present. + /// + public bool IsValid + { + get + { + { + var invalid = string.IsNullOrWhiteSpace(HubName) || string.IsNullOrWhiteSpace(ConnectionString); + return !invalid; + } + } + } + + public string LogString + { + get + { + return $"HubName: {HubName}, EnableSendTracing: {EnableSendTracing}, RegistrationStartDate: {RegistrationStartDate}, RegistrationEndDate: {RegistrationEndDate}"; + } + } + + /// + /// Gets whether registration is enabled for the given comb ID. + /// This is based off of the generation time encoded in the comb ID. + /// + /// + /// + public bool RegistrationEnabled(Guid comb) + { + var combTime = CoreHelpers.DateFromComb(comb); + return RegistrationEnabled(combTime); + } + + /// + /// Gets whether registration is enabled for the given time. + /// + /// The time to check + /// + public bool RegistrationEnabled(DateTime queryTime) + { + if (queryTime >= RegistrationEndDate || RegistrationStartDate == null) + { + return false; + } + + return RegistrationStartDate < queryTime; + } + + private NotificationHubConnection() { } + + /// + /// Creates a new NotificationHubConnection from the given settings. + /// + /// + /// + public static NotificationHubConnection From(GlobalSettings.NotificationHubSettings settings) + { + return new() + { + HubName = settings.HubName, + ConnectionString = settings.ConnectionString, + EnableSendTracing = settings.EnableSendTracing, + // Comb time is not precise enough for millisecond accuracy + RegistrationStartDate = settings.RegistrationStartDate.HasValue ? Truncate(settings.RegistrationStartDate.Value, TimeSpan.FromMilliseconds(10)) : null, + RegistrationEndDate = settings.RegistrationEndDate + }; + } + + private NotificationHubConnection Init() + { + HubClient = NotificationHubClient.CreateClientFromConnectionString(ConnectionString, HubName, EnableSendTracing); + return this; + } + + private static DateTime Truncate(DateTime dateTime, TimeSpan resolution) + { + return dateTime.AddTicks(-(dateTime.Ticks % resolution.Ticks)); + } +} diff --git a/src/Core/NotificationHub/NotificationHubPool.cs b/src/Core/NotificationHub/NotificationHubPool.cs new file mode 100644 index 000000000000..7448aad5bda1 --- /dev/null +++ b/src/Core/NotificationHub/NotificationHubPool.cs @@ -0,0 +1,62 @@ +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Azure.NotificationHubs; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.NotificationHub; + +public class NotificationHubPool : INotificationHubPool +{ + private List _connections { get; } + private readonly IEnumerable _clients; + private readonly ILogger _logger; + public NotificationHubPool(ILogger logger, GlobalSettings globalSettings) + { + _logger = logger; + _connections = FilterInvalidHubs(globalSettings.NotificationHubPool.NotificationHubs); + _clients = _connections.GroupBy(c => c.ConnectionString).Select(g => g.First().HubClient); + } + + private List FilterInvalidHubs(IEnumerable hubs) + { + List result = new(); + _logger.LogDebug("Filtering {HubCount} notification hubs", hubs.Count()); + foreach (var hub in hubs) + { + var connection = NotificationHubConnection.From(hub); + if (!connection.IsValid) + { + _logger.LogWarning("Invalid notification hub settings: {HubName}", hub.HubName ?? "hub name missing"); + continue; + } + _logger.LogDebug("Adding notification hub: {ConnectionLogString}", connection.LogString); + result.Add(connection); + } + + return result; + } + + + /// + /// Gets the NotificationHubClient for the given comb ID. + /// + /// + /// + /// Thrown when no notification hub is found for a given comb. + public NotificationHubClient ClientFor(Guid comb) + { + var possibleConnections = _connections.Where(c => c.RegistrationEnabled(comb)).ToArray(); + if (possibleConnections.Length == 0) + { + throw new InvalidOperationException($"No valid notification hubs are available for the given comb ({comb}).\n" + + $"The comb's datetime is {CoreHelpers.DateFromComb(comb)}." + + $"Hub start and end times are configured as follows:\n" + + string.Join("\n", _connections.Select(c => $"Hub {c.HubName} - Start: {c.RegistrationStartDate}, End: {c.RegistrationEndDate}"))); + } + var resolvedConnection = possibleConnections[CoreHelpers.BinForComb(comb, possibleConnections.Length)]; + _logger.LogTrace("Resolved notification hub for comb {Comb} out of {HubCount} hubs.\n{ConnectionInfo}", comb, possibleConnections.Length, resolvedConnection.LogString); + return resolvedConnection.HubClient; + } + + public INotificationHubProxy AllClients { get { return new NotificationHubClientProxy(_clients); } } +} diff --git a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs b/src/Core/NotificationHub/NotificationHubPushNotificationService.cs similarity index 84% rename from src/Core/Services/Implementations/NotificationHubPushNotificationService.cs rename to src/Core/NotificationHub/NotificationHubPushNotificationService.cs index 480f0dfa9ef8..6143676deffb 100644 --- a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs +++ b/src/Core/NotificationHub/NotificationHubPushNotificationService.cs @@ -6,45 +6,31 @@ using Bit.Core.Models; using Bit.Core.Models.Data; using Bit.Core.Repositories; -using Bit.Core.Settings; +using Bit.Core.Services; using Bit.Core.Tools.Entities; using Bit.Core.Vault.Entities; using Microsoft.AspNetCore.Http; -using Microsoft.Azure.NotificationHubs; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.NotificationHub; public class NotificationHubPushNotificationService : IPushNotificationService { private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; private readonly IHttpContextAccessor _httpContextAccessor; - private readonly List _clients = []; private readonly bool _enableTracing = false; + private readonly INotificationHubPool _notificationHubPool; private readonly ILogger _logger; public NotificationHubPushNotificationService( IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings, + INotificationHubPool notificationHubPool, IHttpContextAccessor httpContextAccessor, ILogger logger) { _installationDeviceRepository = installationDeviceRepository; - _globalSettings = globalSettings; _httpContextAccessor = httpContextAccessor; - - foreach (var hub in globalSettings.NotificationHubs) - { - var client = NotificationHubClient.CreateClientFromConnectionString( - hub.ConnectionString, - hub.HubName, - hub.EnableSendTracing); - _clients.Add(client); - - _enableTracing = _enableTracing || hub.EnableSendTracing; - } - + _notificationHubPool = notificationHubPool; _logger = logger; } @@ -264,30 +250,23 @@ private string BuildTag(string tag, string identifier) private async Task SendPayloadAsync(string tag, PushType type, object payload) { - var tasks = new List>(); - foreach (var client in _clients) - { - var task = client.SendTemplateNotificationAsync( - new Dictionary - { - { "type", ((byte)type).ToString() }, - { "payload", JsonSerializer.Serialize(payload) } - }, tag); - tasks.Add(task); - } - - await Task.WhenAll(tasks); + var results = await _notificationHubPool.AllClients.SendTemplateNotificationAsync( + new Dictionary + { + { "type", ((byte)type).ToString() }, + { "payload", JsonSerializer.Serialize(payload) } + }, tag); if (_enableTracing) { - for (var i = 0; i < tasks.Count; i++) + foreach (var (client, outcome) in results) { - if (_clients[i].EnableTestSend) + if (!client.EnableTestSend) { - var outcome = await tasks[i]; - _logger.LogInformation("Azure Notification Hub Tracking ID: {id} | {type} push notification with {success} successes and {failure} failures with a payload of {@payload} and result of {@results}", - outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results); + continue; } + _logger.LogInformation("Azure Notification Hub Tracking ID: {Id} | {Type} push notification with {Success} successes and {Failure} failures with a payload of {@Payload} and result of {@Results}", + outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results); } } } diff --git a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs b/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs similarity index 64% rename from src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs rename to src/Core/NotificationHub/NotificationHubPushRegistrationService.cs index 87df60e8e3ce..ae32babf4477 100644 --- a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs +++ b/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs @@ -1,50 +1,34 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Settings; using Microsoft.Azure.NotificationHubs; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.NotificationHub; public class NotificationHubPushRegistrationService : IPushRegistrationService { private readonly IInstallationDeviceRepository _installationDeviceRepository; private readonly GlobalSettings _globalSettings; + private readonly INotificationHubPool _notificationHubPool; private readonly IServiceProvider _serviceProvider; private readonly ILogger _logger; - private Dictionary _clients = []; public NotificationHubPushRegistrationService( IInstallationDeviceRepository installationDeviceRepository, GlobalSettings globalSettings, + INotificationHubPool notificationHubPool, IServiceProvider serviceProvider, ILogger logger) { _installationDeviceRepository = installationDeviceRepository; _globalSettings = globalSettings; + _notificationHubPool = notificationHubPool; _serviceProvider = serviceProvider; _logger = logger; - - // Is this dirty to do in the ctor? - void addHub(NotificationHubType type) - { - var hubRegistration = globalSettings.NotificationHubs.FirstOrDefault( - h => h.HubType == type && h.EnableRegistration); - if (hubRegistration != null) - { - var client = NotificationHubClient.CreateClientFromConnectionString( - hubRegistration.ConnectionString, - hubRegistration.HubName, - hubRegistration.EnableSendTracing); - _clients.Add(type, client); - } - } - - addHub(NotificationHubType.General); - addHub(NotificationHubType.iOS); - addHub(NotificationHubType.Android); } public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, @@ -117,7 +101,7 @@ public async Task CreateOrUpdateRegistrationAsync(string pushToken, string devic BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate, userId, identifier); - await GetClient(type).CreateOrUpdateInstallationAsync(installation); + await ClientFor(GetComb(deviceId)).CreateOrUpdateInstallationAsync(installation); if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) { await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); @@ -152,11 +136,11 @@ private void BuildInstallationTemplate(Installation installation, string templat installation.Templates.Add(fullTemplateId, template); } - public async Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType) + public async Task DeleteRegistrationAsync(string deviceId) { try { - await GetClient(deviceType).DeleteInstallationAsync(deviceId); + await ClientFor(GetComb(deviceId)).DeleteInstallationAsync(deviceId); if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) { await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId)); @@ -168,31 +152,31 @@ public async Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType } } - public async Task AddUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId) + public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) { - await PatchTagsForUserDevicesAsync(devices, UpdateOperationType.Add, $"organizationId:{organizationId}"); - if (devices.Any() && InstallationDeviceEntity.IsInstallationDeviceId(devices.First().Key)) + await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}"); + if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) { - var entities = devices.Select(e => new InstallationDeviceEntity(e.Key)); + var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); } } - public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId) + public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) { - await PatchTagsForUserDevicesAsync(devices, UpdateOperationType.Remove, + await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove, $"organizationId:{organizationId}"); - if (devices.Any() && InstallationDeviceEntity.IsInstallationDeviceId(devices.First().Key)) + if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) { - var entities = devices.Select(e => new InstallationDeviceEntity(e.Key)); + var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); } } - private async Task PatchTagsForUserDevicesAsync(IEnumerable> devices, UpdateOperationType op, + private async Task PatchTagsForUserDevicesAsync(IEnumerable deviceIds, UpdateOperationType op, string tag) { - if (!devices.Any()) + if (!deviceIds.Any()) { return; } @@ -212,11 +196,11 @@ private async Task PatchTagsForUserDevicesAsync(IEnumerable { operation }); + await ClientFor(GetComb(deviceId)).PatchInstallationAsync(deviceId, new List { operation }); } catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found")) { @@ -225,53 +209,29 @@ private async Task PatchTagsForUserDevicesAsync(IEnumerable> devices, string organizationId); - Task DeleteUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId); + Task DeleteRegistrationAsync(string deviceId); + Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); + Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); } diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs index bb57f1cd0d7f..a288e1cbedce 100644 --- a/src/Core/Services/IStripeAdapter.cs +++ b/src/Core/Services/IStripeAdapter.cs @@ -10,6 +10,8 @@ public interface IStripeAdapter Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null); Task CustomerDeleteAsync(string id); Task> CustomerListPaymentMethods(string id, CustomerListPaymentMethodsOptions options = null); + Task CustomerBalanceTransactionCreate(string customerId, + CustomerBalanceTransactionCreateOptions options); Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions); Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null); Task> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions); diff --git a/src/Core/Services/Implementations/DeviceService.cs b/src/Core/Services/Implementations/DeviceService.cs index 9d8315f691ba..5b1e4b0f0171 100644 --- a/src/Core/Services/Implementations/DeviceService.cs +++ b/src/Core/Services/Implementations/DeviceService.cs @@ -38,13 +38,13 @@ await _pushRegistrationService.CreateOrUpdateRegistrationAsync(device.PushToken, public async Task ClearTokenAsync(Device device) { await _deviceRepository.ClearPushTokenAsync(device.Id); - await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString(), device.Type); + await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); } public async Task DeleteAsync(Device device) { await _deviceRepository.DeleteAsync(device); - await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString(), device.Type); + await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); } public async Task UpdateDevicesTrustAsync(string currentDeviceIdentifier, diff --git a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs index 92e29908f50b..00be72c980e0 100644 --- a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs +++ b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs @@ -1,61 +1,31 @@ using Bit.Core.Auth.Entities; using Bit.Core.Enums; -using Bit.Core.Repositories; using Bit.Core.Settings; using Bit.Core.Tools.Entities; -using Bit.Core.Utilities; using Bit.Core.Vault.Entities; -using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Bit.Core.Services; public class MultiServicePushNotificationService : IPushNotificationService { - private readonly List _services = new List(); + private readonly IEnumerable _services; private readonly ILogger _logger; public MultiServicePushNotificationService( - IHttpClientFactory httpFactory, - IDeviceRepository deviceRepository, - IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor, + [FromKeyedServices("implementation")] IEnumerable services, ILogger logger, - ILogger relayLogger, - ILogger hubLogger) + GlobalSettings globalSettings) { - if (globalSettings.SelfHosted) - { - if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && - globalSettings.Installation?.Id != null && - CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) - { - _services.Add(new RelayPushNotificationService(httpFactory, deviceRepository, globalSettings, - httpContextAccessor, relayLogger)); - } - if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) && - CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications)) - { - _services.Add(new NotificationsApiPushNotificationService( - httpFactory, globalSettings, httpContextAccessor, hubLogger)); - } - } - else - { - var generalHub = globalSettings.NotificationHubs?.FirstOrDefault(h => h.HubType == NotificationHubType.General); - if (CoreHelpers.SettingHasValue(generalHub?.ConnectionString)) - { - _services.Add(new NotificationHubPushNotificationService(installationDeviceRepository, - globalSettings, httpContextAccessor, hubLogger)); - } - if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) - { - _services.Add(new AzureQueuePushNotificationService(globalSettings, httpContextAccessor)); - } - } + _services = services; _logger = logger; + _logger.LogInformation("Hub services: {Services}", _services.Count()); + globalSettings?.NotificationHubPool?.NotificationHubs?.ForEach(hub => + { + _logger.LogInformation("HubName: {HubName}, EnableSendTracing: {EnableSendTracing}, RegistrationStartDate: {RegistrationStartDate}, RegistrationEndDate: {RegistrationEndDate}", hub.HubName, hub.EnableSendTracing, hub.RegistrationStartDate, hub.RegistrationEndDate); + }); } public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) diff --git a/src/Core/Services/Implementations/RelayPushRegistrationService.cs b/src/Core/Services/Implementations/RelayPushRegistrationService.cs index d9df7d04dcbc..d0f7736e984b 100644 --- a/src/Core/Services/Implementations/RelayPushRegistrationService.cs +++ b/src/Core/Services/Implementations/RelayPushRegistrationService.cs @@ -38,37 +38,36 @@ public async Task CreateOrUpdateRegistrationAsync(string pushToken, string devic await SendAsync(HttpMethod.Post, "push/register", requestModel); } - public async Task DeleteRegistrationAsync(string deviceId, DeviceType type) + public async Task DeleteRegistrationAsync(string deviceId) { var requestModel = new PushDeviceRequestModel { Id = deviceId, - Type = type, }; await SendAsync(HttpMethod.Post, "push/delete", requestModel); } public async Task AddUserRegistrationOrganizationAsync( - IEnumerable> devices, string organizationId) + IEnumerable deviceIds, string organizationId) { - if (!devices.Any()) + if (!deviceIds.Any()) { return; } - var requestModel = new PushUpdateRequestModel(devices, organizationId); + var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); await SendAsync(HttpMethod.Put, "push/add-organization", requestModel); } public async Task DeleteUserRegistrationOrganizationAsync( - IEnumerable> devices, string organizationId) + IEnumerable deviceIds, string organizationId) { - if (!devices.Any()) + if (!deviceIds.Any()) { return; } - var requestModel = new PushUpdateRequestModel(devices, organizationId); + var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel); } } diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs index 100a47f75a3f..e5fee63b9d82 100644 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Services/Implementations/StripeAdapter.cs @@ -18,6 +18,7 @@ public class StripeAdapter : IStripeAdapter private readonly Stripe.PriceService _priceService; private readonly Stripe.SetupIntentService _setupIntentService; private readonly Stripe.TestHelpers.TestClockService _testClockService; + private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; public StripeAdapter() { @@ -34,6 +35,7 @@ public StripeAdapter() _priceService = new Stripe.PriceService(); _setupIntentService = new SetupIntentService(); _testClockService = new Stripe.TestHelpers.TestClockService(); + _customerBalanceTransactionService = new CustomerBalanceTransactionService(); } public Task CustomerCreateAsync(Stripe.CustomerCreateOptions options) @@ -63,6 +65,10 @@ public async Task> CustomerListPaymentMethods(string id, return paymentMethods.Data; } + public async Task CustomerBalanceTransactionCreate(string customerId, + CustomerBalanceTransactionCreateOptions options) + => await _customerBalanceTransactionService.CreateAsync(customerId, options); + public Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options) { return _subscriptionService.CreateAsync(options); diff --git a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs index fcd0889248a3..f6279c946798 100644 --- a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs +++ b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs @@ -4,7 +4,7 @@ namespace Bit.Core.Services; public class NoopPushRegistrationService : IPushRegistrationService { - public Task AddUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId) + public Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) { return Task.FromResult(0); } @@ -15,12 +15,12 @@ public Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, s return Task.FromResult(0); } - public Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType) + public Task DeleteRegistrationAsync(string deviceId) { return Task.FromResult(0); } - public Task DeleteUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId) + public Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) { return Task.FromResult(0); } diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index f99fb3b57d19..793b6ac1c16a 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -1,5 +1,4 @@ using Bit.Core.Auth.Settings; -using Bit.Core.Enums; using Bit.Core.Settings.LoggingSettings; namespace Bit.Core.Settings; @@ -65,7 +64,7 @@ public virtual string LicenseDirectory public virtual SentrySettings Sentry { get; set; } = new SentrySettings(); public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings(); public virtual ILogLevelSettings MinLogLevel { get; set; } = new LogLevelSettings(); - public virtual List NotificationHubs { get; set; } = new(); + public virtual NotificationHubPoolSettings NotificationHubPool { get; set; } = new(); public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings(); public virtual DuoSettings Duo { get; set; } = new DuoSettings(); public virtual BraintreeSettings Braintree { get; set; } = new BraintreeSettings(); @@ -424,7 +423,7 @@ public class NotificationHubSettings public string ConnectionString { get => _connectionString; - set => _connectionString = value.Trim('"'); + set => _connectionString = value?.Trim('"'); } public string HubName { get; set; } /// @@ -433,10 +432,32 @@ public string ConnectionString /// public bool EnableSendTracing { get; set; } = false; /// - /// At least one hub configuration should have registration enabled, preferably the General hub as a safety net. + /// The date and time at which registration will be enabled. + /// + /// **This value should not be updated once set, as it is used to determine installation location of devices.** + /// + /// If null, registration is disabled. + /// + /// + public DateTime? RegistrationStartDate { get; set; } + /// + /// The date and time at which registration will be disabled. + /// + /// **This value should not be updated once set, as it is used to determine installation location of devices.** + /// + /// If null, hub registration has no yet known expiry. + /// + public DateTime? RegistrationEndDate { get; set; } + } + + public class NotificationHubPoolSettings + { + /// + /// List of Notification Hub settings to use for sending push notifications. + /// + /// Note that hubs on the same namespace share active device limits, so multiple namespaces should be used to increase capacity. /// - public bool EnableRegistration { get; set; } - public NotificationHubType HubType { get; set; } + public List NotificationHubs { get; set; } = new(); } public class YubicoSettings diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs index 7fcaab3c4e9b..8e4aa0e4035e 100644 --- a/src/Core/Utilities/CoreHelpers.cs +++ b/src/Core/Utilities/CoreHelpers.cs @@ -76,6 +76,39 @@ internal static Guid GenerateComb(Guid startingGuid, DateTime time) return new Guid(guidArray); } + internal static DateTime DateFromComb(Guid combGuid) + { + var guidArray = combGuid.ToByteArray(); + var daysArray = new byte[4]; + var msecsArray = new byte[4]; + + Array.Copy(guidArray, guidArray.Length - 6, daysArray, 2, 2); + Array.Copy(guidArray, guidArray.Length - 4, msecsArray, 0, 4); + + Array.Reverse(daysArray); + Array.Reverse(msecsArray); + + var days = BitConverter.ToInt32(daysArray, 0); + var msecs = BitConverter.ToInt32(msecsArray, 0); + + var time = TimeSpan.FromDays(days) + TimeSpan.FromMilliseconds(msecs * 3.333333); + return new DateTime(_baseDateTicks + time.Ticks, DateTimeKind.Utc); + } + + internal static long BinForComb(Guid combGuid, int binCount) + { + // From System.Web.Util.HashCodeCombiner + uint CombineHashCodes(uint h1, byte h2) + { + return (uint)(((h1 << 5) + h1) ^ h2); + } + var guidArray = combGuid.ToByteArray(); + var randomArray = new byte[10]; + Array.Copy(guidArray, 0, randomArray, 0, 10); + var hash = randomArray.Aggregate((uint)randomArray.Length, CombineHashCodes); + return hash % binCount; + } + public static string CleanCertificateThumbprint(string thumbprint) { // Clean possible garbage characters from thumbprint copy/paste diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index bd3aecf2f5f7..b0a2c42eaded 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -4,6 +4,7 @@ using System.Security.Cryptography.X509Certificates; using AspNetCoreRateLimit; using Bit.Core.AdminConsole.Models.Business.Tokenables; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services.Implementations; using Bit.Core.AdminConsole.Services.NoopImplementations; @@ -24,6 +25,7 @@ using Bit.Core.HostedServices; using Bit.Core.Identity; using Bit.Core.IdentityServer; +using Bit.Core.NotificationHub; using Bit.Core.OrganizationFeatures; using Bit.Core.Repositories; using Bit.Core.Resources; @@ -102,9 +104,9 @@ public static void AddBaseServices(this IServiceCollection services, IGlobalSett services.AddUserServices(globalSettings); services.AddTrialInitiationServices(); services.AddOrganizationServices(globalSettings); + services.AddPolicyServices(); services.AddScoped(); services.AddScoped(); - services.AddScoped(); services.AddScoped(); services.AddScoped(); services.AddSingleton(); @@ -263,16 +265,30 @@ public static void AddDefaultServices(this IServiceCollection services, GlobalSe } services.AddSingleton(); - if (globalSettings.SelfHosted && - CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && - globalSettings.Installation?.Id != null && - CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) + if (globalSettings.SelfHosted) { - services.AddSingleton(); + if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && + globalSettings.Installation?.Id != null && + CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) + { + services.AddKeyedSingleton("implementation"); + services.AddSingleton(); + } + if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) && + CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications)) + { + services.AddKeyedSingleton("implementation"); + } } else if (!globalSettings.SelfHosted) { + services.AddSingleton(); services.AddSingleton(); + services.AddKeyedSingleton("implementation"); + if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) + { + services.AddKeyedSingleton("implementation"); + } } else { diff --git a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs index 70ca599400e1..b46fd307e96b 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs @@ -52,7 +52,7 @@ public async Task GetMetadataAsync_OK( { sutProvider.GetDependency().AccessMembersTab(organizationId).Returns(true); sutProvider.GetDependency().GetMetadata(organizationId) - .Returns(new OrganizationMetadata(true)); + .Returns(new OrganizationMetadata(true, true)); var result = await sutProvider.Sut.GetMetadataAsync(organizationId); @@ -60,6 +60,7 @@ public async Task GetMetadataAsync_OK( var organizationMetadataResponse = ((Ok)result).Value; + Assert.True(organizationMetadataResponse.IsEligibleForSelfHost); Assert.True(organizationMetadataResponse.IsOnSecretsManagerStandalone); } diff --git a/test/Common/AutoFixture/ControllerCustomization.cs b/test/Common/AutoFixture/ControllerCustomization.cs index f695f86b5503..91fffbf09971 100644 --- a/test/Common/AutoFixture/ControllerCustomization.cs +++ b/test/Common/AutoFixture/ControllerCustomization.cs @@ -1,6 +1,5 @@ using AutoFixture; using Microsoft.AspNetCore.Mvc; -using Org.BouncyCastle.Security; namespace Bit.Test.Common.AutoFixture; @@ -15,7 +14,7 @@ public ControllerCustomization(Type controllerType) { if (!controllerType.IsAssignableTo(typeof(Controller))) { - throw new InvalidParameterException($"{nameof(controllerType)} must derive from {typeof(Controller).Name}"); + throw new Exception($"{nameof(controllerType)} must derive from {typeof(Controller).Name}"); } _controllerType = controllerType; diff --git a/test/Common/AutoFixture/SutProvider.cs b/test/Common/AutoFixture/SutProvider.cs index ac953965bdfc..fefe6c3ebf25 100644 --- a/test/Common/AutoFixture/SutProvider.cs +++ b/test/Common/AutoFixture/SutProvider.cs @@ -127,7 +127,6 @@ public object Create(object request, ISpecimenContext context) return _sutProvider.GetDependency(parameterInfo.ParameterType, ""); } - // This is the equivalent of _fixture.Create, but no overload for // Create(Type type) exists. var dependency = new SpecimenContext(_fixture).Resolve(new SeededRequest(parameterInfo.ParameterType, diff --git a/test/Common/AutoFixture/SutProviderExtensions.cs b/test/Common/AutoFixture/SutProviderExtensions.cs index 1fdf22653943..bdc8604166fd 100644 --- a/test/Common/AutoFixture/SutProviderExtensions.cs +++ b/test/Common/AutoFixture/SutProviderExtensions.cs @@ -1,6 +1,7 @@ using AutoFixture; using Bit.Core.Services; using Bit.Core.Settings; +using Microsoft.Extensions.Time.Testing; using NSubstitute; using RichardSzalay.MockHttp; @@ -47,4 +48,19 @@ public static SutProvider ConfigureBaseIdentityClientService(this SutProvi .SetDependency(mockHttpClientFactory) .Create(); } + + /// + /// Configures SutProvider to use FakeTimeProvider. + /// It is registered under both the TimeProvider type and the FakeTimeProvider type + /// so that it can be retrieved in a type-safe manner with GetDependency. + /// This can be chained with other builder methods; make sure to call + /// before use. + /// + public static SutProvider WithFakeTimeProvider(this SutProvider sutProvider) + { + var fakeTimeProvider = new FakeTimeProvider(); + return sutProvider + .SetDependency((TimeProvider)fakeTimeProvider) + .SetDependency(fakeTimeProvider); + } } diff --git a/test/Common/Common.csproj b/test/Common/Common.csproj index 7b9fbe42d5eb..2f11798cef8f 100644 --- a/test/Common/Common.csproj +++ b/test/Common/Common.csproj @@ -5,6 +5,7 @@ + diff --git a/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs index f70fd579e371..09b112c43c1e 100644 --- a/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs +++ b/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs @@ -9,10 +9,12 @@ namespace Bit.Core.Test.AdminConsole.AutoFixture; internal class PolicyCustomization : ICustomization { public PolicyType Type { get; set; } + public bool Enabled { get; set; } - public PolicyCustomization(PolicyType type) + public PolicyCustomization(PolicyType type, bool enabled) { Type = type; + Enabled = enabled; } public void Customize(IFixture fixture) @@ -20,21 +22,23 @@ public void Customize(IFixture fixture) fixture.Customize(composer => composer .With(o => o.OrganizationId, Guid.NewGuid()) .With(o => o.Type, Type) - .With(o => o.Enabled, true)); + .With(o => o.Enabled, Enabled)); } } public class PolicyAttribute : CustomizeAttribute { private readonly PolicyType _type; + private readonly bool _enabled; - public PolicyAttribute(PolicyType type) + public PolicyAttribute(PolicyType type, bool enabled = true) { _type = type; + _enabled = enabled; } public override ICustomization GetCustomization(ParameterInfo parameter) { - return new PolicyCustomization(_type); + return new PolicyCustomization(_type, _enabled); } } diff --git a/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs new file mode 100644 index 000000000000..dff9b571782c --- /dev/null +++ b/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs @@ -0,0 +1,25 @@ +using System.Reflection; +using AutoFixture; +using AutoFixture.Xunit2; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.Test.AdminConsole.AutoFixture; + +internal class PolicyUpdateCustomization(PolicyType type, bool enabled) : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.Type, type) + .With(o => o.Enabled, enabled)); + } +} + +public class PolicyUpdateAttribute(PolicyType type, bool enabled = true) : CustomizeAttribute +{ + public override ICustomization GetCustomization(ParameterInfo parameter) + { + return new PolicyUpdateCustomization(type, enabled); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorFixtures.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorFixtures.cs new file mode 100644 index 000000000000..ba4741d8bdf9 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorFixtures.cs @@ -0,0 +1,43 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using NSubstitute; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies; + +public class FakeSingleOrgPolicyValidator : IPolicyValidator +{ + public PolicyType Type => PolicyType.SingleOrg; + public IEnumerable RequiredPolicies => Array.Empty(); + + public readonly Func> ValidateAsyncMock = Substitute.For>>(); + public readonly Action OnSaveSideEffectsAsyncMock = Substitute.For>(); + + public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + return ValidateAsyncMock(policyUpdate, currentPolicy); + } + + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + OnSaveSideEffectsAsyncMock(policyUpdate, currentPolicy); + return Task.FromResult(0); + } +} +public class FakeRequireSsoPolicyValidator : IPolicyValidator +{ + public PolicyType Type => PolicyType.RequireSso; + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(""); + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0); +} +public class FakeVaultTimeoutPolicyValidator : IPolicyValidator +{ + public PolicyType Type => PolicyType.MaximumVaultTimeout; + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(""); + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(0); +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorHelpersTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorHelpersTests.cs new file mode 100644 index 000000000000..99f99706fa86 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidatorHelpersTests.cs @@ -0,0 +1,64 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies; + +public class PolicyValidatorHelpersTests +{ + [Fact] + public void ValidateDecryptionOptionsNotEnabled_RequiredByKeyConnector_ValidationError() + { + var ssoConfig = new SsoConfig(); + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector }); + + var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.KeyConnector]); + + Assert.Contains("Key Connector is enabled", result); + } + + [Fact] + public void ValidateDecryptionOptionsNotEnabled_RequiredByTDE_ValidationError() + { + var ssoConfig = new SsoConfig(); + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption }); + + var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.TrustedDeviceEncryption]); + + Assert.Contains("Trusted device encryption is on", result); + } + + [Fact] + public void ValidateDecryptionOptionsNotEnabled_NullSsoConfig_NoValidationError() + { + var ssoConfig = new SsoConfig(); + var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.KeyConnector]); + + Assert.True(string.IsNullOrEmpty(result)); + } + + [Fact] + public void ValidateDecryptionOptionsNotEnabled_RequiredOptionNotEnabled_NoValidationError() + { + var ssoConfig = new SsoConfig(); + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector }); + + var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.TrustedDeviceEncryption]); + + Assert.True(string.IsNullOrEmpty(result)); + } + + [Fact] + public void ValidateDecryptionOptionsNotEnabled_SsoConfigDisabled_NoValidationError() + { + var ssoConfig = new SsoConfig(); + ssoConfig.Enabled = false; + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector }); + + var result = ssoConfig.ValidateDecryptionOptionsNotEnabled([MemberDecryptionType.KeyConnector]); + + Assert.True(string.IsNullOrEmpty(result)); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs new file mode 100644 index 000000000000..d3af765f799a --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs @@ -0,0 +1,75 @@ +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Auth.Repositories; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +[SutProviderCustomize] +public class RequireSsoPolicyValidatorTests +{ + [Theory, BitAutoData] + public async Task ValidateAsync_DisablingPolicy_KeyConnectorEnabled_ValidationError( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = true }; + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector }); + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy); + Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_DisablingPolicy_TdeEnabled_ValidationError( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = true }; + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption }); + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy); + Assert.Contains("Trusted device encryption is on", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_DisablingPolicy_DecryptionOptionsNotEnabled_Success( + [PolicyUpdate(PolicyType.ResetPassword, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.ResetPassword)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = false }; + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy); + Assert.True(string.IsNullOrEmpty(result)); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs new file mode 100644 index 000000000000..83939406b590 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs @@ -0,0 +1,71 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Auth.Repositories; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +[SutProviderCustomize] +public class ResetPasswordPolicyValidatorTests +{ + [Theory] + [BitAutoData(true, false)] + [BitAutoData(false, true)] + [BitAutoData(false, false)] + public async Task ValidateAsync_DisablingPolicy_TdeEnabled_ValidationError( + bool policyEnabled, + bool autoEnrollEnabled, + [PolicyUpdate(PolicyType.ResetPassword)] PolicyUpdate policyUpdate, + [Policy(PolicyType.ResetPassword)] Policy policy, + SutProvider sutProvider) + { + policyUpdate.Enabled = policyEnabled; + policyUpdate.SetDataModel(new ResetPasswordDataModel + { + AutoEnrollEnabled = autoEnrollEnabled + }); + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = true }; + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption }); + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy); + Assert.Contains("Trusted device encryption is on and requires this policy.", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_DisablingPolicy_TdeNotEnabled_Success( + [PolicyUpdate(PolicyType.ResetPassword, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.ResetPassword)] Policy policy, + SutProvider sutProvider) + { + policyUpdate.SetDataModel(new ResetPasswordDataModel + { + AutoEnrollEnabled = false + }); + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = false }; + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy); + Assert.True(string.IsNullOrEmpty(result)); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs new file mode 100644 index 000000000000..76ee5748403f --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs @@ -0,0 +1,129 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Auth.Repositories; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +[SutProviderCustomize] +public class SingleOrgPolicyValidatorTests +{ + [Theory, BitAutoData] + public async Task ValidateAsync_DisablingPolicy_KeyConnectorEnabled_ValidationError( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = true }; + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector }); + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy); + Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_DisablingPolicy_KeyConnectorNotEnabled_Success( + [PolicyUpdate(PolicyType.ResetPassword, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.ResetPassword)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = false }; + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy); + Assert.True(string.IsNullOrEmpty(result)); + } + + [Theory, BitAutoData] + public async Task OnSaveSideEffectsAsync_RemovesNonCompliantUsers( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy policy, + Guid savingUserId, + Guid nonCompliantUserId, + Organization organization, SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + + var compliantUser1 = new OrganizationUserUserDetails + { + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = new Guid(), + Email = "user1@example.com" + }; + + var compliantUser2 = new OrganizationUserUserDetails + { + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = new Guid(), + Email = "user2@example.com" + }; + + var nonCompliantUser = new OrganizationUserUserDetails + { + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = nonCompliantUserId, + Email = "user3@example.com" + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([compliantUser1, compliantUser2, nonCompliantUser]); + + var otherOrganizationUser = new OrganizationUser + { + OrganizationId = new Guid(), + UserId = nonCompliantUserId, + Status = OrganizationUserStatusType.Confirmed + }; + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Is>(ids => ids.Contains(nonCompliantUserId))) + .Returns([otherOrganizationUser]); + + sutProvider.GetDependency().UserId.Returns(savingUserId); + sutProvider.GetDependency().GetByIdAsync(policyUpdate.OrganizationId).Returns(organization); + + await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy); + + await sutProvider.GetDependency() + .Received(1) + .RemoveUserAsync(policyUpdate.OrganizationId, nonCompliantUser.Id, savingUserId); + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(organization.DisplayName(), + "user3@example.com"); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs new file mode 100644 index 000000000000..4dce13174934 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs @@ -0,0 +1,209 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +[SutProviderCustomize] +public class TwoFactorAuthenticationPolicyValidatorTests +{ + [Theory, BitAutoData] + public async Task OnSaveSideEffectsAsync_RemovesNonCompliantUsers( + Organization organization, + [PolicyUpdate(PolicyType.TwoFactorAuthentication)] PolicyUpdate policyUpdate, + [Policy(PolicyType.TwoFactorAuthentication, false)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var orgUserDetailUserInvited = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Invited, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "user1@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = false + }; + var orgUserDetailUserAcceptedWith2FA = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "user2@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = true + }; + var orgUserDetailUserAcceptedWithout2FA = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "user3@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = true + }; + var orgUserDetailAdmin = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.Admin, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "admin@test.com", + Name = "ADMIN", + UserId = Guid.NewGuid(), + HasMasterPassword = false + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns(new List + { + orgUserDetailUserInvited, + orgUserDetailUserAcceptedWith2FA, + orgUserDetailUserAcceptedWithout2FA, + orgUserDetailAdmin + }); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns(new List<(OrganizationUserUserDetails user, bool hasTwoFactor)>() + { + (orgUserDetailUserInvited, false), + (orgUserDetailUserAcceptedWith2FA, true), + (orgUserDetailUserAcceptedWithout2FA, false), + (orgUserDetailAdmin, false), + }); + + var savingUserId = Guid.NewGuid(); + sutProvider.GetDependency().UserId.Returns(savingUserId); + + await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy); + + var removeOrganizationUserCommand = sutProvider.GetDependency(); + + await removeOrganizationUserCommand.Received() + .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserAcceptedWithout2FA.Id, savingUserId); + await sutProvider.GetDependency().Received() + .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailUserAcceptedWithout2FA.Email); + + await removeOrganizationUserCommand.DidNotReceive() + .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserInvited.Id, savingUserId); + await sutProvider.GetDependency().DidNotReceive() + .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailUserInvited.Email); + await removeOrganizationUserCommand.DidNotReceive() + .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserAcceptedWith2FA.Id, savingUserId); + await sutProvider.GetDependency().DidNotReceive() + .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailUserAcceptedWith2FA.Email); + await removeOrganizationUserCommand.DidNotReceive() + .RemoveUserAsync(policy.OrganizationId, orgUserDetailAdmin.Id, savingUserId); + await sutProvider.GetDependency().DidNotReceive() + .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailAdmin.Email); + } + + [Theory, BitAutoData] + public async Task OnSaveSideEffectsAsync_UsersToBeRemovedDontHaveMasterPasswords_Throws( + Organization organization, + [PolicyUpdate(PolicyType.TwoFactorAuthentication)] PolicyUpdate policyUpdate, + [Policy(PolicyType.TwoFactorAuthentication, false)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + + var orgUserDetailUserWith2FAAndMP = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "user1@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = true + }; + var orgUserDetailUserWith2FANoMP = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "user2@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = false + }; + var orgUserDetailUserWithout2FA = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "user3@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = false + }; + var orgUserDetailAdmin = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.Admin, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "admin@test.com", + Name = "ADMIN", + UserId = Guid.NewGuid(), + HasMasterPassword = false + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policy.OrganizationId) + .Returns(new List + { + orgUserDetailUserWith2FAAndMP, + orgUserDetailUserWith2FANoMP, + orgUserDetailUserWithout2FA, + orgUserDetailAdmin + }); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Is>(ids => + ids.Contains(orgUserDetailUserWith2FANoMP.UserId.Value) + && ids.Contains(orgUserDetailUserWithout2FA.UserId.Value) + && ids.Contains(orgUserDetailAdmin.UserId.Value))) + .Returns(new List<(Guid userId, bool hasTwoFactor)>() + { + (orgUserDetailUserWith2FANoMP.UserId.Value, true), + (orgUserDetailUserWithout2FA.UserId.Value, false), + (orgUserDetailAdmin.UserId.Value, false), + }); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy)); + + Assert.Contains("Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .RemoveUserAsync(organizationId: default, organizationUserId: default, deletingUserId: default); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs new file mode 100644 index 000000000000..342ede9c829d --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs @@ -0,0 +1,330 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Exceptions; +using Bit.Core.Models.Data.Organizations; +using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Time.Testing; +using NSubstitute; +using Xunit; +using EventType = Bit.Core.Enums.EventType; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies; + +public class SavePolicyCommandTests +{ + [Theory, BitAutoData] + public async Task SaveAsync_NewPolicy_Success([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate) + { + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([]); + + var creationDate = sutProvider.GetDependency().Start; + + await sutProvider.Sut.SaveAsync(policyUpdate); + + await fakePolicyValidator.ValidateAsyncMock.Received(1).Invoke(policyUpdate, null); + fakePolicyValidator.OnSaveSideEffectsAsyncMock.Received(1).Invoke(policyUpdate, null); + + await AssertPolicySavedAsync(sutProvider, policyUpdate); + await sutProvider.GetDependency().Received(1).UpsertAsync(Arg.Is(p => + p.CreationDate == creationDate && + p.RevisionDate == creationDate)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExistingPolicy_Success( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy currentPolicy) + { + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) + .Returns(currentPolicy); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy]); + + // Store mutable properties separately to assert later + var id = currentPolicy.Id; + var organizationId = currentPolicy.OrganizationId; + var type = currentPolicy.Type; + var creationDate = currentPolicy.CreationDate; + var revisionDate = sutProvider.GetDependency().Start; + + await sutProvider.Sut.SaveAsync(policyUpdate); + + await fakePolicyValidator.ValidateAsyncMock.Received(1).Invoke(policyUpdate, currentPolicy); + fakePolicyValidator.OnSaveSideEffectsAsyncMock.Received(1).Invoke(policyUpdate, currentPolicy); + + await AssertPolicySavedAsync(sutProvider, policyUpdate); + // Additional assertions to ensure certain properties have or have not been updated + await sutProvider.GetDependency().Received(1).UpsertAsync(Arg.Is(p => + p.Id == id && + p.OrganizationId == organizationId && + p.Type == type && + p.CreationDate == creationDate && + p.RevisionDate == revisionDate)); + } + + [Fact] + public void Constructor_DuplicatePolicyValidators_Throws() + { + var exception = Assert.Throws(() => + new SavePolicyCommand( + Substitute.For(), + Substitute.For(), + Substitute.For(), + [new FakeSingleOrgPolicyValidator(), new FakeSingleOrgPolicyValidator()], + Substitute.For() + )); + Assert.Contains("Duplicate PolicyValidator for SingleOrg policy", exception.Message); + } + + [Theory, BitAutoData] + public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest(PolicyUpdate policyUpdate) + { + var sutProvider = SutProviderFactory(); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(policyUpdate.OrganizationId) + .Returns(Task.FromResult(null)); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policyUpdate)); + + Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest(PolicyUpdate policyUpdate) + { + var sutProvider = SutProviderFactory(); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(policyUpdate.OrganizationId) + .Returns(new OrganizationAbility + { + Id = policyUpdate.OrganizationId, + UsePolicies = false + }); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policyUpdate)); + + Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_RequiredPolicyIsNull_Throws( + [PolicyUpdate(PolicyType.RequireSso)] PolicyUpdate policyUpdate) + { + var sutProvider = SutProviderFactory([ + new FakeRequireSsoPolicyValidator(), + new FakeSingleOrgPolicyValidator() + ]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([]); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policyUpdate)); + + Assert.Contains("Turn on the Single organization policy because it is required for the Require single sign-on authentication policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_RequiredPolicyNotEnabled_Throws( + [PolicyUpdate(PolicyType.RequireSso)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy singleOrgPolicy) + { + var sutProvider = SutProviderFactory([ + new FakeRequireSsoPolicyValidator(), + new FakeSingleOrgPolicyValidator() + ]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([singleOrgPolicy]); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policyUpdate)); + + Assert.Contains("Turn on the Single organization policy because it is required for the Require single sign-on authentication policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_RequiredPolicyEnabled_Success( + [PolicyUpdate(PolicyType.RequireSso)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy) + { + var sutProvider = SutProviderFactory([ + new FakeRequireSsoPolicyValidator(), + new FakeSingleOrgPolicyValidator() + ]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([singleOrgPolicy]); + + await sutProvider.Sut.SaveAsync(policyUpdate); + await AssertPolicySavedAsync(sutProvider, policyUpdate); + } + + [Theory, BitAutoData] + public async Task SaveAsync_DependentPolicyIsEnabled_Throws( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy currentPolicy, + [Policy(PolicyType.RequireSso)] Policy requireSsoPolicy) // depends on Single Org + { + var sutProvider = SutProviderFactory([ + new FakeRequireSsoPolicyValidator(), + new FakeSingleOrgPolicyValidator() + ]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy, requireSsoPolicy]); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policyUpdate)); + + Assert.Contains("Turn off the Require single sign-on authentication policy because it requires the Single organization policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_MultipleDependentPoliciesAreEnabled_Throws( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy currentPolicy, + [Policy(PolicyType.RequireSso)] Policy requireSsoPolicy, // depends on Single Org + [Policy(PolicyType.MaximumVaultTimeout)] Policy vaultTimeoutPolicy) // depends on Single Org + { + var sutProvider = SutProviderFactory([ + new FakeRequireSsoPolicyValidator(), + new FakeSingleOrgPolicyValidator(), + new FakeVaultTimeoutPolicyValidator() + ]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy, requireSsoPolicy, vaultTimeoutPolicy]); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policyUpdate)); + + Assert.Contains("Turn off all of the policies that require the Single organization policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_DependentPolicyNotEnabled_Success( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy currentPolicy, + [Policy(PolicyType.RequireSso, false)] Policy requireSsoPolicy) // depends on Single Org but is not enabled + { + var sutProvider = SutProviderFactory([ + new FakeRequireSsoPolicyValidator(), + new FakeSingleOrgPolicyValidator() + ]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy, requireSsoPolicy]); + + await sutProvider.Sut.SaveAsync(policyUpdate); + + await AssertPolicySavedAsync(sutProvider, policyUpdate); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ThrowsOnValidationError([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate) + { + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns("Validation error!"); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([]); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policyUpdate)); + + Assert.Contains("Validation error!", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + /// + /// Returns a new SutProvider with the PolicyValidators registered in the Sut. + /// + private static SutProvider SutProviderFactory(IEnumerable? policyValidators = null) + { + return new SutProvider() + .WithFakeTimeProvider() + .SetDependency(typeof(IEnumerable), policyValidators ?? []) + .Create(); + } + + private static void ArrangeOrganization(SutProvider sutProvider, PolicyUpdate policyUpdate) + { + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(policyUpdate.OrganizationId) + .Returns(new OrganizationAbility + { + Id = policyUpdate.OrganizationId, + UsePolicies = true + }); + } + + private static async Task AssertPolicyNotSavedAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default!); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default); + } + + private static async Task AssertPolicySavedAsync(SutProvider sutProvider, PolicyUpdate policyUpdate) + { + var expectedPolicy = () => Arg.Is(p => + p.Type == policyUpdate.Type && + p.OrganizationId == policyUpdate.OrganizationId && + p.Enabled == policyUpdate.Enabled && + p.Data == policyUpdate.Data); + + await sutProvider.GetDependency().Received(1).UpsertAsync(expectedPolicy()); + + await sutProvider.GetDependency().Received(1) + .LogPolicyEventAsync(expectedPolicy(), EventType.Policy_Updated); + } +} diff --git a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs index fb08a32f2f4b..f9bc49bbe7c0 100644 --- a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs @@ -34,7 +34,6 @@ public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest( var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -61,7 +60,6 @@ public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest( var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -93,7 +91,6 @@ public async Task SaveAsync_SingleOrg_RequireSsoEnabled_ThrowsBadRequest( var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Single Sign-On Authentication policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -124,7 +121,6 @@ public async Task SaveAsync_SingleOrg_VaultTimeoutEnabled_ThrowsBadRequest([Admi var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Maximum Vault Timeout policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -161,7 +157,6 @@ public async Task SaveAsync_PolicyRequiredByKeyConnector_DisablePolicy_ThrowsBad var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Key Connector is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -189,7 +184,6 @@ public async Task SaveAsync_RequireSsoPolicy_NotEnabled_ThrowsBadRequestAsync( var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -222,7 +216,7 @@ public async Task SaveAsync_NewPolicy_Created( var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(policy, Substitute.For(), Guid.NewGuid()); + await sutProvider.Sut.SaveAsync(policy, Guid.NewGuid()); await sutProvider.GetDependency().Received() .LogPolicyEventAsync(policy, EventType.Policy_Updated); @@ -252,7 +246,6 @@ public async Task SaveAsync_VaultTimeoutPolicy_NotEnabled_ThrowsBadRequestAsync( var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -353,14 +346,13 @@ public async Task SaveAsync_ExistingPolicy_UpdateTwoFactor( (orgUserDetailAdmin, false), }); - var organizationService = Substitute.For(); var removeOrganizationUserCommand = sutProvider.GetDependency(); var utcNow = DateTime.UtcNow; var savingUserId = Guid.NewGuid(); - await sutProvider.Sut.SaveAsync(policy, organizationService, savingUserId); + await sutProvider.Sut.SaveAsync(policy, savingUserId); await removeOrganizationUserCommand.Received() .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserAcceptedWithout2FA.Id, savingUserId); @@ -468,13 +460,12 @@ public async Task SaveAsync_EnableTwoFactor_WithoutMasterPasswordOr2FA_ThrowsBad (orgUserDetailAdmin.UserId.Value, false), }); - var organizationService = Substitute.For(); var removeOrganizationUserCommand = sutProvider.GetDependency(); var savingUserId = Guid.NewGuid(); var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, organizationService, savingUserId)); + () => sutProvider.Sut.SaveAsync(policy, savingUserId)); Assert.Contains("Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -541,13 +532,11 @@ public async Task SaveAsync_ExistingPolicy_UpdateSingleOrg( (orgUserDetail.UserId.Value, false), }); - var organizationService = Substitute.For(); - var utcNow = DateTime.UtcNow; var savingUserId = Guid.NewGuid(); - await sutProvider.Sut.SaveAsync(policy, organizationService, savingUserId); + await sutProvider.Sut.SaveAsync(policy, savingUserId); await sutProvider.GetDependency().Received() .LogPolicyEventAsync(policy, EventType.Policy_Updated); @@ -590,7 +579,6 @@ public async Task SaveAsync_ResetPasswordPolicyRequiredByTrustedDeviceEncryption var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Trusted device encryption is on and requires this policy.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -626,7 +614,6 @@ public async Task SaveAsync_RequireSsoPolicyRequiredByTrustedDeviceEncryption_Di var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Trusted device encryption is on and requires this policy.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -659,7 +646,6 @@ public async Task SaveAsync_PolicyRequiredForAccountRecovery_NotEnabled_ThrowsBa var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); @@ -692,7 +678,6 @@ public async Task SaveAsync_SingleOrg_AccountRecoveryEnabled_ThrowsBadRequest( var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), Guid.NewGuid())); Assert.Contains("Account recovery policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); diff --git a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs index fb566537abcd..e397c838c68f 100644 --- a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs +++ b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs @@ -11,7 +11,6 @@ using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -342,14 +341,12 @@ public async Task SaveAsync_Tde_Enable_Required_Policies(SutProvider().Received(1) .SaveAsync( Arg.Is(t => t.Type == PolicyType.SingleOrg), - Arg.Any(), null ); await sutProvider.GetDependency().Received(1) .SaveAsync( Arg.Is(t => t.Type == PolicyType.ResetPassword && t.GetDataModel().AutoEnrollEnabled), - Arg.Any(), null ); diff --git a/test/Core.Test/NotificationHub/NotificationHubConnectionTests.cs b/test/Core.Test/NotificationHub/NotificationHubConnectionTests.cs new file mode 100644 index 000000000000..0d7382b3cc0e --- /dev/null +++ b/test/Core.Test/NotificationHub/NotificationHubConnectionTests.cs @@ -0,0 +1,205 @@ +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Xunit; + +namespace Bit.Core.Test.NotificationHub; + +public class NotificationHubConnectionTests +{ + [Fact] + public void IsValid_ConnectionStringIsNull_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = null, + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + + // Act + var connection = NotificationHubConnection.From(hub); + + // Assert + Assert.False(connection.IsValid); + } + + [Fact] + public void IsValid_HubNameIsNull_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "Endpoint=sb://example.servicebus.windows.net/;", + HubName = null, + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + + // Act + var connection = NotificationHubConnection.From(hub); + + // Assert + Assert.False(connection.IsValid); + } + + [Fact] + public void IsValid_ConnectionStringAndHubNameAreNotNull_ReturnsTrue() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + + // Act + var connection = NotificationHubConnection.From(hub); + + // Assert + Assert.True(connection.IsValid); + } + + [Fact] + public void RegistrationEnabled_QueryTimeIsBeforeStartDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow.AddDays(1), + RegistrationEndDate = DateTime.UtcNow.AddDays(2) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(DateTime.UtcNow); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_QueryTimeIsAfterEndDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(DateTime.UtcNow.AddDays(2)); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_NullStartDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = null, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(DateTime.UtcNow); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_QueryTimeIsBetweenStartDateAndEndDate_ReturnsTrue() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(DateTime.UtcNow.AddHours(1)); + + // Assert + Assert.True(result); + } + + [Fact] + public void RegistrationEnabled_CombTimeIsBeforeStartDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow.AddDays(1), + RegistrationEndDate = DateTime.UtcNow.AddDays(2) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow)); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_CombTimeIsAfterEndDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow.AddDays(2))); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_CombTimeIsBetweenStartDateAndEndDate_ReturnsTrue() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow.AddHours(1))); + + // Assert + Assert.True(result); + } +} diff --git a/test/Core.Test/NotificationHub/NotificationHubPoolTests.cs b/test/Core.Test/NotificationHub/NotificationHubPoolTests.cs new file mode 100644 index 000000000000..dd9afb867ee1 --- /dev/null +++ b/test/Core.Test/NotificationHub/NotificationHubPoolTests.cs @@ -0,0 +1,156 @@ +using Bit.Core.NotificationHub; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; +using static Bit.Core.Settings.GlobalSettings; + +namespace Bit.Core.Test.NotificationHub; + +public class NotificationHubPoolTests +{ + [Fact] + public void NotificationHubPool_WarnsOnMissingConnectionString() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = null, + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + } + } + } + }; + var logger = Substitute.For>(); + + // Act + var sut = new NotificationHubPool(logger, globalSettings); + + // Assert + logger.Received().Log(LogLevel.Warning, Arg.Any(), + Arg.Is(o => o.ToString() == "Invalid notification hub settings: hub"), + null, + Arg.Any>()); + } + + [Fact] + public void NotificationHubPool_WarnsOnMissingHubName() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = "connection", + HubName = null, + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + } + } + } + }; + var logger = Substitute.For>(); + + // Act + var sut = new NotificationHubPool(logger, globalSettings); + + // Assert + logger.Received().Log(LogLevel.Warning, Arg.Any(), + Arg.Is(o => o.ToString() == "Invalid notification hub settings: hub name missing"), + null, + Arg.Any>()); + } + + [Fact] + public void NotificationHubPool_ClientFor_ThrowsOnNoValidHubs() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = null, + RegistrationEndDate = null, + } + } + } + }; + var logger = Substitute.For>(); + var sut = new NotificationHubPool(logger, globalSettings); + + // Act + Action act = () => sut.ClientFor(Guid.NewGuid()); + + // Assert + Assert.Throws(act); + } + + [Fact] + public void NotificationHubPool_ClientFor_ReturnsClient() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = "Endpoint=sb://example.servicebus.windows.net/;SharedAccessKey=example///example=", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow.AddMinutes(-1), + RegistrationEndDate = DateTime.UtcNow.AddDays(1), + } + } + } + }; + var logger = Substitute.For>(); + var sut = new NotificationHubPool(logger, globalSettings); + + // Act + var client = sut.ClientFor(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow)); + + // Assert + Assert.NotNull(client); + } + + [Fact] + public void NotificationHubPool_AllClients_ReturnsProxy() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1), + } + } + } + }; + var logger = Substitute.For>(); + var sut = new NotificationHubPool(logger, globalSettings); + + // Act + var proxy = sut.AllClients; + + // Assert + Assert.NotNull(proxy); + } +} diff --git a/test/Core.Test/NotificationHub/NotificationHubProxyTests.cs b/test/Core.Test/NotificationHub/NotificationHubProxyTests.cs new file mode 100644 index 000000000000..b2e9c4f9f330 --- /dev/null +++ b/test/Core.Test/NotificationHub/NotificationHubProxyTests.cs @@ -0,0 +1,40 @@ +using AutoFixture; +using Bit.Core.NotificationHub; +using Bit.Test.Common.AutoFixture; +using Microsoft.Azure.NotificationHubs; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.NotificationHub; + +public class NotificationHubProxyTests +{ + private readonly IEnumerable _clients; + public NotificationHubProxyTests() + { + _clients = new Fixture().WithAutoNSubstitutions().CreateMany(); + } + + public static IEnumerable ClientMethods = + [ + [ + (NotificationHubClientProxy c) => c.SendTemplateNotificationAsync(new Dictionary() { { "key", "value" } }, "tag"), + (INotificationHubClient c) => c.SendTemplateNotificationAsync(Arg.Is>((a) => a.Keys.Count == 1 && a.ContainsKey("key") && a["key"] == "value"), "tag"), + ], + ]; + + [Theory] + [MemberData(nameof(ClientMethods))] + public async void CallsAllClients(Func proxyMethod, Func clientMethod) + { + var clients = _clients.ToArray(); + var proxy = new NotificationHubClientProxy(clients); + + await proxyMethod(proxy); + + foreach (var client in clients) + { + await clientMethod(client.Received()); + } + } +} diff --git a/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs b/test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs similarity index 81% rename from test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs rename to test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs index 82594445a6e1..ea9ce54131e8 100644 --- a/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs +++ b/test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs @@ -1,32 +1,32 @@ -using Bit.Core.Repositories; +using Bit.Core.NotificationHub; +using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.NotificationHub; public class NotificationHubPushNotificationServiceTests { private readonly NotificationHubPushNotificationService _sut; private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; + private readonly INotificationHubPool _notificationHubPool; private readonly IHttpContextAccessor _httpContextAccessor; private readonly ILogger _logger; public NotificationHubPushNotificationServiceTests() { _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); _httpContextAccessor = Substitute.For(); + _notificationHubPool = Substitute.For(); _logger = Substitute.For>(); _sut = new NotificationHubPushNotificationService( _installationDeviceRepository, - _globalSettings, + _notificationHubPool, _httpContextAccessor, _logger ); diff --git a/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs b/test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs similarity index 82% rename from test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs rename to test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs index a8dd536b87d7..c5851f279148 100644 --- a/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs +++ b/test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs @@ -1,11 +1,11 @@ -using Bit.Core.Repositories; -using Bit.Core.Services; +using Bit.Core.NotificationHub; +using Bit.Core.Repositories; using Bit.Core.Settings; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.NotificationHub; public class NotificationHubPushRegistrationServiceTests { @@ -15,6 +15,7 @@ public class NotificationHubPushRegistrationServiceTests private readonly IServiceProvider _serviceProvider; private readonly ILogger _logger; private readonly GlobalSettings _globalSettings; + private readonly INotificationHubPool _notificationHubPool; public NotificationHubPushRegistrationServiceTests() { @@ -22,10 +23,12 @@ public NotificationHubPushRegistrationServiceTests() _serviceProvider = Substitute.For(); _logger = Substitute.For>(); _globalSettings = new GlobalSettings(); + _notificationHubPool = Substitute.For(); _sut = new NotificationHubPushRegistrationService( _installationDeviceRepository, _globalSettings, + _notificationHubPool, _serviceProvider, _logger ); diff --git a/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs b/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs index b1876f1dda7f..68d6c50a7ede 100644 --- a/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs +++ b/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs @@ -1,10 +1,10 @@ -using Bit.Core.Repositories; +using AutoFixture; using Bit.Core.Services; -using Bit.Core.Settings; -using Microsoft.AspNetCore.Http; +using Bit.Test.Common.AutoFixture; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; +using GlobalSettingsCustomization = Bit.Test.Common.AutoFixture.GlobalSettings; namespace Bit.Core.Test.Services; @@ -12,35 +12,26 @@ public class MultiServicePushNotificationServiceTests { private readonly MultiServicePushNotificationService _sut; - private readonly IHttpClientFactory _httpFactory; - private readonly IDeviceRepository _deviceRepository; - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; private readonly ILogger _logger; private readonly ILogger _relayLogger; private readonly ILogger _hubLogger; + private readonly IEnumerable _services; + private readonly Settings.GlobalSettings _globalSettings; public MultiServicePushNotificationServiceTests() { - _httpFactory = Substitute.For(); - _deviceRepository = Substitute.For(); - _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); _logger = Substitute.For>(); _relayLogger = Substitute.For>(); _hubLogger = Substitute.For>(); + var fixture = new Fixture().WithAutoNSubstitutions().Customize(new GlobalSettingsCustomization()); + _services = fixture.CreateMany(); + _globalSettings = fixture.Create(); + _sut = new MultiServicePushNotificationService( - _httpFactory, - _deviceRepository, - _installationDeviceRepository, - _globalSettings, - _httpContextAccessor, + _services, _logger, - _relayLogger, - _hubLogger + _globalSettings ); } diff --git a/test/Core.Test/Utilities/CoreHelpersTests.cs b/test/Core.Test/Utilities/CoreHelpersTests.cs index af115679893a..2cce276fcb8a 100644 --- a/test/Core.Test/Utilities/CoreHelpersTests.cs +++ b/test/Core.Test/Utilities/CoreHelpersTests.cs @@ -34,33 +34,30 @@ public void GenerateComb_Success() // the comb are working properly } - public static IEnumerable GenerateCombCases = new[] - { - new object[] - { + public static IEnumerable GuidSeedCases = [ + [ Guid.Parse("a58db474-43d8-42f1-b4ee-0c17647cd0c0"), // Input Guid new DateTime(2022, 3, 12, 12, 12, 0, DateTimeKind.Utc), // Input Time - Guid.Parse("a58db474-43d8-42f1-b4ee-ae5600c90cc1"), // Expected Comb - }, - new object[] - { + ], + [ Guid.Parse("f776e6ee-511f-4352-bb28-88513002bdeb"), new DateTime(2021, 5, 10, 10, 52, 0, DateTimeKind.Utc), - Guid.Parse("f776e6ee-511f-4352-bb28-ad2400b313c1"), - }, - new object[] - { + ], + [ Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011648a1"), new DateTime(1999, 2, 26, 16, 53, 13, DateTimeKind.Utc), - Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011649cd"), - }, - new object[] - { + ], + [ Guid.Parse("bfb8f353-3b32-4a9e-bef6-24fe0b54bfb0"), new DateTime(2024, 10, 20, 1, 32, 16, DateTimeKind.Utc), - Guid.Parse("bfb8f353-3b32-4a9e-bef6-b20f00195780"), - } - }; + ] + ]; + public static IEnumerable GenerateCombCases = GuidSeedCases.Zip([ + Guid.Parse("a58db474-43d8-42f1-b4ee-ae5600c90cc1"), // Expected Comb for each Guid Seed case + Guid.Parse("f776e6ee-511f-4352-bb28-ad2400b313c1"), + Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011649cd"), + Guid.Parse("bfb8f353-3b32-4a9e-bef6-b20f00195780"), + ]).Select((zip) => new object[] { zip.Item1[0], zip.Item1[1], zip.Item2 }); [Theory] [MemberData(nameof(GenerateCombCases))] @@ -71,6 +68,31 @@ public void GenerateComb_WithInputs_Success(Guid inputGuid, DateTime inputTime, Assert.Equal(expectedComb, comb); } + [Theory] + [MemberData(nameof(GuidSeedCases))] + public void DateFromComb_WithComb_Success(Guid inputGuid, DateTime inputTime) + { + var comb = CoreHelpers.GenerateComb(inputGuid, inputTime); + var inverseComb = CoreHelpers.DateFromComb(comb); + + Assert.Equal(inputTime, inverseComb, TimeSpan.FromMilliseconds(4)); + } + + [Theory] + [InlineData("00000000-0000-0000-0000-000000000000", 1, 0)] + [InlineData("00000000-0000-0000-0000-000000000001", 1, 0)] + [InlineData("00000000-0000-0000-0000-000000000000", 500, 430)] + [InlineData("00000000-0000-0000-0000-000000000001", 500, 430)] + [InlineData("10000000-0000-0000-0000-000000000001", 500, 454)] + [InlineData("00000000-0000-0100-0000-000000000001", 500, 19)] + public void BinForComb_Success(string guidString, int nbins, int expectedBin) + { + var guid = Guid.Parse(guidString); + var bin = CoreHelpers.BinForComb(guid, nbins); + + Assert.Equal(expectedBin, bin); + } + /* [Fact] public void ToGuidIdArrayTVP_Success()