diff --git a/python/src/etos_api/library/docker.py b/python/src/etos_api/library/docker.py index 8e1ca04..308f30b 100644 --- a/python/src/etos_api/library/docker.py +++ b/python/src/etos_api/library/docker.py @@ -16,6 +16,7 @@ """Docker operations for the ETOS API.""" import time import logging +from typing import Optional, Mapping from threading import Lock import aiohttp @@ -33,12 +34,14 @@ class Docker: """ logger = logging.getLogger(__name__) + # The amount of time in seconds before the token expires where we should recreate it. + token_expire_modifier = 30 # In-memory database for stored authorization tokens. # This dictionary shares memory with all instances of `Docker`, by design. tokens = {} lock = Lock() - def token(self, manifest_url: str) -> str: + def token(self, manifest_url: str) -> Optional[str]: """Get a stored token, removing it if expired. :param manifest_url: URL the token has been stored for. @@ -54,7 +57,7 @@ def token(self, manifest_url: str) -> str: return token.get("token") async def head( - self, session: aiohttp.ClientSession, url: str, token: str = None + self, session: aiohttp.ClientSession, url: str, token: Optional[str] = None ) -> aiohttp.ClientResponse: """Make a HEAD request to a URL, adding token to headers if supplied. @@ -70,36 +73,78 @@ async def head( async with session.head(url, headers=headers) as response: return response - async def authorize( - self, session: aiohttp.ClientSession, response: aiohttp.ClientResponse - ) -> str: - """Get a token from an unauthorized request to image repository. + async def get_token_from_container_registry( + self, session: aiohttp.ClientSession, realm: str, parameters: dict + ) -> dict: + """Get a token from an unauthorized request to container repository. :param session: Client HTTP session to use for HTTP request. - :param response: HTTP response to get headers from. + :param realm: The realm to authorize against. + :param parameters: Parameters to use for the authorization request. :return: Response JSON from authorization request. """ - www_auth_header = response.headers.get("www-authenticate") - challenge = www_auth_header.replace("Bearer ", "") - parts = challenge.split(",") + async with session.get(realm, params=parameters) as response: + response.raise_for_status() + return await response.json() - url = None - query = {} - for part in parts: - key, value = part.split("=") - if key == "realm": - url = value.strip('"') - else: - query[key] = value.strip('"') + async def authorize( + self, session: aiohttp.ClientSession, response: aiohttp.ClientResponse, manifest_url: str + ) -> str: + """Authorize against container registry. + :param session: Client HTTP session to use for HTTP request. + :param response: Response from a previous request to parse headers from. + :param manifest_url: Manifest URL to query should a new auth request be needed. + :return: A token retrieved from container registry. + """ + parameters = await self.parse_headers(response.headers) + url = parameters.get("realm") if not isinstance(url, str) or not ( url.startswith("http://") or url.startswith("https://") ): - raise ValueError(f"No realm URL found in www-authenticate header: {www_auth_header}") + self.logger.warning("No realm in original request, retrying without a token") + with self.lock: + try: + del self.tokens[manifest_url] + except KeyError: + pass + response = await self.head(session, manifest_url) + parameters = await self.parse_headers(response.headers) + url = parameters.get("realm") + if not isinstance(url, str) or not ( + url.startswith("http://") or url.startswith("https://") + ): + raise ValueError( + f"No realm URL found in www-authenticate header: {response.headers}" + ) + url = parameters.pop("realm") + response_json = await self.get_token_from_container_registry(session, url, parameters) + with self.lock: + self.tokens[manifest_url] = { + "token": response_json.get("token"), + "expire": time.time() + + response_json.get("expires_in", 0.0) + - self.token_expire_modifier, + } + return "" + + async def parse_headers(self, headers: Mapping) -> dict: + """Parse the www-authenticate header and convert it to a dict. + + :param headers: Headers to parse. + :return: Dictionary of the keys in the www-authenticate header. + """ + www_auth_header = headers.get("www-authenticate") + if www_auth_header is None: + return {} + challenge = www_auth_header.replace("Bearer ", "") + parts = challenge.split(",") - async with session.get(url, params=query) as response: - response.raise_for_status() - return await response.json() + parameters = {} + for part in parts: + key, value = part.split("=") + parameters[key] = value.strip('"') + return parameters def tag(self, base: str) -> tuple[str, str]: """Figure out tag from a container image name. @@ -149,7 +194,7 @@ def repository(self, repo: str) -> tuple[str, str]: registry = DEFAULT_REGISTRY return registry, repo - async def digest(self, name: str) -> str: + async def digest(self, name: str) -> Optional[str]: """Get a sha256 digest from an image in an image repository. :param name: The name of the container image. @@ -167,13 +212,8 @@ async def digest(self, name: str) -> str: try: if response.status == 401 and "www-authenticate" in response.headers: self.logger.info("Generate a new authorization token for %r", manifest_url) - response_json = await self.authorize(session, response) - with self.lock: - self.tokens[manifest_url] = { - "token": response_json.get("token"), - "expire": time.time() + response_json.get("expires_in"), - } - response = await self.head(session, manifest_url, self.token(manifest_url)) + token = await self.authorize(session, response, manifest_url) + response = await self.head(session, manifest_url, token) digest = response.headers.get("Docker-Content-Digest") except aiohttp.ClientResponseError as exception: self.logger.error("Error getting container image %r", exception)