From 3ae94baf6908a6f25177ea21cd2f2e0d3a5b808b Mon Sep 17 00:00:00 2001 From: Simon Kok Date: Wed, 18 Jan 2023 13:52:43 +0100 Subject: [PATCH] Fix bootstrapping accounts in non-protected OUs only (#590) **Why?** While running the `adf-build/main.py` script: 1. It would get the list of accounts of this specific AWS Organization. 2. Spin up a thread for any of the non-deployment accounts. 3. In the thread, check if the account is in the root or in a protected OU. If so, it would stop the thread. If not, it would deploy the bootstrap templates. So far so good. 4. When invoking the Step Function to enable the cross-account access, it would use the list of account ids it retrieved before. However, this list is unfiltered. Therefore, the Step Function State Machine would try to enable the cross account access on accounts that were in the root and/or in protected organization units. **What?** We would only need to bootstrap accounts that are: * Active, * Not in the AWS Organization root, and * Not in an AWS Organization OU that is listed as protected. ADF should also fix the cross-account access in those accounts only. As the others don't have a bootstrap template deployed that needs to be updated any way. Thus, instead of introducing the same logic in the Step Function. This change set moved the responsibility for filtering the accounts based on their state and location in the Organizations class. Tests were added to validate that this works correctly. Additionally, fixes were introduced to reduce the line lengths where needed. --- .../bootstrap_repository/adf-build/main.py | 28 +- .../adf-build/shared/python/organizations.py | 190 ++++++++++--- .../shared/python/tests/test_organizations.py | 255 +++++++++++++++++- .../adf-build/tests/test_main.py | 31 +-- 4 files changed, 419 insertions(+), 85 deletions(-) diff --git a/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/main.py b/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/main.py index 70eb1fb47..d1fed23be 100644 --- a/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/main.py +++ b/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/main.py @@ -59,24 +59,6 @@ LOGGER = configure_logger(__name__) -def is_account_in_invalid_state(ou_id, config): - """ - Check if Account is sitting in the root - of the Organization or in Protected OU - """ - if ou_id.startswith('r-'): - return "Is in the Root of the Organization, it will be skipped." - - protected = config.get('protected', []) - if ou_id in protected: - return ( - f"Is in a protected Organizational Unit {ou_id}, " - "it will be skipped." - ) - - return False - - def ensure_generic_account_can_be_setup(sts, config, account_id): """ If the target account has been configured returns the role to assume @@ -233,11 +215,6 @@ def worker_thread( ) ou_id = organizations.get_parent_info().get("ou_parent_id") - account_state = is_account_in_invalid_state(ou_id, config.config) - if account_state: - LOGGER.info("%s %s", account_id, account_state) - return - account_path = organizations.build_account_path( ou_id, [], # Initial empty array to hold OU Path, @@ -490,7 +467,10 @@ def main(): # pylint: disable=R0915 threads = [] account_ids = [ account_id["Id"] - for account_id in organizations.get_accounts() + for account_id in organizations.get_accounts( + protected_ou_ids=config.config.get('protected'), + include_root=False, + ) ] non_deployment_account_ids = [ account for account in account_ids diff --git a/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/shared/python/organizations.py b/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/shared/python/organizations.py index 8bde1f91f..3cea74ed4 100644 --- a/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/shared/python/organizations.py +++ b/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/shared/python/organizations.py @@ -40,24 +40,49 @@ def __init__(self, role, account_id=None): config=Organizations._config ) self.account_id = account_id - self.account_ids = [] self.root_id = None - def get_parent_info(self): - response = self.list_parents(self.account_id) + def get_parent_info(self, account_id=None): + """ + Get the parent info of the account_id specified. If no specific + account id is specified, it will use the account_id setup when + initiating the Organizations instance. + + Args: + account_id (str|None): The specific account id if any. + + Returns: + dict: The ou_parent_id and ou_parent_type are returned in a + dictionary. + """ + response = self.list_parents(account_id or self.account_id) return { "ou_parent_id": response.get('Id'), "ou_parent_type": response.get('Type') } - def enable_organization_policies(self, policy_type='SERVICE_CONTROL_POLICY'): # or 'TAG_POLICY' + def enable_organization_policies( + self, + policy_type='SERVICE_CONTROL_POLICY', + ): + """ + Enable the policies on the organization unit root id. + + Args: + policy_type (str): + The policy type, either 'SERVICE_CONTROL_POLICY' or + 'TAG_POLICY'. It defaults to the 'SERVICE_CONTROL_POLICY'. + """ try: self.client.enable_policy_type( RootId=self.get_ou_root_id(), PolicyType=policy_type ) except self.client.exceptions.PolicyTypeAlreadyEnabledException: - LOGGER.info('%s are currently enabled within the Organization', policy_type) + LOGGER.info( + '%s are currently enabled within the Organization', + policy_type, + ) @staticmethod def trim_policy_path(policy): @@ -73,22 +98,45 @@ def get_organization_map(self, org_structure, counter=0): if not Organizations.is_ou_id(ou_id): continue # List OUs - for organization_id in [organization_id['Id'] for organization_id in paginator(self.client.list_children, **{"ParentId":ou_id, "ChildType":"ORGANIZATIONAL_UNIT"})]: + for organization_id in [ + ou_data['Id'] for ou_data in paginator( + self.client.list_children, + **{ + "ParentId": ou_id, + "ChildType": "ORGANIZATIONAL_UNIT", + }, + ) + ]: if organization_id in org_structure.values() and counter != 0: continue ou_name = self.describe_ou_name(organization_id) - trimmed_path = Organizations.trim_policy_path(f"{name}/{ou_name}") + trimmed_path = Organizations.trim_policy_path( + f"{name}/{ou_name}", + ) org_structure[trimmed_path] = organization_id # List accounts - for account_id in [account_id['Id'] for account_id in paginator(self.client.list_children, **{"ParentId":ou_id, "ChildType":"ACCOUNT"})]: + for account_id in [ + account_data['Id'] for account_data in paginator( + self.client.list_children, + **{ + "ParentId": ou_id, + "ChildType": "ACCOUNT", + } + ) + ]: if account_id in org_structure.values() and counter != 0: continue account_name = self.describe_account_name(account_id) - trimmed_path = Organizations.trim_policy_path(f"{name}/{account_name}") + trimmed_path = Organizations.trim_policy_path( + f"{name}/{account_name}", + ) org_structure[trimmed_path] = account_id counter = counter + 1 # Counter is greater than 5 here is the conditional as organizations cannot have more than 5 levels of nested OUs + 1 accounts "level" - return org_structure if counter > 5 else self.get_organization_map(org_structure, counter) + return ( + org_structure if counter > 5 + else self.get_organization_map(org_structure, counter) + ) def update_policy(self, content, policy_id): self.client.update_policy( @@ -96,7 +144,12 @@ def update_policy(self, content, policy_id): Content=content ) - def create_policy(self, content, ou_path, policy_type="SERVICE_CONTROL_POLICY"): + def create_policy( + self, + content, + ou_path, + policy_type="SERVICE_CONTROL_POLICY", + ): policy_type_name = ( 'scp' if policy_type == "SERVICE_CONTROL_POLICY" else 'tagging-policy' @@ -111,29 +164,42 @@ def create_policy(self, content, ou_path, policy_type="SERVICE_CONTROL_POLICY"): @staticmethod def get_policy_body(path): - with open(f'./adf-bootstrap/{path}', mode='r', encoding='utf-8') as policy: + bootstrap_path = f'./adf-bootstrap/{path}' + with open(bootstrap_path, mode='r', encoding='utf-8') as policy: return json.dumps(json.load(policy)) def list_policies(self, name, policy_type="SERVICE_CONTROL_POLICY"): - response = list(paginator(self.client.list_policies, Filter=policy_type)) - try: - return [policy for policy in response if policy['Name'] == name][0]['Id'] - except IndexError: - return [] - - def describe_policy_id_for_target(self, target_id, policy_type='SERVICE_CONTROL_POLICY'): + response = list( + paginator(self.client.list_policies, Filter=policy_type) + ) + filtered_policies = [ + policy for policy in response + if policy['Name'] == name + ] + if len(filtered_policies) > 0: + return filtered_policies[0]['Id'] + return [] + + def describe_policy_id_for_target( + self, + target_id, + policy_type='SERVICE_CONTROL_POLICY', + ): response = self.client.list_policies_for_target( TargetId=target_id, Filter=policy_type ) - try: - return [p for p in response['Policies'] if f'ADF Managed {policy_type}' in p['Description']][0]['Id'] - except IndexError: - return [] + adf_managed_policies = [ + policy for policy in response['Policies'] + if f'ADF Managed {policy_type}' in policy['Description'] + ] + if len(adf_managed_policies) > 0: + return adf_managed_policies[0]['Id'] + return [] def describe_policy(self, policy_id): response = self.client.describe_policy( - PolicyId=policy_id + PolicyId=policy_id, ) return response.get('Policy') @@ -141,7 +207,7 @@ def attach_policy(self, policy_id, target_id): try: self.client.attach_policy( PolicyId=policy_id, - TargetId=target_id + TargetId=target_id, ) except self.client.exceptions.DuplicatePolicyAttachmentException: pass @@ -157,13 +223,66 @@ def delete_policy(self, policy_id): PolicyId=policy_id ) - def get_accounts(self): + def _account_available_to_adf( + self, + account, + protected_ou_ids, + include_root, + ): + if protected_ou_ids or not include_root: + account_ou_id = ( + self.get_parent_info(account["Id"]).get("ou_parent_id") + ) + if not include_root and account_ou_id.startswith("r-"): + LOGGER.info( + "Account %s is in the root of the AWS Organization, " + "therefore skipping it", + account["Id"], + ) + return False + if protected_ou_ids and account_ou_id in protected_ou_ids: + LOGGER.info( + "Account %s is in OU %s which is marked as protected, " + "therefore skipping it", + account["Id"], + account_ou_id, + ) + return False + if account.get("Status") != "ACTIVE": + LOGGER.warning( + "Account %s is not an active AWS Account, state reported: %s", + account["Id"], + account.get("Status"), + ) + return False + return True + + def get_accounts(self, protected_ou_ids=None, include_root=True): + """ + Get the accounts from this AWS Organizations. + Filtered by the given arguments if required. + + Args: + protected_ou_ids (list(str)): The list of protected organization + unit ids as configured in the adfconfig.yml file. + The organization unit ids are structured like: ou-123. + + include_root (bool): Whether or not to include accounts that are + located in the root of the AWS Organization. + ADF does not adopt these accounts. + + Returns: + list(str): The list of account details, filtered as requested. + """ + accounts = [] for account in paginator(self.client.list_accounts): - if not account.get('Status') == 'ACTIVE': - LOGGER.warning('Account %s is not an Active AWS Account', account['Id']) - continue - self.account_ids.append(account) - return self.account_ids + if self._account_available_to_adf( + account, + protected_ou_ids, + include_root, + ): + accounts.append(account) + return accounts def get_organization_info(self): response = self.client.describe_organization() @@ -183,7 +302,9 @@ def describe_ou_name(self, ou_id): ) return response['OrganizationalUnit']['Name'] except ClientError as error: - raise RootOUIDError("OU is the Root of the Organization") from error + raise RootOUIDError( + "OU is the Root of the Organization", + ) from error def describe_account_name(self, account_id): try: @@ -192,7 +313,10 @@ def describe_account_name(self, account_id): ) return response['Account']['Name'] except ClientError as error: - LOGGER.error('Failed to retrieve account name for account ID %s', account_id) + LOGGER.error( + "Failed to retrieve account name for account ID %s", + account_id, + ) raise error @staticmethod diff --git a/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/shared/python/tests/test_organizations.py b/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/shared/python/tests/test_organizations.py index ec58c2cf2..ba47554e0 100644 --- a/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/shared/python/tests/test_organizations.py +++ b/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/shared/python/tests/test_organizations.py @@ -28,11 +28,252 @@ def test_get_parent_info(cls): "ou_parent_id": 'some_id', "ou_parent_type": 'ORGANIZATIONAL_UNIT' } + cls.client.list_parents.assert_called_once_with( + ChildId=cls.account_id, + ) + + +def test_get_parent_info_specific_account(cls): + specific_account_id = '111111111111' + cls.client = Mock() + cls.client.list_parents.return_value = stub_organizations.list_parents + assert cls.get_parent_info(specific_account_id) == { + "ou_parent_id": 'some_id', + "ou_parent_type": 'ORGANIZATIONAL_UNIT' + } + cls.client.list_parents.assert_called_once_with( + ChildId=specific_account_id, + ) + + +@patch('organizations.paginator') +def test_get_accounts(paginator_mock, cls): + all_account_ids = [ + '111111111111', + '222222222222', + '333333333333', + '444444444444', + ] + root_account_ids = [ + '333333333333', + ] + cls.client = Mock() + cls.client.list_parents.side_effect = lambda account_id: ( + { + "Id": ( + f"r-{account_id}" if account_id in root_account_ids + else f"ou-{account_id}" + ), + "Type": "ORGANIZATIONAL_UNIT", + } + ) + paginator_mock.return_value = list(map( + lambda account_id: ({ + "Id": account_id, + "Status": "ACTIVE", + }), + all_account_ids, + )) + assert set(map( + lambda account: account['Id'], + cls.get_accounts(), + )) == set(all_account_ids) + + +@patch('organizations.paginator') +def test_get_accounts_with_suspended(paginator_mock, cls): + all_account_ids = [ + '111111111111', + '222222222222', + '333333333333', + '444444444444', + ] + root_account_ids = [ + '333333333333', + ] + suspended_account_ids = [ + '444444444444', + ] + cls.client = Mock() + cls.client.list_parents.side_effect = lambda account_id: ( + { + "Id": ( + f"r-{account_id}" if account_id in root_account_ids + else f"ou-{account_id}" + ), + "Type": "ORGANIZATIONAL_UNIT", + } + ) + paginator_mock.return_value = list(map( + lambda account_id: ({ + "Id": account_id, + "Status": ( + "SUSPENDED" if account_id in suspended_account_ids + else "ACTIVE" + ), + }), + all_account_ids, + )) + assert set(map( + lambda account: account['Id'], + cls.get_accounts(), + )) == (set(all_account_ids) - set(suspended_account_ids)) + + +@patch('organizations.paginator') +def test_get_accounts_ignore_root(paginator_mock, cls): + all_account_ids = [ + '111111111111', + '222222222222', + '333333333333', + '444444444444', + ] + root_account_ids = [ + '444444444444', + ] + cls.client = Mock() + cls.client.list_parents.side_effect = lambda ChildId: ({ + "Parents": [{ + "Id": ( + f"r-{ChildId}" if ChildId in root_account_ids + else f"ou-{ChildId}" + ), + "Type": "ORGANIZATIONAL_UNIT", + }], + }) + paginator_mock.return_value = list(map( + lambda account_id: ({ + "Id": account_id, + "Status": "ACTIVE", + }), + all_account_ids, + )) + assert set(map( + lambda account: account['Id'], + cls.get_accounts( + include_root=False, + ), + )) == (set(all_account_ids) - set(root_account_ids)) + + +@patch('organizations.paginator') +def test_get_accounts_ignore_protected(paginator_mock, cls): + all_account_ids = [ + '111111111111', + '222222222222', + '333333333333', + '444444444444', + ] + root_account_ids = [ + '444444444444', + ] + protected_account_ids = [ + '222222222222', + ] + protected_ou_ids = list(map( + lambda account_id: f"ou-{account_id}", + protected_account_ids, + )) + cls.client = Mock() + cls.client.list_parents.side_effect = lambda ChildId: ({ + "Parents": [{ + "Id": ( + f"r-{ChildId}" if ChildId in root_account_ids + else f"ou-{ChildId}" + ), + "Type": "ORGANIZATIONAL_UNIT", + }], + }) + paginator_mock.return_value = list(map( + lambda account_id: ({ + "Id": account_id, + "Status": "ACTIVE", + }), + all_account_ids, + )) + assert set(map( + lambda account: account['Id'], + cls.get_accounts( + protected_ou_ids=protected_ou_ids, + ), + )) == (set(all_account_ids) - set(protected_account_ids)) + + +@patch('organizations.paginator') +def test_get_accounts_ignore_root_protected_and_inactive(paginator_mock, cls): + all_account_ids = [ + '111111111111', + '222222222222', + '333333333333', + '444444444444', + '555555555555', + '666666666666', + '777777777777', + '888888888888', + ] + protected_account_ids = [ + '222222222222', + '777777777777', + ] + root_account_ids = [ + '333333333333', + '888888888888', + ] + suspended_account_ids = [ + '444444444444', + ] + pending_closure_account_ids = [ + '555555555555', + ] + protected_ou_ids = list(map( + lambda account_id: f"ou-{account_id}", + protected_account_ids, + )) + cls.client = Mock() + cls.client.list_parents.side_effect = lambda ChildId: ({ + "Parents": [{ + "Id": ( + f"r-{ChildId}" if ChildId in root_account_ids + else f"ou-{ChildId}" + ), + "Type": "ORGANIZATIONAL_UNIT", + }], + }) + paginator_mock.return_value = list(map( + lambda account_id: ({ + "Id": account_id, + "Status": ( + "SUSPENDED" + if account_id in suspended_account_ids + else ( + "PENDING_CLOSURE" + if account_id in pending_closure_account_ids + else "ACTIVE" + ) + ), + }), + all_account_ids, + )) + assert set(map( + lambda account: account['Id'], + cls.get_accounts( + protected_ou_ids=protected_ou_ids, + include_root=False, + ), + )) == ( + set(all_account_ids) + - set(protected_account_ids) + - set(root_account_ids) + - set(suspended_account_ids) + - set(pending_closure_account_ids) + ) def test_get_organization_info(cls): cls.client = Mock() - cls.client.describe_organization.return_value = stub_organizations.describe_organization + cls.client.describe_organization.return_value = ( + stub_organizations.describe_organization + ) assert cls.get_organization_info() == { 'organization_id': 'some_org_id', 'organization_master_account_id': 'some_master_account_id', @@ -42,13 +283,17 @@ def test_get_organization_info(cls): def test_describe_ou_name(cls): cls.client = Mock() - cls.client.describe_organizational_unit.return_value = stub_organizations.describe_organizational_unit + cls.client.describe_organizational_unit.return_value = ( + stub_organizations.describe_organizational_unit + ) assert cls.describe_ou_name('some_ou_id') == 'some_ou_name' def test_describe_account_name(cls): cls.client = Mock() - cls.client.describe_account.return_value = stub_organizations.describe_account + cls.client.describe_account.return_value = ( + stub_organizations.describe_account + ) assert cls.describe_account_name('some_account_id') == 'some_account_name' @@ -66,6 +311,8 @@ def test_build_account_path(cls): cls.client = Mock() cache = Cache() cls.client.list_parents.return_value = stub_organizations.list_parents_root - cls.client.describe_organizational_unit.return_value = stub_organizations.describe_organizational_unit + cls.client.describe_organizational_unit.return_value = ( + stub_organizations.describe_organizational_unit + ) assert cls.build_account_path('some_ou_id', [], cache) == 'some_ou_name' diff --git a/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/tests/test_main.py b/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/tests/test_main.py index 93ea8bd96..7ff8ba183 100644 --- a/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/tests/test_main.py +++ b/src/lambda_codebase/initial_commit/bootstrap_repository/adf-build/tests/test_main.py @@ -8,7 +8,11 @@ from pytest import fixture from parameter_store import ParameterStore from mock import Mock, patch, call -from main import * +from main import ( + Config, + ensure_generic_account_can_be_setup, + update_deployment_account_output_parameters, +) @fixture @@ -41,26 +45,6 @@ def sts(): return sts -def test_is_account_valid_state(cls): - assert is_account_in_invalid_state('ou-123', cls.__dict__) == False - - -def test_is_account_in_invalid_state(cls): - cls.protected = [] - cls.protected.append('ou-123') - assert is_account_in_invalid_state('ou-123', cls.__dict__) == ( - 'Is in a protected Organizational Unit ou-123, it will be skipped.' - ) - - - -def test_is_account_is_in_root(cls): - assert is_account_in_invalid_state('r-123', cls.__dict__) == ( - 'Is in the Root of the Organization, it will be skipped.' - ) - - - def test_ensure_generic_account_can_be_setup(cls, sts): assert ensure_generic_account_can_be_setup(sts, cls, '123456789012') == ( sts.assume_cross_account_role() @@ -68,8 +52,8 @@ def test_ensure_generic_account_can_be_setup(cls, sts): def test_update_deployment_account_output_parameters(cls, sts): - cloudformation=Mock() - parameter_store=Mock() + cloudformation = Mock() + parameter_store = Mock() parameter_store.client.put_parameter.return_value = True cloudformation.get_stack_regional_outputs.return_value = { "kms_arn": 'some_kms_arn', @@ -86,7 +70,6 @@ def test_update_deployment_account_output_parameters(cls, sts): 'some_s3_bucket', ), ] - kms_and_bucket_dict={} update_deployment_account_output_parameters( deployment_account_region='eu-central-1', region='eu-central-1',