Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic Authentication Token Refresh Support #107

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 97 additions & 15 deletions src/pds/ingress/client/pds_ingress_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import hashlib
import json
import os
import sched
import time
from datetime import datetime
from datetime import timezone
from threading import Thread

import backoff
import pds.ingress.util.log_util as log_util
Expand All @@ -26,8 +28,14 @@
from pds.ingress.util.node_util import NodeUtil
from pds.ingress.util.path_util import PathUtil

BEARER_TOKEN = None
"""Placeholder for authentication bearer token used to authenticate to API gateway"""

PARALLEL = Parallel(require="sharedmem")

REFRESH_SCHEDULER = sched.scheduler(time.time, time.sleep)
"""Scheduler object used to periodically refresh the Cognito authentication token"""

SUMMARY_TABLE = {
"uploaded": set(),
"skipped": set(),
Expand All @@ -54,7 +62,7 @@ def backoff_logger(details):
logger.warning(f"Total time elapsed: {details['elapsed']:0.1f} seconds.")


def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_config):
def _perform_ingress(ingress_path, node_id, prefix, api_gateway_config):
"""
Performs an ingress request and transfer to S3 using credentials obtained from
Cognito. This helper function is intended for use with a Joblib parallelized
Expand All @@ -69,9 +77,6 @@ def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_co
prefix : str
Global path prefix to trim from the ingress path before making the
ingress request.
bearer_token : str
JWT Bearer token string obtained from a successful authentication to
Cognito.
api_gateway_config : dict
Dictionary containing configuration details for the API Gateway instance
used to request ingress.
Expand All @@ -88,9 +93,7 @@ def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_co
trimmed_path = PathUtil.trim_ingress_path(ingress_path, prefix)

try:
s3_ingress_url = request_file_for_ingress(
object_body, ingress_path, trimmed_path, node_id, api_gateway_config, bearer_token
)
s3_ingress_url = request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, api_gateway_config)

if s3_ingress_url:
ingress_file_to_s3(object_body, ingress_path, trimmed_path, s3_ingress_url)
Expand All @@ -104,6 +107,72 @@ def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_co
SUMMARY_TABLE["failed"].add(trimmed_path)


def _schedule_token_refresh(refresh_token, token_expiration, offset=60):
"""
Schedules a refresh of the Cognito authentication token using the provided
refresh token. This function is inteded to be executed with a separate daemon
thread to prevent blocking on the main thread.

Parameters
----------
refresh_token : str
The refresh token provided by Cognito.
token_expiration : int
Time in seconds before the current authentication token is expected to
expire.
offset : int, optional
Offset in seconds to subtract from the token expiration duration to ensure
a refresh occurs some time before the expiration deadline. Defaults to
60 seconds.

"""
# Offset the expiration, so we refresh a bit ahead of time
delay = max(token_expiration - offset, offset)

REFRESH_SCHEDULER.enter(delay, priority=1, action=_token_refresh_event, argument=(refresh_token,))

# Kick off scheduler
# Since this function should be running in a seperate thread, it should be
# safe to block until the scheduler fires the next refresh event
REFRESH_SCHEDULER.run(blocking=True)


def _token_refresh_event(refresh_token):
"""
Callback event evoked when refresh scheduler kicks off a Cognito token refresh.
This function will submit the refresh request to Cognito, and if successful,
schedules the next refresh interval.

Parameters
----------
refresh_token : str
The refresh token provided by Cognito.

"""
global BEARER_TOKEN

logger = get_logger(__name__)

logger.debug("_token_refresh_event fired")

config = ConfigUtil.get_config()

cognito_config = config["COGNITO"]

# Submit the token refresh request via boto3
authentication_result = AuthUtil.refresh_auth_token(cognito_config, refresh_token)

# Update the authentication token referenced by each ingress worker thread,
# as well as the Cloudwatch logger
BEARER_TOKEN = AuthUtil.create_bearer_token(authentication_result)
log_util.CLOUDWATCH_HANDLER.bearer_token = BEARER_TOKEN

# Schedule the next refresh iteration
expiration = authentication_result["ExpiresIn"]

_schedule_token_refresh(refresh_token, expiration)


@backoff.on_exception(
backoff.constant,
requests.exceptions.RequestException,
Expand All @@ -112,7 +181,7 @@ def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_co
on_backoff=backoff_logger,
interval=15,
)
def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, api_gateway_config, bearer_token):
def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, api_gateway_config):
"""
Submits a request for file ingress to the PDS Ingress App API.

