Skip to content

Commit

Permalink
[Fetch Migration] Added support for AWS SigV4
Browse files Browse the repository at this point in the history
Signed-off-by: Kartik Ganesh <[email protected]>
  • Loading branch information
kartg committed Nov 7, 2023
1 parent 0775d88 commit 60bc424
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 52 deletions.
3 changes: 2 additions & 1 deletion FetchMigration/python/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
coverage>=7.3.2
pur>=7.3.1
pur>=7.3.1
moto>=4.2.7
8 changes: 5 additions & 3 deletions FetchMigration/python/endpoint_info.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Optional

from requests_aws4auth import AWS4Auth


# Class that encapsulates endpoint information for an OpenSearch/Elasticsearch cluster
class EndpointInfo:
# Private member variables
__url: str
__auth: Optional[tuple]
__auth: Optional[tuple] | AWS4Auth
__verify_ssl: bool

def __init__(self, url: str, auth: tuple = None, verify_ssl: bool = True):
def __init__(self, url: str, auth: tuple | AWS4Auth = None, verify_ssl: bool = True):
self.__url = url
# Normalize url value to have trailing slash
if not url.endswith("/"):
Expand All @@ -31,7 +33,7 @@ def add_path(self, path: str) -> str:
def get_url(self) -> str:
return self.__url

def get_auth(self) -> Optional[tuple]:
def get_auth(self) -> Optional[tuple] | AWS4Auth:
return self.__auth

def is_verify_ssl(self) -> bool:
Expand Down
111 changes: 88 additions & 23 deletions FetchMigration/python/endpoint_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import re
from typing import Optional

from requests_aws4auth import AWS4Auth
from botocore.session import Session

from endpoint_info import EndpointInfo

# Constants
Expand All @@ -12,21 +16,61 @@
DISABLE_AUTH_KEY = "disable_authentication"
USER_KEY = "username"
PWD_KEY = "password"
AWS_SIGV4_KEY = "aws_sigv4"
AWS_REGION_KEY = "aws_region"
AWS_CONFIG_KEY = "aws"
AWS_CONFIG_REGION_KEY = "region"
IS_SERVERLESS_KEY = "serverless"
ES_SERVICE_NAME = "es"
AOSS_SERVICE_NAME = "aoss"
URL_REGION_PATTERN = re.compile(r"([\w-]*)\.(es|aoss)\.amazonaws\.com")


def __check_supported_endpoint(config: dict) -> Optional[tuple]:
def __get_url(plugin_config: dict) -> str:
# "hosts" can be a simple string, or an array of hosts for Logstash to hit.
# This tool needs one accessible host, so we pick the first entry in the latter case.
return plugin_config[HOSTS_KEY][0] if type(plugin_config[HOSTS_KEY]) is list else plugin_config[HOSTS_KEY]


# Helper function that attempts to extract the AWS region from a URL,
# assuming it is of the form *.<region>.<service>.amazonaws.com
def __derive_aws_region_from_url(url: str) -> Optional[str]:
match = URL_REGION_PATTERN.search(url)
if match:
# Index 0 returns the entire match, index 1 returns only the first group
return match.group(1)
return None


def get_aws_region(plugin_config: dict) -> str:
if plugin_config.get(AWS_SIGV4_KEY, False) and plugin_config.get(AWS_REGION_KEY, None) is not None:
return plugin_config[AWS_REGION_KEY]
elif plugin_config.get(AWS_CONFIG_KEY, None) is not None:
aws_config = plugin_config[AWS_CONFIG_KEY]
if type(aws_config) is not dict:
raise ValueError("Unexpected value for 'aws' configuration")
elif aws_config.get(AWS_CONFIG_REGION_KEY, None) is not None:
return aws_config[AWS_CONFIG_REGION_KEY]
# Region not explicitly defined, attempt to derive from URL
derived_region = __derive_aws_region_from_url(__get_url(plugin_config))
if derived_region is None:
raise ValueError("No region configured for AWS SigV4 auth, or derivable from host URL")
return derived_region


def __check_supported_endpoint(section_config: dict) -> Optional[tuple]:
for supported_type in SUPPORTED_PLUGINS:
if supported_type in config:
return supported_type, config[supported_type]
if supported_type in section_config:
return supported_type, section_config[supported_type]


