Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use google.auth TokenState for credentials validity #321

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING

import aiohttp
from google.auth.transport.requests import Request

from google.cloud.alloydb.connector.version import __version__ as version

Expand Down Expand Up @@ -116,10 +115,6 @@ async def _get_metadata(
"""
logger.debug(f"['{project}/{region}/{cluster}/{name}']: Requesting metadata")

if not self._credentials.valid:
request = Request()
self._credentials.refresh(request)

headers = {
"Authorization": f"Bearer {self._credentials.token}",
}
Expand Down Expand Up @@ -167,10 +162,6 @@ async def _get_client_certificate(
"""
logger.debug(f"['{project}/{region}/{cluster}']: Requesting client certificate")

if not self._credentials.valid:
request = Request()
self._credentials.refresh(request)

headers = {
"Authorization": f"Bearer {self._credentials.token}",
}
Expand Down
8 changes: 8 additions & 0 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import re
from typing import Tuple, TYPE_CHECKING

from google.auth.credentials import TokenState
from google.auth.transport import requests

from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
from google.cloud.alloydb.connector.exceptions import RefreshError
from google.cloud.alloydb.connector.rate_limiter import AsyncRateLimiter
Expand Down Expand Up @@ -130,6 +133,11 @@ async def _perform_refresh(self) -> RefreshResult:
try:
await self._refresh_rate_limiter.acquire()
priv_key, pub_key = await self._keys

# before making AlloyDB API calls, refresh creds if required
if not self._client._credentials.token_state == TokenState.FRESH:
self._client._credentials.refresh(requests.Request())

# fetch metadata
metadata_task = asyncio.create_task(
self._client._get_metadata(
Expand Down
39 changes: 31 additions & 8 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import ipaddress
import ssl
import struct
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from google.auth.credentials import _helpers
from google.auth.credentials import TokenState

import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb

Expand All @@ -35,7 +37,7 @@ def __init__(self) -> None:
self.token: Optional[str] = None
self.expiry: Optional[datetime] = None

def refresh(self, request: Callable) -> None:
def refresh(self, _: Callable) -> None:
"""Refreshes the access token."""
self.token = "12345"
self.expiry = datetime.now(timezone.utc) + timedelta(minutes=60)
Expand All @@ -51,13 +53,33 @@ def expired(self) -> bool:
return False if not self.expiry else True

@property
def valid(self) -> bool:
"""Checks the validity of the credentials.

This is True if the credentials have a token and the token
is not expired.
def token_state(
self,
) -> Literal[TokenState.FRESH, TokenState.STALE, TokenState.INVALID]:
"""
return self.token is not None and not self.expired
Tracks the state of a token.
FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry.
STALE: The token is close to expired, and should be refreshed. The token can be used normally.
INVALID: The token is expired or invalid. The token cannot be used for a normal operation.
"""
if self.token is None:
return TokenState.INVALID

# Credentials that can't expire are always treated as fresh.
if self.expiry is None:
return TokenState.FRESH

expired = datetime.now(timezone.utc) >= self.expiry
if expired:
return TokenState.INVALID

is_stale = datetime.now(timezone.utc) >= (
self.expiry - _helpers.REFRESH_THRESHOLD
)
if is_stale:
return TokenState.STALE

return TokenState.FRESH


def generate_cert(
Expand Down Expand Up @@ -180,6 +202,7 @@ def __init__(
self.instance = FakeInstance() if instance is None else instance
self.closed = False
self._user_agent = f"test-user-agent+{driver}"
self._credentials = FakeCredentials()

async def _get_metadata(self, *args: Any, **kwargs: Any) -> str:
return self.instance.ip_addrs
Expand Down
Loading