Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding headers in STS Calls for Confused Deputy #1061

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ disable=
duplicate-code, # finds dupes between tests and plugins
too-few-public-methods, # triggers when inheriting
ungrouped-imports, # clashes with isort
W0613 # Unused argument 'kwargs'

[BASIC]

Expand All @@ -23,4 +24,5 @@ indent-string=' '
max-line-length=160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an exception to just the one line if it's not possible to break it down? 160 is very long.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


[DESIGN]
max-locals=16
max-locals=17
max-args=6
22 changes: 20 additions & 2 deletions src/rpdk/core/boto_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,35 @@ def _known_error(msg):
return session


def get_temporary_credentials(session, key_names=BOTO_CRED_KEYS, role_arn=None):
def get_temporary_credentials(
session, key_names=BOTO_CRED_KEYS, role_arn=None, headers=None
):
sts_client = session.client(
"sts",
endpoint_url=get_service_endpoint("sts", session.region_name),
region_name=session.region_name,
)
check_keys = {"account_id", "source_arn"}
if (
headers
and check_keys.issubset(headers.keys())
and headers["account_id"]
and headers["source_arn"]
):
# Inject headers through the event system.
def inject_confused_deputy_headers(params, **kwargs):
params["headers"]["x-amz-source-account"] = headers["account_id"]
params["headers"]["x-amz-source-arn"] = headers["source_arn"]

sts_client.meta.events.register("before-call", inject_confused_deputy_headers)
LOG.info(headers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to still log it

if role_arn:
session_name = f"CloudFormationContractTest-{datetime.now():%Y%m%d%H%M%S}"
try:
response = sts_client.assume_role(
RoleArn=role_arn, RoleSessionName=session_name, DurationSeconds=900
RoleArn=role_arn,
RoleSessionName=session_name,
DurationSeconds=900,
)
except ClientError:
# pylint: disable=W1201
Expand Down
10 changes: 7 additions & 3 deletions src/rpdk/core/contract/hook_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
type_name=None,
log_group_name=None,
log_role_arn=None,
headers=None,
docker_image=None,
typeconfig=None,
executable_entrypoint=None,
Expand All @@ -69,9 +70,12 @@ def __init__(
self._log_group_name = log_group_name
self._log_role_arn = log_role_arn
self.region = region
self._headers = headers
self.account = get_account(
self._session,
get_temporary_credentials(self._session, LOWER_CAMEL_CRED_KEYS, role_arn),
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, role_arn, headers
),
)
self._function_name = function_name
if endpoint.startswith("http://"):
Expand Down Expand Up @@ -396,11 +400,11 @@ def _make_payload(
self.account,
invocation_point,
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers
),
self._log_group_name,
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn, self._headers
),
self.generate_token(),
target_model,
Expand Down
12 changes: 8 additions & 4 deletions src/rpdk/core/contract/resource_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
type_name=None,
log_group_name=None,
log_role_arn=None,
headers=None,
docker_image=None,
typeconfig=None,
executable_entrypoint=None,
Expand All @@ -182,9 +183,12 @@ def __init__(
self._log_group_name = log_group_name
self._log_role_arn = log_role_arn
self.region = region
self._headers = headers
self.account = get_account(
self._session,
get_temporary_credentials(self._session, LOWER_CAMEL_CRED_KEYS, role_arn),
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, role_arn, headers
),
)
self._function_name = function_name
if endpoint.startswith("http://"):
Expand Down Expand Up @@ -674,12 +678,12 @@ def _make_payload(
self.account,
action,
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers
),
self._type_name,
self._log_group_name,
get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn, self._headers
),
self.generate_token(),
type_configuration=type_configuration,
Expand Down Expand Up @@ -794,7 +798,7 @@ def call(self, action, current_model, previous_model=None, **kwargs):
request["callbackContext"] = response.get("callbackContext")
# refresh credential for every handler invocation
request["requestData"]["callerCredentials"] = get_temporary_credentials(
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers
)

response = self._call(request)
Expand Down
55 changes: 45 additions & 10 deletions src/rpdk/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ def temporary_ini_file():
yield str(path)


def get_cloudformation_exports(region_name, endpoint_url, role_arn, profile_name):
def get_cloudformation_exports(
region_name, endpoint_url, role_arn, profile_name, headers
):
session = create_sdk_session(region_name, profile_name)
temp_credentials = get_temporary_credentials(session, role_arn=role_arn)
temp_credentials = get_temporary_credentials(
session, role_arn=role_arn, headers=headers
)
cfn_client = session.client(
"cloudformation", endpoint_url=endpoint_url, **temp_credentials
)
Expand Down Expand Up @@ -132,13 +136,13 @@ def __retrieve_args(match):


def render_template(
overrides_string, region_name, endpoint_url, role_arn, profile_name
overrides_string, region_name, endpoint_url, role_arn, profile_name, headers
):
regex = r"{{([-A-Za-z0-9:\s]+?)}}"
variables = set(str(match).strip() for match in re.findall(regex, overrides_string))
if variables:
exports = get_cloudformation_exports(
region_name, endpoint_url, role_arn, profile_name
region_name, endpoint_url, role_arn, profile_name, headers
)
invalid_exports = variables - exports.keys()
if len(invalid_exports) > 0:
Expand Down Expand Up @@ -166,15 +170,20 @@ def filter_overrides(overrides, project):
return overrides