# This config key may be either directly in the main dict (for sink)
# or inside a nested dict (for source). The default value is False.
def is_insecure(config: dict) -> bool:
if INSECURE_KEY in config:
return config[INSECURE_KEY]
elif CONNECTION_KEY in config and INSECURE_KEY in config[CONNECTION_KEY]:
return config[CONNECTION_KEY][INSECURE_KEY]
def is_insecure(plugin_config: dict) -> bool:
if INSECURE_KEY in plugin_config:
return plugin_config[INSECURE_KEY]
elif CONNECTION_KEY in plugin_config and INSECURE_KEY in plugin_config[CONNECTION_KEY]:
return plugin_config[CONNECTION_KEY][INSECURE_KEY]
return False


Expand All @@ -37,14 +81,19 @@ def validate_pipeline(pipeline: dict):
raise ValueError("Missing sink configuration in Data Prepper pipeline YAML")


def validate_auth(plugin_name: str, config: dict):
# Check if auth is disabled. If so, no further validation is required
if config.get(DISABLE_AUTH_KEY, False):
def validate_auth(plugin_name: str, plugin_config: dict):
# If auth is disabled, no further validation is required
if plugin_config.get(DISABLE_AUTH_KEY, False):
return
# If AWS SigV4 is configured, validate region
if plugin_config.get(AWS_SIGV4_KEY, False) or AWS_CONFIG_KEY in plugin_config:
# Raises a ValueError if region cannot be derived
get_aws_region(plugin_config)
return
# TODO AWS / SigV4
elif USER_KEY not in config:
# Validate basic auth
elif USER_KEY not in plugin_config:
raise ValueError("Invalid auth configuration (no username) for plugin: " + plugin_name)
elif PWD_KEY not in config:
elif PWD_KEY not in plugin_config:
raise ValueError("Invalid auth configuration (no password for username) for plugin: " + plugin_name)


Expand All @@ -65,23 +114,39 @@ def get_supported_endpoint_config(pipeline_config: dict, section_key: str) -> tu
return supported_tuple[0], supported_tuple[1]


# TODO Only supports basic auth for now
def get_auth(input_data: dict) -> Optional[tuple]:
if not input_data.get(DISABLE_AUTH_KEY, False) and USER_KEY in input_data and PWD_KEY in input_data:
return input_data[USER_KEY], input_data[PWD_KEY]
def get_aws_sigv4_auth(region: str, is_serverless: bool = False) -> AWS4Auth:
credentials = Session().get_credentials()
if not credentials:
raise ValueError("Unable to fetch AWS session credentials for SigV4 auth")
if is_serverless:
return AWS4Auth(region=region, service=AOSS_SERVICE_NAME, refreshable_credentials=credentials)
else:
return AWS4Auth(region=region, service=ES_SERVICE_NAME, refreshable_credentials=credentials)


def get_auth(plugin_config: dict) -> Optional[tuple] | AWS4Auth:
# Basic auth
if USER_KEY in plugin_config and PWD_KEY in plugin_config:
return plugin_config[USER_KEY], plugin_config[PWD_KEY]
elif plugin_config.get(AWS_SIGV4_KEY, False) or AWS_CONFIG_KEY in plugin_config:
is_serverless = False
# OpenSearch Serverless uses a different service name
if AWS_CONFIG_KEY in plugin_config:
aws_config = plugin_config[AWS_CONFIG_KEY]
if type(aws_config) is dict and aws_config.get(IS_SERVERLESS_KEY, False):
is_serverless = True
region = get_aws_region(plugin_config)
return get_aws_sigv4_auth(region, is_serverless)
return None


def get_endpoint_info_from_plugin_config(plugin_config: dict) -> EndpointInfo:
# "hosts" can be a simple string, or an array of hosts for Logstash to hit.
# This tool needs one accessible host, so we pick the first entry in the latter case.
url = plugin_config[HOSTS_KEY][0] if type(plugin_config[HOSTS_KEY]) is list else plugin_config[HOSTS_KEY]
# verify boolean will be the inverse of the insecure SSL key, if present
should_verify = not is_insecure(plugin_config)
return EndpointInfo(url, get_auth(plugin_config), should_verify)
return EndpointInfo(__get_url(plugin_config), get_auth(plugin_config), should_verify)


def get_endpoint_info_from_pipeline_config(pipeline_config: dict, section_key: str) -> EndpointInfo:

# Raises a ValueError if no supported endpoints are found
plugin_name, plugin_config = get_supported_endpoint_config(pipeline_config, section_key)
if HOSTS_KEY not in plugin_config:
Expand Down
1 change: 0 additions & 1 deletion FetchMigration/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ prometheus-client>=0.17.1
pyyaml>=6.0.1
requests>=2.31.0
requests-aws4auth>=1.2.3
requests-auth-aws-sigv4>=0.7
responses>=0.23.3
Loading

0 comments on commit 60bc424

Please sign in to comment.