Skip to content

Commit

Permalink
Merge pull request #107 from NASA-PDS/104_token_refresh
Browse files Browse the repository at this point in the history
Automatic Authentication Token Refresh Support
  • Loading branch information
tloubrieu-jpl authored May 14, 2024
2 parents 1d1803e + 4ee82f3 commit 48d1a1d
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 15 deletions.
112 changes: 97 additions & 15 deletions src/pds/ingress/client/pds_ingress_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import hashlib
import json
import os
import sched
import time
from datetime import datetime
from datetime import timezone
from threading import Thread

import backoff
import pds.ingress.util.log_util as log_util
Expand All @@ -26,8 +28,14 @@
from pds.ingress.util.node_util import NodeUtil
from pds.ingress.util.path_util import PathUtil

BEARER_TOKEN = None
"""Placeholder for authentication bearer token used to authenticate to API gateway"""

PARALLEL = Parallel(require="sharedmem")

REFRESH_SCHEDULER = sched.scheduler(time.time, time.sleep)
"""Scheduler object used to periodically refresh the Cognito authentication token"""

SUMMARY_TABLE = {
"uploaded": set(),
"skipped": set(),
Expand All @@ -54,7 +62,7 @@ def backoff_logger(details):
logger.warning(f"Total time elapsed: {details['elapsed']:0.1f} seconds.")


def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_config):
def _perform_ingress(ingress_path, node_id, prefix, api_gateway_config):
"""
Performs an ingress request and transfer to S3 using credentials obtained from
Cognito. This helper function is intended for use with a Joblib parallelized
Expand All @@ -69,9 +77,6 @@ def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_co
prefix : str
Global path prefix to trim from the ingress path before making the
ingress request.
bearer_token : str
JWT Bearer token string obtained from a successful authentication to
Cognito.
api_gateway_config : dict
Dictionary containing configuration details for the API Gateway instance
used to request ingress.
Expand All @@ -88,9 +93,7 @@ def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_co
trimmed_path = PathUtil.trim_ingress_path(ingress_path, prefix)

try:
s3_ingress_url = request_file_for_ingress(
object_body, ingress_path, trimmed_path, node_id, api_gateway_config, bearer_token
)
s3_ingress_url = request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, api_gateway_config)

if s3_ingress_url:
ingress_file_to_s3(object_body, ingress_path, trimmed_path, s3_ingress_url)
Expand All @@ -104,6 +107,72 @@ def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_co
SUMMARY_TABLE["failed"].add(trimmed_path)


def _schedule_token_refresh(refresh_token, token_expiration, offset=60):
"""
Schedules a refresh of the Cognito authentication token using the provided
refresh token. This function is inteded to be executed with a separate daemon
thread to prevent blocking on the main thread.
Parameters
----------
refresh_token : str
The refresh token provided by Cognito.
token_expiration : int
Time in seconds before the current authentication token is expected to
expire.
offset : int, optional
Offset in seconds to subtract from the token expiration duration to ensure
a refresh occurs some time before the expiration deadline. Defaults to
60 seconds.
"""
# Offset the expiration, so we refresh a bit ahead of time
delay = max(token_expiration - offset, offset)

REFRESH_SCHEDULER.enter(delay, priority=1, action=_token_refresh_event, argument=(refresh_token,))

# Kick off scheduler
# Since this function should be running in a seperate thread, it should be
# safe to block until the scheduler fires the next refresh event
REFRESH_SCHEDULER.run(blocking=True)


def _token_refresh_event(refresh_token):
"""
Callback event evoked when refresh scheduler kicks off a Cognito token refresh.
This function will submit the refresh request to Cognito, and if successful,
schedules the next refresh interval.
Parameters
----------
refresh_token : str
The refresh token provided by Cognito.
"""
global BEARER_TOKEN

logger = get_logger(__name__)

logger.debug("_token_refresh_event fired")

config = ConfigUtil.get_config()

cognito_config = config["COGNITO"]

# Submit the token refresh request via boto3
authentication_result = AuthUtil.refresh_auth_token(cognito_config, refresh_token)

# Update the authentication token referenced by each ingress worker thread,
# as well as the Cloudwatch logger
BEARER_TOKEN = AuthUtil.create_bearer_token(authentication_result)
log_util.CLOUDWATCH_HANDLER.bearer_token = BEARER_TOKEN

# Schedule the next refresh iteration
expiration = authentication_result["ExpiresIn"]

_schedule_token_refresh(refresh_token, expiration)


@backoff.on_exception(
backoff.constant,
requests.exceptions.RequestException,
Expand All @@ -112,7 +181,7 @@ def _perform_ingress(ingress_path, node_id, prefix, bearer_token, api_gateway_co
on_backoff=backoff_logger,
interval=15,
)
def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, api_gateway_config, bearer_token):
def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, api_gateway_config):
"""
Submits a request for file ingress to the PDS Ingress App API.
Expand All @@ -129,9 +198,6 @@ def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, a
api_gateway_config : dict
Dictionary or dictionary-like containing key/value pairs used to
configure the API Gateway endpoint url.
bearer_token : str
The Bearer token authorizing the current user to access the Ingress
Lambda function.
Returns
-------
Expand All @@ -148,6 +214,8 @@ def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, a
If the request to the Ingress Service fails.
"""
global BEARER_TOKEN

logger = get_logger(__name__)

logger.info(f"{trimmed_path} : Requesting ingress for node ID {node_id}")
Expand All @@ -173,7 +241,7 @@ def request_file_for_ingress(object_body, ingress_path, trimmed_path, node_id, a
params = {"node": node_id, "node_name": NodeUtil.node_id_to_long_name[node_id]}
payload = {"url": trimmed_path}
headers = {
"Authorization": bearer_token,
"Authorization": BEARER_TOKEN,
"UserGroup": NodeUtil.node_id_to_group_name(node_id),
"ContentMD5": md5_digest,
"ContentLength": str(file_size),
Expand Down Expand Up @@ -410,6 +478,8 @@ def main():
and dry-run is not enabled.
"""
global BEARER_TOKEN

