From 60bc42421823441c9146eae9f6b675a25d776f34 Mon Sep 17 00:00:00 2001 From: Kartik Ganesh Date: Tue, 7 Nov 2023 14:44:44 -0800 Subject: [PATCH] [Fetch Migration] Added support for AWS SigV4 Signed-off-by: Kartik Ganesh --- FetchMigration/python/dev-requirements.txt | 3 +- FetchMigration/python/endpoint_info.py | 8 +- FetchMigration/python/endpoint_utils.py | 111 +++++++++--- FetchMigration/python/requirements.txt | 1 - .../python/tests/test_endpoint_utils.py | 166 +++++++++++++++--- 5 files changed, 237 insertions(+), 52 deletions(-) diff --git a/FetchMigration/python/dev-requirements.txt b/FetchMigration/python/dev-requirements.txt index 2efcff09b9..4697944d11 100644 --- a/FetchMigration/python/dev-requirements.txt +++ b/FetchMigration/python/dev-requirements.txt @@ -1,2 +1,3 @@ coverage>=7.3.2 -pur>=7.3.1 \ No newline at end of file +pur>=7.3.1 +moto>=4.2.7 \ No newline at end of file diff --git a/FetchMigration/python/endpoint_info.py b/FetchMigration/python/endpoint_info.py index a9cce2b13d..81ace5be3b 100644 --- a/FetchMigration/python/endpoint_info.py +++ b/FetchMigration/python/endpoint_info.py @@ -1,14 +1,16 @@ from typing import Optional +from requests_aws4auth import AWS4Auth + # Class that encapsulates endpoint information for an OpenSearch/Elasticsearch cluster class EndpointInfo: # Private member variables __url: str - __auth: Optional[tuple] + __auth: Optional[tuple] | AWS4Auth __verify_ssl: bool - def __init__(self, url: str, auth: tuple = None, verify_ssl: bool = True): + def __init__(self, url: str, auth: tuple | AWS4Auth = None, verify_ssl: bool = True): self.__url = url # Normalize url value to have trailing slash if not url.endswith("/"): @@ -31,7 +33,7 @@ def add_path(self, path: str) -> str: def get_url(self) -> str: return self.__url - def get_auth(self) -> Optional[tuple]: + def get_auth(self) -> Optional[tuple] | AWS4Auth: return self.__auth def is_verify_ssl(self) -> bool: diff --git a/FetchMigration/python/endpoint_utils.py b/FetchMigration/python/endpoint_utils.py index 3f58a3f709..f3e74e6fe2 100644 --- a/FetchMigration/python/endpoint_utils.py +++ b/FetchMigration/python/endpoint_utils.py @@ -1,5 +1,9 @@ +import re from typing import Optional +from requests_aws4auth import AWS4Auth +from botocore.session import Session + from endpoint_info import EndpointInfo # Constants @@ -12,21 +16,61 @@ DISABLE_AUTH_KEY = "disable_authentication" USER_KEY = "username" PWD_KEY = "password" +AWS_SIGV4_KEY = "aws_sigv4" +AWS_REGION_KEY = "aws_region" +AWS_CONFIG_KEY = "aws" +AWS_CONFIG_REGION_KEY = "region" +IS_SERVERLESS_KEY = "serverless" +ES_SERVICE_NAME = "es" +AOSS_SERVICE_NAME = "aoss" +URL_REGION_PATTERN = re.compile(r"([\w-]*)\.(es|aoss)\.amazonaws\.com") -def __check_supported_endpoint(config: dict) -> Optional[tuple]: +def __get_url(plugin_config: dict) -> str: + # "hosts" can be a simple string, or an array of hosts for Logstash to hit. + # This tool needs one accessible host, so we pick the first entry in the latter case. + return plugin_config[HOSTS_KEY][0] if type(plugin_config[HOSTS_KEY]) is list else plugin_config[HOSTS_KEY] + + +# Helper function that attempts to extract the AWS region from a URL, +# assuming it is of the form *...amazonaws.com +def __derive_aws_region_from_url(url: str) -> Optional[str]: + match = URL_REGION_PATTERN.search(url) + if match: + # Index 0 returns the entire match, index 1 returns only the first group + return match.group(1) + return None + + +def get_aws_region(plugin_config: dict) -> str: + if plugin_config.get(AWS_SIGV4_KEY, False) and plugin_config.get(AWS_REGION_KEY, None) is not None: + return plugin_config[AWS_REGION_KEY] + elif plugin_config.get(AWS_CONFIG_KEY, None) is not None: + aws_config = plugin_config[AWS_CONFIG_KEY] + if type(aws_config) is not dict: + raise ValueError("Unexpected value for 'aws' configuration") + elif aws_config.get(AWS_CONFIG_REGION_KEY, None) is not None: + return aws_config[AWS_CONFIG_REGION_KEY] + # Region not explicitly defined, attempt to derive from URL + derived_region = __derive_aws_region_from_url(__get_url(plugin_config)) + if derived_region is None: + raise ValueError("No region configured for AWS SigV4 auth, or derivable from host URL") + return derived_region + + +def __check_supported_endpoint(section_config: dict) -> Optional[tuple]: for supported_type in SUPPORTED_PLUGINS: - if supported_type in config: - return supported_type, config[supported_type] + if supported_type in section_config: + return supported_type, section_config[supported_type] # This config key may be either directly in the main dict (for sink) # or inside a nested dict (for source). The default value is False. -def is_insecure(config: dict) -> bool: - if INSECURE_KEY in config: - return config[INSECURE_KEY] - elif CONNECTION_KEY in config and INSECURE_KEY in config[CONNECTION_KEY]: - return config[CONNECTION_KEY][INSECURE_KEY] +def is_insecure(plugin_config: dict) -> bool: + if INSECURE_KEY in plugin_config: + return plugin_config[INSECURE_KEY] + elif CONNECTION_KEY in plugin_config and INSECURE_KEY in plugin_config[CONNECTION_KEY]: + return plugin_config[CONNECTION_KEY][INSECURE_KEY] return False @@ -37,14 +81,19 @@ def validate_pipeline(pipeline: dict): raise ValueError("Missing sink configuration in Data Prepper pipeline YAML") -def validate_auth(plugin_name: str, config: dict): - # Check if auth is disabled. If so, no further validation is required - if config.get(DISABLE_AUTH_KEY, False): +def validate_auth(plugin_name: str, plugin_config: dict): + # If auth is disabled, no further validation is required + if plugin_config.get(DISABLE_AUTH_KEY, False): + return + # If AWS SigV4 is configured, validate region + if plugin_config.get(AWS_SIGV4_KEY, False) or AWS_CONFIG_KEY in plugin_config: + # Raises a ValueError if region cannot be derived + get_aws_region(plugin_config) return - # TODO AWS / SigV4 - elif USER_KEY not in config: + # Validate basic auth + elif USER_KEY not in plugin_config: raise ValueError("Invalid auth configuration (no username) for plugin: " + plugin_name) - elif PWD_KEY not in config: + elif PWD_KEY not in plugin_config: raise ValueError("Invalid auth configuration (no password for username) for plugin: " + plugin_name) @@ -65,23 +114,39 @@ def get_supported_endpoint_config(pipeline_config: dict, section_key: str) -> tu return supported_tuple[0], supported_tuple[1] -# TODO Only supports basic auth for now -def get_auth(input_data: dict) -> Optional[tuple]: - if not input_data.get(DISABLE_AUTH_KEY, False) and USER_KEY in input_data and PWD_KEY in input_data: - return input_data[USER_KEY], input_data[PWD_KEY] +def get_aws_sigv4_auth(region: str, is_serverless: bool = False) -> AWS4Auth: + credentials = Session().get_credentials() + if not credentials: + raise ValueError("Unable to fetch AWS session credentials for SigV4 auth") + if is_serverless: + return AWS4Auth(region=region, service=AOSS_SERVICE_NAME, refreshable_credentials=credentials) + else: + return AWS4Auth(region=region, service=ES_SERVICE_NAME, refreshable_credentials=credentials) + + +def get_auth(plugin_config: dict) -> Optional[tuple] | AWS4Auth: + # Basic auth + if USER_KEY in plugin_config and PWD_KEY in plugin_config: + return plugin_config[USER_KEY], plugin_config[PWD_KEY] + elif plugin_config.get(AWS_SIGV4_KEY, False) or AWS_CONFIG_KEY in plugin_config: + is_serverless = False + # OpenSearch Serverless uses a different service name + if AWS_CONFIG_KEY in plugin_config: + aws_config = plugin_config[AWS_CONFIG_KEY] + if type(aws_config) is dict and aws_config.get(IS_SERVERLESS_KEY, False): + is_serverless = True + region = get_aws_region(plugin_config) + return get_aws_sigv4_auth(region, is_serverless) + return None def get_endpoint_info_from_plugin_config(plugin_config: dict) -> EndpointInfo: - # "hosts" can be a simple string, or an array of hosts for Logstash to hit. - # This tool needs one accessible host, so we pick the first entry in the latter case. - url = plugin_config[HOSTS_KEY][0] if type(plugin_config[HOSTS_KEY]) is list else plugin_config[HOSTS_KEY] # verify boolean will be the inverse of the insecure SSL key, if present should_verify = not is_insecure(plugin_config) - return EndpointInfo(url, get_auth(plugin_config), should_verify) + return EndpointInfo(__get_url(plugin_config), get_auth(plugin_config), should_verify) def get_endpoint_info_from_pipeline_config(pipeline_config: dict, section_key: str) -> EndpointInfo: - # Raises a ValueError if no supported endpoints are found plugin_name, plugin_config = get_supported_endpoint_config(pipeline_config, section_key) if HOSTS_KEY not in plugin_config: diff --git a/FetchMigration/python/requirements.txt b/FetchMigration/python/requirements.txt index 257bda427c..a8f57b550c 100644 --- a/FetchMigration/python/requirements.txt +++ b/FetchMigration/python/requirements.txt @@ -4,5 +4,4 @@ prometheus-client>=0.17.1 pyyaml>=6.0.1 requests>=2.31.0 requests-aws4auth>=1.2.3 -requests-auth-aws-sigv4>=0.7 responses>=0.23.3 diff --git a/FetchMigration/python/tests/test_endpoint_utils.py b/FetchMigration/python/tests/test_endpoint_utils.py index 6909754aa8..2cf03909ff 100644 --- a/FetchMigration/python/tests/test_endpoint_utils.py +++ b/FetchMigration/python/tests/test_endpoint_utils.py @@ -3,12 +3,14 @@ import random import unittest from typing import Optional +from unittest.mock import MagicMock, patch -import endpoint_utils +from moto import mock_iam -# Constants +import endpoint_utils from tests import test_constants +# Constants SUPPORTED_ENDPOINTS = ["opensearch", "elasticsearch"] INSECURE_KEY = "insecure" CONNECTION_KEY = "connection" @@ -20,17 +22,22 @@ # Utility method to create a test plugin config def create_plugin_config(host_list: list[str], - user: Optional[str] = None, - password: Optional[str] = None, - disable_auth: Optional[bool] = None) -> dict: + basic_auth_tuple: Optional[tuple[Optional[str], Optional[str]]] = None, + aws_config_snippet: Optional[dict] = None, + disable_auth: Optional[bool] = None + ) -> dict: config = dict() config["hosts"] = host_list - if user: - config["username"] = user - if password: - config["password"] = password - if disable_auth is not None: + if disable_auth: config["disable_authentication"] = disable_auth + elif basic_auth_tuple: + user, password = basic_auth_tuple + if user: + config["username"] = user + if password: + config["password"] = password + elif aws_config_snippet: + config.update(aws_config_snippet) return config @@ -93,18 +100,18 @@ def get_endpoint_info_from_plugin_config(self): self.assertIsNone(result.get_auth()) self.assertTrue(result.is_verify_ssl()) # Invalid auth config - test_config = create_plugin_config([host_input], test_user) + test_config = create_plugin_config([host_input]) result = endpoint_utils.get_endpoint_info_from_plugin_config(test_config) self.assertEqual(expected_endpoint, result.get_url()) self.assertIsNone(result.get_auth()) # Valid auth config - test_config = create_plugin_config([host_input], user=test_user, password=test_password) + test_config = create_plugin_config([host_input], (test_user, test_password)) result = endpoint_utils.get_endpoint_info_from_plugin_config(test_config) self.assertEqual(expected_endpoint, result.get_url()) self.assertEqual(test_user, result.get_auth()[0]) self.assertEqual(test_password, result.get_auth()[1]) # Array of hosts uses the first entry - test_config = create_plugin_config([host_input, "other_host"], test_user, test_password) + test_config = create_plugin_config([host_input, "other_host"], (test_user, test_password)) result = endpoint_utils.get_endpoint_info_from_plugin_config(test_config) self.assertEqual(expected_endpoint, result.get_url()) self.assertEqual(test_user, result.get_auth()[0]) @@ -123,25 +130,86 @@ def test_validate_plugin_config_missing_auth(self): test_data = create_config_section(create_plugin_config(["host"])) self.assertRaises(ValueError, endpoint_utils.get_endpoint_info_from_pipeline_config, test_data, TEST_KEY) - def test_validate_plugin_config_missing_password(self): - test_data = create_config_section(create_plugin_config(["host"], user="test", disable_auth=False)) - self.assertRaises(ValueError, endpoint_utils.get_endpoint_info_from_pipeline_config, test_data, TEST_KEY) - - def test_validate_plugin_config_missing_user(self): - test_data = create_config_section(create_plugin_config(["host"], password="test")) - self.assertRaises(ValueError, endpoint_utils.get_endpoint_info_from_pipeline_config, test_data, TEST_KEY) - def test_validate_plugin_config_auth_disabled(self): - test_data = create_config_section(create_plugin_config(["host"], user="test", disable_auth=True)) + test_data = create_config_section(create_plugin_config(["host"], ("test", None), disable_auth=True)) # Should complete without errors endpoint_utils.get_endpoint_info_from_pipeline_config(test_data, TEST_KEY) - def test_validate_plugin_config_happy_case(self): - plugin_config = create_plugin_config(["host"], "user", "password") + def test_validate_plugin_config_basic_auth(self): + plugin_config = create_plugin_config(["host"], ("user", "password")) test_data = create_config_section(plugin_config) # Should complete without errors endpoint_utils.get_endpoint_info_from_pipeline_config(test_data, TEST_KEY) + def test_validate_auth_missing_password(self): + test_plugin_config = create_plugin_config(["host"], ("test", None), disable_auth=False) + self.assertRaises(ValueError, endpoint_utils.validate_auth, TEST_KEY, test_plugin_config) + + def test_validate_auth_missing_user(self): + test_plugin_config = create_plugin_config(["host"], (None, "test")) + self.assertRaises(ValueError, endpoint_utils.validate_auth, TEST_KEY, test_plugin_config) + + def test_validate_auth_bad_empty_config(self): + test_plugin_config = create_plugin_config(["host"], aws_config_snippet={}) + self.assertRaises(ValueError, endpoint_utils.validate_auth, TEST_KEY, test_plugin_config) + + @patch('endpoint_utils.get_aws_region') + # Note that mock objects are passed bottom-up from the patch order above + def test_validate_auth_aws_sigv4(self, mock_get_aws_region: MagicMock): + test_plugin_config = create_plugin_config(["host"], aws_config_snippet={"aws_sigv4": False}) + self.assertRaises(ValueError, endpoint_utils.validate_auth, TEST_KEY, test_plugin_config) + mock_get_aws_region.assert_not_called() + test_plugin_config = create_plugin_config(["host"], aws_config_snippet={"aws_sigv4": True}) + # Should complete without errors + endpoint_utils.validate_auth(TEST_KEY, test_plugin_config) + mock_get_aws_region.assert_called_once() + mock_get_aws_region.reset_mock() + # "aws" is expected to be a section so the check is only for the presence of the key + test_plugin_config = create_plugin_config(["host"], aws_config_snippet={"aws": False}) + endpoint_utils.validate_auth(TEST_KEY, test_plugin_config) + mock_get_aws_region.assert_called_once() + + @patch('endpoint_utils.get_aws_region') + @patch('endpoint_utils.get_aws_sigv4_auth') + # Note that mock objects are passed bottom-up from the patch order above + def test_get_auth_aws_sigv4(self, mock_get_sigv4_auth: MagicMock, mock_get_aws_region: MagicMock): + # AWS SigV4 key specified, but disabled + test_plugin_config = create_plugin_config(["host"], aws_config_snippet={"aws_sigv4": False}) + result = endpoint_utils.get_auth(test_plugin_config) + self.assertIsNone(result) + mock_get_sigv4_auth.assert_not_called() + # AWS SigV4 key enabled + expected_region = "region" + mock_get_aws_region.return_value = expected_region + test_plugin_config = create_plugin_config(["host"], aws_config_snippet={"aws_sigv4": True}) + result = endpoint_utils.get_auth(test_plugin_config) + self.assertIsNotNone(result) + mock_get_sigv4_auth.assert_called_once_with(expected_region, False) + + @patch('endpoint_utils.get_aws_region') + @patch('endpoint_utils.get_aws_sigv4_auth') + # Note that mock objects are passed bottom-up from the patch order above + def test_get_auth_aws_config(self, mock_get_sigv4_auth: MagicMock, mock_get_aws_region: MagicMock): + expected_region = "region" + mock_get_aws_region.return_value = expected_region + test_plugin_config = create_plugin_config(["host"], aws_config_snippet={"aws": {"key": "value"}}) + result = endpoint_utils.get_auth(test_plugin_config) + self.assertIsNotNone(result) + mock_get_sigv4_auth.assert_called_once_with(expected_region, False) + mock_get_aws_region.assert_called_once() + + @patch('endpoint_utils.get_aws_region') + @patch('endpoint_utils.get_aws_sigv4_auth') + # Note that mock objects are passed bottom-up from the patch order above + def test_get_auth_aws_sigv4_serverless(self, mock_get_sigv4_auth: MagicMock, mock_get_aws_region: MagicMock): + expected_region = "region" + mock_get_aws_region.return_value = expected_region + test_plugin_config = create_plugin_config(["host"], aws_config_snippet={"aws": {"serverless": True}}) + result = endpoint_utils.get_auth(test_plugin_config) + self.assertIsNotNone(result) + mock_get_sigv4_auth.assert_called_once_with(expected_region, True) + mock_get_aws_region.assert_called_once() + def test_validate_pipeline_missing_required_keys(self): # Test cases: # - Empty input @@ -159,6 +227,56 @@ def test_validate_pipeline_config_happy_case(self): endpoint_utils.get_endpoint_info_from_pipeline_config(test_config, "sink") self.assertIsNotNone(result) + @patch('endpoint_utils.__derive_aws_region_from_url') + def test_get_aws_region_aws_sigv4(self, mock_derive_region: MagicMock): + derived_value = "derived" + mock_derive_region.return_value = derived_value + aws_sigv4_config = dict() + aws_sigv4_config["aws_sigv4"] = True + aws_sigv4_config["aws_region"] = "test" + self.assertEqual("test", endpoint_utils.get_aws_region( + create_plugin_config(["host"], aws_config_snippet=aws_sigv4_config))) + mock_derive_region.assert_not_called() + del aws_sigv4_config["aws_region"] + self.assertEqual(derived_value, endpoint_utils.get_aws_region( + create_plugin_config(["host"], aws_config_snippet=aws_sigv4_config))) + mock_derive_region.assert_called_once() + + @patch('endpoint_utils.__derive_aws_region_from_url') + def test_get_aws_region_aws_config(self, mock_derive_region: MagicMock): + derived_value = "derived" + mock_derive_region.return_value = derived_value + test_config = create_plugin_config(["host"], aws_config_snippet={"aws": {"region": "test"}}) + self.assertEqual("test", endpoint_utils.get_aws_region(test_config)) + mock_derive_region.assert_not_called() + test_config = create_plugin_config(["host"], aws_config_snippet={"aws": {"serverless": True}}) + self.assertEqual(derived_value, endpoint_utils.get_aws_region(test_config)) + mock_derive_region.assert_called_once() + # Invalid configuration + test_config = create_plugin_config(["host"], aws_config_snippet={"aws": True}) + self.assertRaises(ValueError, endpoint_utils.get_aws_region, test_config) + + def test_derive_aws_region(self): + # Custom endpoint that does not match regex + test_config = create_plugin_config(["https://www.custom.endpoint.amazon.com"], + aws_config_snippet={"aws_sigv4": True}) + self.assertRaises(ValueError, endpoint_utils.get_aws_region, test_config) + # Non-matching service name + test_config = create_plugin_config(["test123.test-region.s3.amazonaws.com"], + aws_config_snippet={"aws_sigv4": True}) + self.assertRaises(ValueError, endpoint_utils.get_aws_region, test_config) + test_config = create_plugin_config(["test-123.test-region.es.amazonaws.com"], + aws_config_snippet={"aws": {"serverless": True}}) + # Should return region successfully + self.assertEqual("test-region", endpoint_utils.get_aws_region(test_config)) + + @mock_iam + def test_get_aws_sigv4_auth(self): + result = endpoint_utils.get_aws_sigv4_auth("test") + self.assertEqual(result.service, "es") + result = endpoint_utils.get_aws_sigv4_auth("test", True) + self.assertEqual(result.service, "aoss") + if __name__ == '__main__': unittest.main()