def get_overrides(root, region_name, endpoint_url, role_arn, profile_name):
def get_overrides(root, region_name, endpoint_url, role_arn, profile_name, headers):
if not root:
return empty_override()

path = root / "overrides.json"
try:
with path.open("r", encoding="utf-8") as f:
overrides_raw = render_template(
f.read(), region_name, endpoint_url, role_arn, profile_name
f.read(),
region_name,
endpoint_url,
role_arn,
profile_name,
headers=headers,
)
except FileNotFoundError:
LOG.debug("Override file '%s' not found. No overrides will be applied", path)
Expand Down Expand Up @@ -203,15 +212,22 @@ def get_overrides(root, region_name, endpoint_url, role_arn, profile_name):

# pylint: disable=R0914
# flake8: noqa: C901
def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name):
def get_hook_overrides(
root, region_name, endpoint_url, role_arn, profile_name, headers
):
if not root:
return empty_hook_override()

path = root / "overrides.json"
try:
with path.open("r", encoding="utf-8") as f:
overrides_raw = render_template(
f.read(), region_name, endpoint_url, role_arn, profile_name
f.read(),
region_name,
endpoint_url,
role_arn,
profile_name,
headers=headers,
)
except FileNotFoundError:
LOG.debug("Override file '%s' not found. No overrides will be applied", path)
Expand Down Expand Up @@ -258,7 +274,7 @@ def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name):


# pylint: disable=R0914,too-many-arguments
def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name):
def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name, headers):
inputs = {}
if not root:
return None
Expand All @@ -280,7 +296,12 @@ def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name):
file_path = path / file
with file_path.open("r", encoding="utf-8") as f:
overrides_raw = render_template(
f.read(), region_name, endpoint_url, role_arn, profile_name
f.read(),
region_name,
endpoint_url,
role_arn,
profile_name,
headers=headers,
)
overrides = {}
for pointer, obj in overrides_raw.items():
Expand Down Expand Up @@ -355,6 +376,7 @@ def get_contract_plugin_client(args, project, overrides, inputs):
project.type_name,
args.log_group_name,
args.log_role_arn,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
executable_entrypoint=project.executable_entrypoint,
docker_image=args.docker_image,
typeconfig=args.typeconfig,
Expand All @@ -378,6 +400,7 @@ def get_contract_plugin_client(args, project, overrides, inputs):
project.type_name,
args.log_group_name,
args.log_role_arn,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
typeconfig=args.typeconfig,
executable_entrypoint=project.executable_entrypoint,
docker_image=args.docker_image,
Expand All @@ -402,6 +425,7 @@ def test(args):
args.cloudformation_endpoint_url,
args.role_arn,
args.profile,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
)
else:
overrides = get_overrides(
Expand All @@ -410,6 +434,7 @@ def test(args):
args.cloudformation_endpoint_url,
args.role_arn,
args.profile,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
)
filter_overrides(overrides, project)

Expand All @@ -422,6 +447,7 @@ def test(args):
index,
args.role_arn,
args.profile,
headers={"account_id": args.source_account, "source_arn": args.source_arn},
)
if not inputs:
break
Expand Down Expand Up @@ -509,6 +535,15 @@ def setup_subparser(subparsers, parents):
" '~/.cfn-cli/typeConfiguration.json.'"
),
)
parser.add_argument(
"--source-account",
help="Source Account key used for Assume Role to Run Contract Tests",
)

parser.add_argument(
"--source-arn",
help="Source Type Version Arn key used for Assume Role to Run Contract Tests",
)


def _sam_arguments(parser):
Expand Down
6 changes: 3 additions & 3 deletions tests/contract/test_hook_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def hook_client():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == SCHEMA_
Expand Down Expand Up @@ -179,7 +179,7 @@ def hook_client_inputs():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == SCHEMA_
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_init_sam_cli_client():
mock_sesh.client.assert_called_once_with(
"lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY
)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client.account == ACCOUNT

Expand Down
14 changes: 7 additions & 7 deletions tests/contract/test_resource_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def resource_client():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == EMPTY_SCHEMA
Expand Down Expand Up @@ -214,7 +214,7 @@ def resource_client_no_handler():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == {}
Expand Down Expand Up @@ -254,7 +254,7 @@ def resource_client_inputs():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})

assert client._function_name == DEFAULT_FUNCTION
Expand Down Expand Up @@ -299,7 +299,7 @@ def resource_client_inputs_schema(request):
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})

assert client._function_name == DEFAULT_FUNCTION
Expand Down Expand Up @@ -344,7 +344,7 @@ def resource_client_inputs_composite_key():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})

assert client._function_name == DEFAULT_FUNCTION
Expand Down Expand Up @@ -384,7 +384,7 @@ def resource_client_inputs_property_transform():
)

mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client._function_name == DEFAULT_FUNCTION
assert client._schema == SCHEMA_WITH_PROPERTY_TRANSFORM
Expand Down Expand Up @@ -693,7 +693,7 @@ def test_init_sam_cli_client():
mock_sesh.client.assert_called_once_with(
"lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY
)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
mock_account.assert_called_once_with(mock_sesh, {})
assert client.account == ACCOUNT

Expand Down
Loading
Loading