From 3b6d2c4a4337c95212ca16d2daacb9bef0b92005 Mon Sep 17 00:00:00 2001 From: drf7 <83355241+drf7@users.noreply.github.com> Date: Thu, 8 Aug 2024 12:57:33 +0200 Subject: [PATCH] use gcs sa credentials (#2382) --- nucliadb/src/nucliadb/writer/tus/gcs.py | 26 +++++++++++++---- .../src/nucliadb_utils/storages/gcs.py | 29 ++++++++++++++----- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/nucliadb/src/nucliadb/writer/tus/gcs.py b/nucliadb/src/nucliadb/writer/tus/gcs.py index 91abb4bd9a..effbb4c87a 100644 --- a/nucliadb/src/nucliadb/writer/tus/gcs.py +++ b/nucliadb/src/nucliadb/writer/tus/gcs.py @@ -28,12 +28,13 @@ import uuid from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -from datetime import datetime from typing import Optional from urllib.parse import quote_plus import aiohttp import backoff +import google.auth.transport.requests # type: ignore +from google.auth.exceptions import DefaultCredentialsError # type: ignore from oauth2client.service_account import ServiceAccountCredentials # type: ignore from nucliadb.writer import logger @@ -88,9 +89,11 @@ async def get_access_headers(self): return {"AUTHORIZATION": f"Bearer {token}"} def _get_access_token(self): - access_token = self._credentials.get_access_token() - self._creation_access_token = datetime.now() - return access_token.access_token + if self._credentials.expired or self._credentials.valid is False: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + + return self._credentials.token async def finalize(self): if self.session is not None: @@ -115,6 +118,11 @@ async def initialize( self._credentials = None + if json_credentials is None: + self._json_credentials = None + elif isinstance(json_credentials, str) and json_credentials.strip() == "": + self._json_credentials = None + if json_credentials is not None: self.json_credentials_file = os.path.join(tempfile.mkdtemp(), "gcs_credentials.json") open(self.json_credentials_file, "w").write( @@ -123,6 +131,12 @@ async def initialize( self._credentials = ServiceAccountCredentials.from_json_keyfile_name( self.json_credentials_file, SCOPES ) + else: + try: + self._credentials, _ = google.auth.default() + except DefaultCredentialsError: + logger.warning("Setting up without credentials as couldn't find workload identity") + self._credentials = None loop = asyncio.get_event_loop() self.session = aiohttp.ClientSession(loop=loop) @@ -132,7 +146,9 @@ async def check_exists(self, bucket_name: str): raise AttributeError() headers = await self.get_access_headers() - url = f"{self.object_base_url}/{bucket_name}?project={self.project}" + # Using object access url instead of bucket access to avoid + # giving admin permission to the SA, needed to GET a bucket + url = f"{self.object_base_url}/{bucket_name}/o" async with self.session.get( url, headers=headers, diff --git a/nucliadb_utils/src/nucliadb_utils/storages/gcs.py b/nucliadb_utils/src/nucliadb_utils/storages/gcs.py index 71bfaf2fe5..3ddccc97b3 100644 --- a/nucliadb_utils/src/nucliadb_utils/storages/gcs.py +++ b/nucliadb_utils/src/nucliadb_utils/storages/gcs.py @@ -34,6 +34,7 @@ import backoff import google.auth.transport.requests # type: ignore import yarl +from google.auth.exceptions import DefaultCredentialsError # type: ignore from google.oauth2 import service_account # type: ignore from nucliadb_protos.resources_pb2 import CloudFile @@ -458,12 +459,25 @@ def __init__( url: str = "https://www.googleapis.com", scopes: Optional[List[str]] = None, ): - if account_credentials is not None: + if account_credentials is None: + self._json_credentials = None + elif isinstance(account_credentials, str) and account_credentials.strip() == "": + self._json_credentials = None + else: self._json_credentials = json.loads(base64.b64decode(account_credentials)) + + if self._json_credentials is not None: self._credentials = service_account.Credentials.from_service_account_info( self._json_credentials, scopes=DEFAULT_SCOPES if scopes is None else scopes, ) + else: + try: + self._credentials, self._project = google.auth.default() + except DefaultCredentialsError: + logger.warning("Setting up without credentials as couldn't find workload identity") + self._credentials = None + self.source = CloudFile.GCS self.deadletter_bucket = deadletter_bucket self.indexing_bucket = indexing_bucket @@ -473,16 +487,15 @@ def __init__( # https://cloud.google.com/storage/docs/bucket-locations self._bucket_labels = labels or {} self._executor = executor - self._creation_access_token = datetime.now() self._upload_url = url + "/upload/storage/v1/b/{bucket}/o?uploadType=resumable" # noqa self.object_base_url = url + "/storage/v1/b" self._client = None def _get_access_token(self): - if self._credentials.valid is False: - req = google.auth.transport.requests.Request() - self._credentials.refresh(req) - self._creation_access_token = datetime.now() + if self._credentials.expired or self._credentials.valid is False: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + return self._credentials.token @storage_ops_observer.wrap({"type": "initialize"}) @@ -552,7 +565,9 @@ async def check_exists(self, bucket_name: str): raise AttributeError() headers = await self.get_access_headers() - url = f"{self.object_base_url}/{bucket_name}?project={self._project}" + # Using object access url instead of bucket access to avoid + # giving admin permission to the SA, needed to GET a bucket + url = f"{self.object_base_url}/{bucket_name}/o" async with self.session.get( url, headers=headers,