diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index ecb5e65c..2c6afef5 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -18,7 +18,6 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING import aiohttp -from google.auth.transport.requests import Request from google.cloud.alloydb.connector.version import __version__ as version @@ -116,10 +115,6 @@ async def _get_metadata( """ logger.debug(f"['{project}/{region}/{cluster}/{name}']: Requesting metadata") - if not self._credentials.valid: - request = Request() - self._credentials.refresh(request) - headers = { "Authorization": f"Bearer {self._credentials.token}", } @@ -167,10 +162,6 @@ async def _get_client_certificate( """ logger.debug(f"['{project}/{region}/{cluster}']: Requesting client certificate") - if not self._credentials.valid: - request = Request() - self._credentials.refresh(request) - headers = { "Authorization": f"Bearer {self._credentials.token}", } diff --git a/google/cloud/alloydb/connector/instance.py b/google/cloud/alloydb/connector/instance.py index 75169bce..c3ae25a0 100644 --- a/google/cloud/alloydb/connector/instance.py +++ b/google/cloud/alloydb/connector/instance.py @@ -20,6 +20,9 @@ import re from typing import Tuple, TYPE_CHECKING +from google.auth.credentials import TokenState +from google.auth.transport import requests + from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError from google.cloud.alloydb.connector.exceptions import RefreshError from google.cloud.alloydb.connector.rate_limiter import AsyncRateLimiter @@ -130,6 +133,11 @@ async def _perform_refresh(self) -> RefreshResult: try: await self._refresh_rate_limiter.acquire() priv_key, pub_key = await self._keys + + # before making AlloyDB API calls, refresh creds if required + if not self._client._credentials.token_state == TokenState.FRESH: + self._client._credentials.refresh(requests.Request()) + # fetch metadata metadata_task = asyncio.create_task( self._client._get_metadata( diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index a052fb3f..eb0cbd17 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -19,13 +19,15 @@ import ipaddress import ssl import struct -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple from cryptography import x509 from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID +from google.auth.credentials import _helpers +from google.auth.credentials import TokenState import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb @@ -35,7 +37,7 @@ def __init__(self) -> None: self.token: Optional[str] = None self.expiry: Optional[datetime] = None - def refresh(self, request: Callable) -> None: + def refresh(self, _: Callable) -> None: """Refreshes the access token.""" self.token = "12345" self.expiry = datetime.now(timezone.utc) + timedelta(minutes=60) @@ -51,13 +53,33 @@ def expired(self) -> bool: return False if not self.expiry else True @property - def valid(self) -> bool: - """Checks the validity of the credentials. - - This is True if the credentials have a token and the token - is not expired. + def token_state( + self, + ) -> Literal[TokenState.FRESH, TokenState.STALE, TokenState.INVALID]: """ - return self.token is not None and not self.expired + Tracks the state of a token. + FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry. + STALE: The token is close to expired, and should be refreshed. The token can be used normally. + INVALID: The token is expired or invalid. The token cannot be used for a normal operation. + """ + if self.token is None: + return TokenState.INVALID + + # Credentials that can't expire are always treated as fresh. + if self.expiry is None: + return TokenState.FRESH + + expired = datetime.now(timezone.utc) >= self.expiry + if expired: + return TokenState.INVALID + + is_stale = datetime.now(timezone.utc) >= ( + self.expiry - _helpers.REFRESH_THRESHOLD + ) + if is_stale: + return TokenState.STALE + + return TokenState.FRESH def generate_cert( @@ -180,6 +202,7 @@ def __init__( self.instance = FakeInstance() if instance is None else instance self.closed = False self._user_agent = f"test-user-agent+{driver}" + self._credentials = FakeCredentials() async def _get_metadata(self, *args: Any, **kwargs: Any) -> str: return self.instance.ip_addrs