Skip to content

Commit

Permalink
Fix bootstrapping accounts in non-protected OUs only (#590)
Browse files Browse the repository at this point in the history
**Why?**

While running the `adf-build/main.py` script:
1. It would get the list of accounts of this specific AWS Organization.
2. Spin up a thread for any of the non-deployment accounts.
3. In the thread, check if the account is in the root or in a protected OU. If
   so, it would stop the thread. If not, it would deploy the bootstrap
   templates. So far so good.
4. When invoking the Step Function to enable the cross-account access, it would
   use the list of account ids it retrieved before. However, this list is
   unfiltered.

Therefore, the Step Function State Machine would try to enable
the cross account access on accounts that were in the root and/or in
protected organization units.

**What?**

We would only need to bootstrap accounts that are:
* Active,
* Not in the AWS Organization root, and
* Not in an AWS Organization OU that is listed as protected.

ADF should also fix the cross-account access in those accounts only.
As the others don't have a bootstrap template deployed that needs to be updated
any way.

Thus, instead of introducing the same logic in the Step Function. This change
set moved the responsibility for filtering the accounts based on their state
and location in the Organizations class.

Tests were added to validate that this works correctly.

Additionally, fixes were introduced to reduce the line lengths where needed.
  • Loading branch information
sbkok authored Jan 18, 2023
1 parent 947e200 commit 3ae94ba
Show file tree
Hide file tree
Showing 4 changed files with 419 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,6 @@
LOGGER = configure_logger(__name__)


def is_account_in_invalid_state(ou_id, config):
"""
Check if Account is sitting in the root
of the Organization or in Protected OU
"""
if ou_id.startswith('r-'):
return "Is in the Root of the Organization, it will be skipped."

protected = config.get('protected', [])
if ou_id in protected:
return (
f"Is in a protected Organizational Unit {ou_id}, "
"it will be skipped."
)

return False


def ensure_generic_account_can_be_setup(sts, config, account_id):
"""
If the target account has been configured returns the role to assume
Expand Down Expand Up @@ -233,11 +215,6 @@ def worker_thread(
)
ou_id = organizations.get_parent_info().get("ou_parent_id")

account_state = is_account_in_invalid_state(ou_id, config.config)
if account_state:
LOGGER.info("%s %s", account_id, account_state)
return

account_path = organizations.build_account_path(
ou_id,
[], # Initial empty array to hold OU Path,
Expand Down Expand Up @@ -490,7 +467,10 @@ def main(): # pylint: disable=R0915
threads = []
account_ids = [
account_id["Id"]
for account_id in organizations.get_accounts()
for account_id in organizations.get_accounts(
protected_ou_ids=config.config.get('protected'),
include_root=False,
)
]
non_deployment_account_ids = [
account for account in account_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,49 @@ def __init__(self, role, account_id=None):
config=Organizations._config
)
self.account_id = account_id
self.account_ids = []
self.root_id = None

def get_parent_info(self):
response = self.list_parents(self.account_id)
def get_parent_info(self, account_id=None):
"""
Get the parent info of the account_id specified. If no specific
account id is specified, it will use the account_id setup when
initiating the Organizations instance.
Args:
account_id (str|None): The specific account id if any.
Returns:
dict: The ou_parent_id and ou_parent_type are returned in a
dictionary.
"""
response = self.list_parents(account_id or self.account_id)
return {
"ou_parent_id": response.get('Id'),
"ou_parent_type": response.get('Type')
}

def enable_organization_policies(self, policy_type='SERVICE_CONTROL_POLICY'): # or 'TAG_POLICY'
def enable_organization_policies(
self,
policy_type='SERVICE_CONTROL_POLICY',
):
"""
Enable the policies on the organization unit root id.
Args:
policy_type (str):
The policy type, either 'SERVICE_CONTROL_POLICY' or
'TAG_POLICY'. It defaults to the 'SERVICE_CONTROL_POLICY'.
"""
try:
self.client.enable_policy_type(
RootId=self.get_ou_root_id(),
PolicyType=policy_type
)
except self.client.exceptions.PolicyTypeAlreadyEnabledException:
LOGGER.info('%s are currently enabled within the Organization', policy_type)
LOGGER.info(
'%s are currently enabled within the Organization',
policy_type,
)

@staticmethod
def trim_policy_path(policy):
Expand All @@ -73,30 +98,58 @@ def get_organization_map(self, org_structure, counter=0):
if not Organizations.is_ou_id(ou_id):
continue
# List OUs
for organization_id in [organization_id['Id'] for organization_id in paginator(self.client.list_children, **{"ParentId":ou_id, "ChildType":"ORGANIZATIONAL_UNIT"})]:
for organization_id in [
ou_data['Id'] for ou_data in paginator(
self.client.list_children,
**{
"ParentId": ou_id,
"ChildType": "ORGANIZATIONAL_UNIT",
},
)
]:
if organization_id in org_structure.values() and counter != 0:
continue
ou_name = self.describe_ou_name(organization_id)
trimmed_path = Organizations.trim_policy_path(f"{name}/{ou_name}")
trimmed_path = Organizations.trim_policy_path(
f"{name}/{ou_name}",
)
org_structure[trimmed_path] = organization_id
# List accounts
for account_id in [account_id['Id'] for account_id in paginator(self.client.list_children, **{"ParentId":ou_id, "ChildType":"ACCOUNT"})]:
for account_id in [
account_data['Id'] for account_data in paginator(
self.client.list_children,
**{
"ParentId": ou_id,
"ChildType": "ACCOUNT",
}
)
]:
if account_id in org_structure.values() and counter != 0:
continue
account_name = self.describe_account_name(account_id)
trimmed_path = Organizations.trim_policy_path(f"{name}/{account_name}")
trimmed_path = Organizations.trim_policy_path(
f"{name}/{account_name}",
)
org_structure[trimmed_path] = account_id
counter = counter + 1
# Counter is greater than 5 here is the conditional as organizations cannot have more than 5 levels of nested OUs + 1 accounts "level"
return org_structure if counter > 5 else self.get_organization_map(org_structure, counter)
return (
org_structure if counter > 5
else self.get_organization_map(org_structure, counter)
)

def update_policy(self, content, policy_id):
self.client.update_policy(
PolicyId=policy_id,
Content=content
)

def create_policy(self, content, ou_path, policy_type="SERVICE_CONTROL_POLICY"):
def create_policy(
self,
content,
ou_path,
policy_type="SERVICE_CONTROL_POLICY",
):
policy_type_name = (
'scp' if policy_type == "SERVICE_CONTROL_POLICY"
else 'tagging-policy'
Expand All @@ -111,37 +164,50 @@ def create_policy(self, content, ou_path, policy_type="SERVICE_CONTROL_POLICY"):

@staticmethod
def get_policy_body(path):
with open(f'./adf-bootstrap/{path}', mode='r', encoding='utf-8') as policy:
bootstrap_path = f'./adf-bootstrap/{path}'
with open(bootstrap_path, mode='r', encoding='utf-8') as policy:
return json.dumps(json.load(policy))

def list_policies(self, name, policy_type="SERVICE_CONTROL_POLICY"):
response = list(paginator(self.client.list_policies, Filter=policy_type))
try:
return [policy for policy in response if policy['Name'] == name][0]['Id']
except IndexError:
return []

def describe_policy_id_for_target(self, target_id, policy_type='SERVICE_CONTROL_POLICY'):
response = list(
paginator(self.client.list_policies, Filter=policy_type)
)
filtered_policies = [
policy for policy in response
if policy['Name'] == name
]
if len(filtered_policies) > 0:
return filtered_policies[0]['Id']
return []

def describe_policy_id_for_target(
self,
target_id,
policy_type='SERVICE_CONTROL_POLICY',
):
response = self.client.list_policies_for_target(
TargetId=target_id,
Filter=policy_type
)
try:
return [p for p in response['Policies'] if f'ADF Managed {policy_type}' in p['Description']][0]['Id']
except IndexError:
return []
adf_managed_policies = [
policy for policy in response['Policies']
if f'ADF Managed {policy_type}' in policy['Description']
]
if len(adf_managed_policies) > 0:
return adf_managed_policies[0]['Id']
return []

def describe_policy(self, policy_id):
response = self.client.describe_policy(
PolicyId=policy_id
PolicyId=policy_id,
)
return response.get('Policy')

def attach_policy(self, policy_id, target_id):
try:
self.client.attach_policy(
PolicyId=policy_id,
TargetId=target_id
TargetId=target_id,
)
except self.client.exceptions.DuplicatePolicyAttachmentException:
pass
Expand All @@ -157,13 +223,66 @@ def delete_policy(self, policy_id):
PolicyId=policy_id
)

def get_accounts(self):
def _account_available_to_adf(
self,
account,
protected_ou_ids,
include_root,
):
if protected_ou_ids or not include_root:
account_ou_id = (
self.get_parent_info(account["Id"]).get("ou_parent_id")
)
if not include_root and account_ou_id.startswith("r-"):
LOGGER.info(
"Account %s is in the root of the AWS Organization, "
"therefore skipping it",
account["Id"],
)
return False
if protected_ou_ids and account_ou_id in protected_ou_ids:
LOGGER.info(
"Account %s is in OU %s which is marked as protected, "
"therefore skipping it",
account["Id"],
account_ou_id,
)
return False
if account.get("Status") != "ACTIVE":
LOGGER.warning(
"Account %s is not an active AWS Account, state reported: %s",
account["Id"],
account.get("Status"),
)
return False
return True

def get_accounts(self, protected_ou_ids=None, include_root=True):
"""
Get the accounts from this AWS Organizations.
Filtered by the given arguments if required.
Args:
protected_ou_ids (list(str)): The list of protected organization
unit ids as configured in the adfconfig.yml file.
The organization unit ids are structured like: ou-123.
include_root (bool): Whether or not to include accounts that are
located in the root of the AWS Organization.
ADF does not adopt these accounts.
Returns:
list(str): The list of account details, filtered as requested.
"""
accounts = []
for account in paginator(self.client.list_accounts):
if not account.get('Status') == 'ACTIVE':
LOGGER.warning('Account %s is not an Active AWS Account', account['Id'])
continue
self.account_ids.append(account)
return self.account_ids
if self._account_available_to_adf(
account,
protected_ou_ids,
include_root,
):
accounts.append(account)
return accounts

def get_organization_info(self):
response = self.client.describe_organization()
Expand All @@ -183,7 +302,9 @@ def describe_ou_name(self, ou_id):
)
return response['OrganizationalUnit']['Name']
except ClientError as error:
raise RootOUIDError("OU is the Root of the Organization") from error
raise RootOUIDError(
"OU is the Root of the Organization",
) from error

def describe_account_name(self, account_id):
try:
Expand All @@ -192,7 +313,10 @@ def describe_account_name(self, account_id):
)
return response['Account']['Name']
except ClientError as error:
LOGGER.error('Failed to retrieve account name for account ID %s', account_id)
LOGGER.error(
"Failed to retrieve account name for account ID %s",
account_id,
)
raise error

@staticmethod
Expand Down
Loading

0 comments on commit 3ae94ba

Please sign in to comment.