Skip to content

Commit

Permalink
use gcs sa credentials (#2382)
Browse files Browse the repository at this point in the history
  • Loading branch information
drf7 authored Aug 8, 2024
1 parent 19920db commit 3b6d2c4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
26 changes: 21 additions & 5 deletions nucliadb/src/nucliadb/writer/tus/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
import uuid
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from typing import Optional
from urllib.parse import quote_plus

import aiohttp
import backoff
import google.auth.transport.requests # type: ignore
from google.auth.exceptions import DefaultCredentialsError # type: ignore
from oauth2client.service_account import ServiceAccountCredentials # type: ignore

from nucliadb.writer import logger
Expand Down Expand Up @@ -88,9 +89,11 @@ async def get_access_headers(self):
return {"AUTHORIZATION": f"Bearer {token}"}

def _get_access_token(self):
access_token = self._credentials.get_access_token()
self._creation_access_token = datetime.now()
return access_token.access_token
if self._credentials.expired or self._credentials.valid is False:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)

return self._credentials.token

async def finalize(self):
if self.session is not None:
Expand All @@ -115,6 +118,11 @@ async def initialize(

self._credentials = None

if json_credentials is None:
self._json_credentials = None
elif isinstance(json_credentials, str) and json_credentials.strip() == "":
self._json_credentials = None

if json_credentials is not None:
self.json_credentials_file = os.path.join(tempfile.mkdtemp(), "gcs_credentials.json")
open(self.json_credentials_file, "w").write(
Expand All @@ -123,6 +131,12 @@ async def initialize(
self._credentials = ServiceAccountCredentials.from_json_keyfile_name(
self.json_credentials_file, SCOPES
)
else:
try:
self._credentials, _ = google.auth.default()
except DefaultCredentialsError:
logger.warning("Setting up without credentials as couldn't find workload identity")
self._credentials = None

loop = asyncio.get_event_loop()
self.session = aiohttp.ClientSession(loop=loop)
Expand All @@ -132,7 +146,9 @@ async def check_exists(self, bucket_name: str):
raise AttributeError()

headers = await self.get_access_headers()
url = f"{self.object_base_url}/{bucket_name}?project={self.project}"
# Using object access url instead of bucket access to avoid
# giving admin permission to the SA, needed to GET a bucket
url = f"{self.object_base_url}/{bucket_name}/o"
async with self.session.get(
url,
headers=headers,
Expand Down
29 changes: 22 additions & 7 deletions nucliadb_utils/src/nucliadb_utils/storages/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import backoff
import google.auth.transport.requests # type: ignore
import yarl
from google.auth.exceptions import DefaultCredentialsError # type: ignore
from google.oauth2 import service_account # type: ignore

from nucliadb_protos.resources_pb2 import CloudFile
Expand Down Expand Up @@ -458,12 +459,25 @@ def __init__(
url: str = "https://www.googleapis.com",
scopes: Optional[List[str]] = None,
):
if account_credentials is not None:
if account_credentials is None:
self._json_credentials = None
elif isinstance(account_credentials, str) and account_credentials.strip() == "":
self._json_credentials = None
else:
self._json_credentials = json.loads(base64.b64decode(account_credentials))

if self._json_credentials is not None:
self._credentials = service_account.Credentials.from_service_account_info(
self._json_credentials,
scopes=DEFAULT_SCOPES if scopes is None else scopes,
)
else:
try:
self._credentials, self._project = google.auth.default()
except DefaultCredentialsError:
logger.warning("Setting up without credentials as couldn't find workload identity")
self._credentials = None

self.source = CloudFile.GCS
self.deadletter_bucket = deadletter_bucket
self.indexing_bucket = indexing_bucket
Expand All @@ -473,16 +487,15 @@ def __init__(
# https://cloud.google.com/storage/docs/bucket-locations
self._bucket_labels = labels or {}
self._executor = executor
self._creation_access_token = datetime.now()
self._upload_url = url + "/upload/storage/v1/b/{bucket}/o?uploadType=resumable" # noqa
self.object_base_url = url + "/storage/v1/b"
self._client = None

def _get_access_token(self):
if self._credentials.valid is False:
req = google.auth.transport.requests.Request()
self._credentials.refresh(req)
self._creation_access_token = datetime.now()
if self._credentials.expired or self._credentials.valid is False:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)

return self._credentials.token

@storage_ops_observer.wrap({"type": "initialize"})
Expand Down Expand Up @@ -552,7 +565,9 @@ async def check_exists(self, bucket_name: str):
raise AttributeError()

headers = await self.get_access_headers()
url = f"{self.object_base_url}/{bucket_name}?project={self._project}"
# Using object access url instead of bucket access to avoid
# giving admin permission to the SA, needed to GET a bucket
url = f"{self.object_base_url}/{bucket_name}/o"
async with self.session.get(
url,
headers=headers,
Expand Down

0 comments on commit 3b6d2c4

Please sign in to comment.