parser = setup_argparser()

args = parser.parse_args()
Expand All @@ -435,18 +505,30 @@ def main():

authentication_result = AuthUtil.perform_cognito_authentication(cognito_config)

bearer_token = AuthUtil.create_bearer_token(authentication_result)
BEARER_TOKEN = AuthUtil.create_bearer_token(authentication_result)

# Set the bearer token on the CloudWatchHandler singleton, so it can
# be used to authenticate submissions to the CloudWatch Logs API endpoint
log_util.CLOUDWATCH_HANDLER.bearer_token = bearer_token
log_util.CLOUDWATCH_HANDLER.bearer_token = BEARER_TOKEN
log_util.CLOUDWATCH_HANDLER.node_id = node_id

# Schedule automatic refresh of the Cognito token prior to expiration within
# a separate thread. Since this thread will not allocate any
# resources, we can designate the thread as a daemon, so it will not
# preempt completion of the main thread.
refresh_thread = Thread(
target=_schedule_token_refresh,
name="token_refresh",
args=(authentication_result["RefreshToken"], authentication_result["ExpiresIn"]),
daemon=True,
)
refresh_thread.start()

# Perform uploads in parallel using the number of requested threads
PARALLEL.n_jobs = args.num_threads

PARALLEL(
delayed(_perform_ingress)(resolved_ingress_path, node_id, args.prefix, bearer_token, config["API_GATEWAY"])
delayed(_perform_ingress)(resolved_ingress_path, node_id, args.prefix, config["API_GATEWAY"])
for resolved_ingress_path in resolved_ingress_paths
)

Expand Down
42 changes: 42 additions & 0 deletions src/pds/ingress/util/auth_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,45 @@ def create_bearer_token(authentication_result):
bearer_token = f"Bearer {access_token}"

return bearer_token

@staticmethod
def refresh_auth_token(cognito_config, refresh_token):
"""
Performs a Cognito authentication token refresh request, returning a
new authentication token for use with the worker threads and CloudWatch
logger.
Parameters
----------
cognito_config : dict
The Cognito configuration parameters as read from the INI config.
refresh_token : str
The refresh token provided by Cognito.
Returns
-------
authentication_result : dict
Dictionary containing the results of the authentication refresh.
This includes an updated authentication token and expiration time.
"""
logger = get_logger(__name__)

client = boto3.client("cognito-idp", region_name=cognito_config["region"])

auth_params = {"REFRESH_TOKEN": refresh_token}

logger.info("Refreshing authentication token")

try:
response = client.initiate_auth(
AuthFlow="REFRESH_TOKEN_AUTH", AuthParameters=auth_params, ClientId=cognito_config["client_id"]
)
except Exception as err:
raise RuntimeError(f"Failed to refresh Cognito authentication token, reason: {str(err)}") from err

logger.info("Token refresh successful")

authentication_result = response["AuthenticationResult"]

return authentication_result

0 comments on commit 48d1a1d

Please sign in to comment.