From 9197399f0d22ad7faa61219dae58c4cae1fda3b4 Mon Sep 17 00:00:00 2001 From: saieshwarm Date: Wed, 6 Mar 2024 03:05:55 +0000 Subject: [PATCH] Adding headers in STS Calls for Confused Deputy --- .pylintrc | 4 +- src/rpdk/core/boto_helpers.py | 22 ++++++- src/rpdk/core/contract/hook_client.py | 10 ++- src/rpdk/core/contract/resource_client.py | 12 ++-- src/rpdk/core/test.py | 55 ++++++++++++++--- tests/contract/test_hook_client.py | 6 +- tests/contract/test_resource_client.py | 14 ++--- tests/test_boto_helpers.py | 74 +++++++++++++++++++++++ tests/test_test.py | 36 ++++++----- 9 files changed, 187 insertions(+), 46 deletions(-) diff --git a/.pylintrc b/.pylintrc index 13ea7479..c3b15401 100644 --- a/.pylintrc +++ b/.pylintrc @@ -12,6 +12,7 @@ disable= duplicate-code, # finds dupes between tests and plugins too-few-public-methods, # triggers when inheriting ungrouped-imports, # clashes with isort + W0613 # Unused argument 'kwargs' [BASIC] @@ -23,4 +24,5 @@ indent-string=' ' max-line-length=160 [DESIGN] -max-locals=16 +max-locals=17 +max-args=6 diff --git a/src/rpdk/core/boto_helpers.py b/src/rpdk/core/boto_helpers.py index 67aad6cf..c4f326b8 100644 --- a/src/rpdk/core/boto_helpers.py +++ b/src/rpdk/core/boto_helpers.py @@ -32,17 +32,35 @@ def _known_error(msg): return session -def get_temporary_credentials(session, key_names=BOTO_CRED_KEYS, role_arn=None): +def get_temporary_credentials( + session, key_names=BOTO_CRED_KEYS, role_arn=None, headers=None +): sts_client = session.client( "sts", endpoint_url=get_service_endpoint("sts", session.region_name), region_name=session.region_name, ) + check_keys = {"account_id", "source_arn"} + if ( + headers + and check_keys.issubset(headers.keys()) + and headers["account_id"] + and headers["source_arn"] + ): + # Inject headers through the event system. + def inject_confused_deputy_headers(params, **kwargs): + params["headers"]["x-amz-source-account"] = headers["account_id"] + params["headers"]["x-amz-source-arn"] = headers["source_arn"] + + sts_client.meta.events.register("before-call", inject_confused_deputy_headers) + LOG.info(headers) if role_arn: session_name = f"CloudFormationContractTest-{datetime.now():%Y%m%d%H%M%S}" try: response = sts_client.assume_role( - RoleArn=role_arn, RoleSessionName=session_name, DurationSeconds=900 + RoleArn=role_arn, + RoleSessionName=session_name, + DurationSeconds=900, ) except ClientError: # pylint: disable=W1201 diff --git a/src/rpdk/core/contract/hook_client.py b/src/rpdk/core/contract/hook_client.py index 31fe5251..6c0e187a 100644 --- a/src/rpdk/core/contract/hook_client.py +++ b/src/rpdk/core/contract/hook_client.py @@ -56,6 +56,7 @@ def __init__( type_name=None, log_group_name=None, log_role_arn=None, + headers=None, docker_image=None, typeconfig=None, executable_entrypoint=None, @@ -69,9 +70,12 @@ def __init__( self._log_group_name = log_group_name self._log_role_arn = log_role_arn self.region = region + self._headers = headers self.account = get_account( self._session, - get_temporary_credentials(self._session, LOWER_CAMEL_CRED_KEYS, role_arn), + get_temporary_credentials( + self._session, LOWER_CAMEL_CRED_KEYS, role_arn, headers + ), ) self._function_name = function_name if endpoint.startswith("http://"): @@ -396,11 +400,11 @@ def _make_payload( self.account, invocation_point, get_temporary_credentials( - self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn + self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers ), self._log_group_name, get_temporary_credentials( - self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn + self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn, self._headers ), self.generate_token(), target_model, diff --git a/src/rpdk/core/contract/resource_client.py b/src/rpdk/core/contract/resource_client.py index 77a219d2..89e89bbc 100644 --- a/src/rpdk/core/contract/resource_client.py +++ b/src/rpdk/core/contract/resource_client.py @@ -171,6 +171,7 @@ def __init__( type_name=None, log_group_name=None, log_role_arn=None, + headers=None, docker_image=None, typeconfig=None, executable_entrypoint=None, @@ -182,9 +183,12 @@ def __init__( self._log_group_name = log_group_name self._log_role_arn = log_role_arn self.region = region + self._headers = headers self.account = get_account( self._session, - get_temporary_credentials(self._session, LOWER_CAMEL_CRED_KEYS, role_arn), + get_temporary_credentials( + self._session, LOWER_CAMEL_CRED_KEYS, role_arn, headers + ), ) self._function_name = function_name if endpoint.startswith("http://"): @@ -674,12 +678,12 @@ def _make_payload( self.account, action, get_temporary_credentials( - self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn + self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers ), self._type_name, self._log_group_name, get_temporary_credentials( - self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn + self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn, self._headers ), self.generate_token(), type_configuration=type_configuration, @@ -794,7 +798,7 @@ def call(self, action, current_model, previous_model=None, **kwargs): request["callbackContext"] = response.get("callbackContext") # refresh credential for every handler invocation request["requestData"]["callerCredentials"] = get_temporary_credentials( - self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn + self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers ) response = self._call(request) diff --git a/src/rpdk/core/test.py b/src/rpdk/core/test.py index 81d70c30..e0f909c3 100644 --- a/src/rpdk/core/test.py +++ b/src/rpdk/core/test.py @@ -102,9 +102,13 @@ def temporary_ini_file(): yield str(path) -def get_cloudformation_exports(region_name, endpoint_url, role_arn, profile_name): +def get_cloudformation_exports( + region_name, endpoint_url, role_arn, profile_name, headers +): session = create_sdk_session(region_name, profile_name) - temp_credentials = get_temporary_credentials(session, role_arn=role_arn) + temp_credentials = get_temporary_credentials( + session, role_arn=role_arn, headers=headers + ) cfn_client = session.client( "cloudformation", endpoint_url=endpoint_url, **temp_credentials ) @@ -132,13 +136,13 @@ def __retrieve_args(match): def render_template( - overrides_string, region_name, endpoint_url, role_arn, profile_name + overrides_string, region_name, endpoint_url, role_arn, profile_name, headers ): regex = r"{{([-A-Za-z0-9:\s]+?)}}" variables = set(str(match).strip() for match in re.findall(regex, overrides_string)) if variables: exports = get_cloudformation_exports( - region_name, endpoint_url, role_arn, profile_name + region_name, endpoint_url, role_arn, profile_name, headers ) invalid_exports = variables - exports.keys() if len(invalid_exports) > 0: @@ -166,7 +170,7 @@ def filter_overrides(overrides, project): return overrides -def get_overrides(root, region_name, endpoint_url, role_arn, profile_name): +def get_overrides(root, region_name, endpoint_url, role_arn, profile_name, headers): if not root: return empty_override() @@ -174,7 +178,12 @@ def get_overrides(root, region_name, endpoint_url, role_arn, profile_name): try: with path.open("r", encoding="utf-8") as f: overrides_raw = render_template( - f.read(), region_name, endpoint_url, role_arn, profile_name + f.read(), + region_name, + endpoint_url, + role_arn, + profile_name, + headers=headers, ) except FileNotFoundError: LOG.debug("Override file '%s' not found. No overrides will be applied", path) @@ -203,7 +212,9 @@ def get_overrides(root, region_name, endpoint_url, role_arn, profile_name): # pylint: disable=R0914 # flake8: noqa: C901 -def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name): +def get_hook_overrides( + root, region_name, endpoint_url, role_arn, profile_name, headers +): if not root: return empty_hook_override() @@ -211,7 +222,12 @@ def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name): try: with path.open("r", encoding="utf-8") as f: overrides_raw = render_template( - f.read(), region_name, endpoint_url, role_arn, profile_name + f.read(), + region_name, + endpoint_url, + role_arn, + profile_name, + headers=headers, ) except FileNotFoundError: LOG.debug("Override file '%s' not found. No overrides will be applied", path) @@ -258,7 +274,7 @@ def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name): # pylint: disable=R0914,too-many-arguments -def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name): +def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name, headers): inputs = {} if not root: return None @@ -280,7 +296,12 @@ def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name): file_path = path / file with file_path.open("r", encoding="utf-8") as f: overrides_raw = render_template( - f.read(), region_name, endpoint_url, role_arn, profile_name + f.read(), + region_name, + endpoint_url, + role_arn, + profile_name, + headers=headers, ) overrides = {} for pointer, obj in overrides_raw.items(): @@ -355,6 +376,7 @@ def get_contract_plugin_client(args, project, overrides, inputs): project.type_name, args.log_group_name, args.log_role_arn, + headers={"account_id": args.source_account, "source_arn": args.source_arn}, executable_entrypoint=project.executable_entrypoint, docker_image=args.docker_image, typeconfig=args.typeconfig, @@ -378,6 +400,7 @@ def get_contract_plugin_client(args, project, overrides, inputs): project.type_name, args.log_group_name, args.log_role_arn, + headers={"account_id": args.source_account, "source_arn": args.source_arn}, typeconfig=args.typeconfig, executable_entrypoint=project.executable_entrypoint, docker_image=args.docker_image, @@ -402,6 +425,7 @@ def test(args): args.cloudformation_endpoint_url, args.role_arn, args.profile, + headers={"account_id": args.source_account, "source_arn": args.source_arn}, ) else: overrides = get_overrides( @@ -410,6 +434,7 @@ def test(args): args.cloudformation_endpoint_url, args.role_arn, args.profile, + headers={"account_id": args.source_account, "source_arn": args.source_arn}, ) filter_overrides(overrides, project) @@ -422,6 +447,7 @@ def test(args): index, args.role_arn, args.profile, + headers={"account_id": args.source_account, "source_arn": args.source_arn}, ) if not inputs: break @@ -509,6 +535,15 @@ def setup_subparser(subparsers, parents): " '~/.cfn-cli/typeConfiguration.json.'" ), ) + parser.add_argument( + "--source-account", + help="Source Account key used for Assume Role to Run Contract Tests", + ) + + parser.add_argument( + "--source-arn", + help="Source Type Version Arn key used for Assume Role to Run Contract Tests", + ) def _sam_arguments(parser): diff --git a/tests/contract/test_hook_client.py b/tests/contract/test_hook_client.py index e9269c0a..e8072cc5 100644 --- a/tests/contract/test_hook_client.py +++ b/tests/contract/test_hook_client.py @@ -121,7 +121,7 @@ def hook_client(): ) mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client._function_name == DEFAULT_FUNCTION assert client._schema == SCHEMA_ @@ -179,7 +179,7 @@ def hook_client_inputs(): ) mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client._function_name == DEFAULT_FUNCTION assert client._schema == SCHEMA_ @@ -215,7 +215,7 @@ def test_init_sam_cli_client(): mock_sesh.client.assert_called_once_with( "lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY ) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client.account == ACCOUNT diff --git a/tests/contract/test_resource_client.py b/tests/contract/test_resource_client.py index 5182ac03..cc09456b 100644 --- a/tests/contract/test_resource_client.py +++ b/tests/contract/test_resource_client.py @@ -179,7 +179,7 @@ def resource_client(): ) mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client._function_name == DEFAULT_FUNCTION assert client._schema == EMPTY_SCHEMA @@ -214,7 +214,7 @@ def resource_client_no_handler(): ) mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client._function_name == DEFAULT_FUNCTION assert client._schema == {} @@ -254,7 +254,7 @@ def resource_client_inputs(): ) mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client._function_name == DEFAULT_FUNCTION @@ -299,7 +299,7 @@ def resource_client_inputs_schema(request): ) mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client._function_name == DEFAULT_FUNCTION @@ -344,7 +344,7 @@ def resource_client_inputs_composite_key(): ) mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client._function_name == DEFAULT_FUNCTION @@ -384,7 +384,7 @@ def resource_client_inputs_property_transform(): ) mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client._function_name == DEFAULT_FUNCTION assert client._schema == SCHEMA_WITH_PROPERTY_TRANSFORM @@ -693,7 +693,7 @@ def test_init_sam_cli_client(): mock_sesh.client.assert_called_once_with( "lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY ) - mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None) mock_account.assert_called_once_with(mock_sesh, {}) assert client.account == ACCOUNT diff --git a/tests/test_boto_helpers.py b/tests/test_boto_helpers.py index 2414e6d2..4abdaa27 100644 --- a/tests/test_boto_helpers.py +++ b/tests/test_boto_helpers.py @@ -14,6 +14,8 @@ from rpdk.core.exceptions import CLIMisconfiguredError, DownstreamError EXPECTED_ROLE = "someroleArn" +SOURCE_ACCOUNT = "123456789" +SOURCE_ARN = "someSourceArn" def test_create_sdk_session_region(): @@ -200,6 +202,78 @@ def test_get_temporary_credentials_assume_role(): assert tuple(creds.values()) == (access_key, secret_key, token) +def test_get_temporary_credentials_assume_role_with_headers(): + session = create_autospec(spec=Session, spec_set=True) + + access_key = object() + secret_key = object() + token = object() + + client = session.client.return_value + client.assume_role.return_value = { + "Credentials": { + "AccessKeyId": access_key, + "SecretAccessKey": secret_key, + "SessionToken": token, + } + } + session.region_name = "cn-north-1" + + header = {"account_id": SOURCE_ACCOUNT, "source_arn": SOURCE_ARN} + creds = get_temporary_credentials( + session, LOWER_CAMEL_CRED_KEYS, EXPECTED_ROLE, header + ) + + session.client.assert_called_once_with( + "sts", + endpoint_url="https://sts.cn-north-1.amazonaws.com.cn", + region_name="cn-north-1", + ) + client.assume_role.assert_called_once_with( + RoleArn=EXPECTED_ROLE, RoleSessionName=ANY, DurationSeconds=900 + ) + + assert len(creds) == 3 + assert tuple(creds.keys()) == LOWER_CAMEL_CRED_KEYS + assert tuple(creds.values()) == (access_key, secret_key, token) + + +def test_get_temporary_credentials_assume_role_with_missing_account_id_header(): + session = create_autospec(spec=Session, spec_set=True) + + access_key = object() + secret_key = object() + token = object() + + client = session.client.return_value + client.assume_role.return_value = { + "Credentials": { + "AccessKeyId": access_key, + "SecretAccessKey": secret_key, + "SessionToken": token, + } + } + session.region_name = "cn-north-1" + + header = {"source_arn": None} + creds = get_temporary_credentials( + session, LOWER_CAMEL_CRED_KEYS, EXPECTED_ROLE, header + ) + + session.client.assert_called_once_with( + "sts", + endpoint_url="https://sts.cn-north-1.amazonaws.com.cn", + region_name="cn-north-1", + ) + client.assume_role.assert_called_once_with( + RoleArn=EXPECTED_ROLE, RoleSessionName=ANY, DurationSeconds=900 + ) + + assert len(creds) == 3 + assert tuple(creds.keys()) == LOWER_CAMEL_CRED_KEYS + assert tuple(creds.values()) == (access_key, secret_key, token) + + def test_get_account_with_temporary_credentials(): session = create_autospec(spec=Session, spec_set=True) client = session.client.return_value diff --git a/tests/test_test.py b/tests/test_test.py index 68f13dbe..977a9805 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -266,6 +266,7 @@ def test_test_command_happy_path_resource( mock_project.type_name, None, None, + headers={"account_id": None, "source_arn": None}, typeconfig=None, executable_entrypoint=None, docker_image=None, @@ -374,6 +375,7 @@ def test_test_command_happy_path_hook( mock_project.type_name, None, None, + headers={"account_id": None, "source_arn": None}, typeconfig=None, executable_entrypoint=None, docker_image=None, @@ -440,7 +442,7 @@ def test_temporary_ini_file(): def test_get_overrides_no_root(): assert ( - get_overrides(None, DEFAULT_REGION, "", None, DEFAULT_PROFILE) + get_overrides(None, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None) == EMPTY_RESOURCE_OVERRIDE ) @@ -452,7 +454,7 @@ def test_get_overrides_file_not_found(base): except FileNotFoundError: pass assert ( - get_overrides(path, DEFAULT_REGION, "", None, DEFAULT_PROFILE) + get_overrides(path, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None) == EMPTY_RESOURCE_OVERRIDE ) @@ -461,7 +463,7 @@ def test_get_overrides_invalid_file(base): path = base / "overrides.json" path.write_text("{}") assert ( - get_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE) + get_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None) == EMPTY_RESOURCE_OVERRIDE ) @@ -471,7 +473,7 @@ def test_get_overrides_empty_overrides(base): with path.open("w", encoding="utf-8") as f: json.dump(EMPTY_RESOURCE_OVERRIDE, f) assert ( - get_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE) + get_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None) == EMPTY_RESOURCE_OVERRIDE ) @@ -484,7 +486,7 @@ def test_get_overrides_invalid_pointer_skipped(base): with path.open("w", encoding="utf-8") as f: json.dump(overrides, f) assert ( - get_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE) + get_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None) == EMPTY_RESOURCE_OVERRIDE ) @@ -496,14 +498,14 @@ def test_get_overrides_good_path(base): path = base / "overrides.json" with path.open("w", encoding="utf-8") as f: json.dump(overrides, f) - assert get_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE) == { + assert get_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None) == { "CREATE": {("foo", "bar"): {}} } def test_get_hook_overrides_no_root(): assert ( - get_hook_overrides(None, DEFAULT_REGION, "", None, DEFAULT_PROFILE) + get_hook_overrides(None, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None) == EMPTY_HOOK_OVERRIDE ) @@ -515,7 +517,7 @@ def test_get_hook_overrides_file_not_found(base): except FileNotFoundError: pass assert ( - get_hook_overrides(path, DEFAULT_REGION, "", None, DEFAULT_PROFILE) + get_hook_overrides(path, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None) == EMPTY_HOOK_OVERRIDE ) @@ -524,7 +526,7 @@ def test_get_hook_overrides_invalid_file(base): path = base / "overrides.json" path.write_text("{}") assert ( - get_hook_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE) + get_hook_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None) == EMPTY_HOOK_OVERRIDE ) @@ -538,7 +540,9 @@ def test_get_hook_overrides_good_path(base): path = base / "overrides.json" with path.open("w", encoding="utf-8") as f: json.dump(overrides, f) - assert get_hook_overrides(base, DEFAULT_REGION, "", None, DEFAULT_PROFILE) == { + assert get_hook_overrides( + base, DEFAULT_REGION, "", None, DEFAULT_PROFILE, None + ) == { "CREATE_PRE_PROVISION": { "My::Example::Resource": {"resourceProperties": {("foo", "bar"): {}}} } @@ -595,7 +599,7 @@ def test_get_overrides_with_jinja( mock_cfn_client, Mock(), ] - result = get_overrides(base, DEFAULT_REGION, None, None, DEFAULT_PROFILE) + result = get_overrides(base, DEFAULT_REGION, None, None, DEFAULT_PROFILE, None) assert result == expected_overrides @@ -661,7 +665,7 @@ def test_with_inputs( mock_cfn_client, Mock(), ] - result = get_inputs(base, DEFAULT_REGION, None, 1, None, DEFAULT_PROFILE) + result = get_inputs(base, DEFAULT_REGION, None, 1, None, DEFAULT_PROFILE, None) assert result == expected_inputs @@ -685,23 +689,23 @@ def test_with_inputs_invalid(base): mock_cfn_client, Mock(), ] - result = get_inputs(base, DEFAULT_REGION, None, 1, None, DEFAULT_PROFILE) + result = get_inputs(base, DEFAULT_REGION, None, 1, None, DEFAULT_PROFILE, None) assert not result def test_get_input_invalid_root(): - assert not get_inputs("", DEFAULT_REGION, "", 1, None, DEFAULT_PROFILE) + assert not get_inputs("", DEFAULT_REGION, "", 1, None, DEFAULT_PROFILE, None) def test_get_input_input_folder_does_not_exist(base): - assert not get_inputs(base, DEFAULT_REGION, "", 1, None, DEFAULT_PROFILE) + assert not get_inputs(base, DEFAULT_REGION, "", 1, None, DEFAULT_PROFILE, None) def test_get_input_file_not_found(base): path = base / "inputs" os.mkdir(path, mode=0o777) - assert not get_inputs(base, DEFAULT_REGION, "", 1, None, DEFAULT_PROFILE) + assert not get_inputs(base, DEFAULT_REGION, "", 1, None, DEFAULT_PROFILE, None) def test_use_both_sam_and_docker_arguments():