Expand All @@ -129,9 +198,6 @@ def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, a
api_gateway_config : dict
Dictionary or dictionary-like containing key/value pairs used to
configure the API Gateway endpoint url.
bearer_token : str
The Bearer token authorizing the current user to access the Ingress
Lambda function.

Returns
-------
Expand All @@ -148,6 +214,8 @@ def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, a
If the request to the Ingress Service fails.

"""
global BEARER_TOKEN

logger = get_logger(__name__)

logger.info(f"{trimmed_path} : Requesting ingress for node ID {node_id}")
Expand All @@ -173,7 +241,7 @@ def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, a
params = {"node": node_id, "node_name": NodeUtil.node_id_to_long_name[node_id]}
payload = {"url": trimmed_path}
headers = {
"Authorization": bearer_token,
"Authorization": BEARER_TOKEN,
"UserGroup": NodeUtil.node_id_to_group_name(node_id),
"ContentMD5": md5_digest,
"ContentLength": str(file_size),
Expand Down Expand Up @@ -410,6 +478,8 @@ def main():
and dry-run is not enabled.

"""
global BEARER_TOKEN

parser = setup_argparser()

args = parser.parse_args()
Expand All @@ -435,18 +505,30 @@ def main():

authentication_result = AuthUtil.perform_cognito_authentication(cognito_config)

bearer_token = AuthUtil.create_bearer_token(authentication_result)
BEARER_TOKEN = AuthUtil.create_bearer_token(authentication_result)

# Set the bearer token on the CloudWatchHandler singleton, so it can
# be used to authenticate submissions to the CloudWatch Logs API endpoint
log_util.CLOUDWATCH_HANDLER.bearer_token = bearer_token
log_util.CLOUDWATCH_HANDLER.bearer_token = BEARER_TOKEN
log_util.CLOUDWATCH_HANDLER.node_id = node_id

# Schedule automatic refresh of the Cognito token prior to expiration within
# a separate thread. Since this thread will not allocate any
# resources, we can designate the thread as a daemon, so it will not
# preempt completion of the main thread.
refresh_thread = Thread(
target=_schedule_token_refresh,
name="token_refresh",
args=(authentication_result["RefreshToken"], authentication_result["ExpiresIn"]),
daemon=True,
)
refresh_thread.start()

# Perform uploads in parallel using the number of requested threads
PARALLEL.n_jobs = args.num_threads

PARALLEL(
delayed(_perform_ingress)(resolved_ingress_path, node_id, args.prefix, bearer_token, config["API_GATEWAY"])
delayed(_perform_ingress)(resolved_ingress_path, node_id, args.prefix, config["API_GATEWAY"])
for resolved_ingress_path in resolved_ingress_paths
)

Expand Down
42 changes: 42 additions & 0 deletions src/pds/ingress/util/auth_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,45 @@ def create_bearer_token(authentication_result):
bearer_token = f"Bearer {access_token}"

return bearer_token

@staticmethod
def refresh_auth_token(cognito_config, refresh_token):
"""
Performs a Cognito authentication token refresh request, returning a
new authentication token for use with the worker threads and CloudWatch
logger.

Parameters
----------
cognito_config : dict
The Cognito configuration parameters as read from the INI config.
refresh_token : str
The refresh token provided by Cognito.

Returns
-------
authentication_result : dict
Dictionary containing the results of the authentication refresh.
This includes an updated authentication token and expiration time.

"""
logger = get_logger(__name__)

client = boto3.client("cognito-idp", region_name=cognito_config["region"])

auth_params = {"REFRESH_TOKEN": refresh_token}

logger.info("Refreshing authentication token")

try:
response = client.initiate_auth(
AuthFlow="REFRESH_TOKEN_AUTH", AuthParameters=auth_params, ClientId=cognito_config["client_id"]
)
except Exception as err:
raise RuntimeError(f"Failed to refresh Cognito authentication token, reason: {str(err)}") from err

logger.info("Token refresh successful")

authentication_result = response["AuthenticationResult"]

return authentication_result