Skip to content

Commit

Permalink
gcs sa credentials again (#2386)
Browse files Browse the repository at this point in the history
  • Loading branch information
drf7 authored Aug 9, 2024
1 parent bcae03b commit c0b8aca
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 25 deletions.
46 changes: 28 additions & 18 deletions nucliadb/src/nucliadb/writer/tus/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,20 @@
from __future__ import annotations

import asyncio
import base64
import json
import os
import socket
import tempfile
import uuid
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from typing import Optional
from urllib.parse import quote_plus

import aiohttp
import backoff
from oauth2client.service_account import ServiceAccountCredentials # type: ignore
import google.auth.compute_engine.credentials # type: ignore
import google.auth.transport.requests # type: ignore
from google.auth.exceptions import DefaultCredentialsError # type: ignore
from google.oauth2 import service_account

from nucliadb.writer import logger
from nucliadb.writer.tus.dm import FileDataManager
Expand Down Expand Up @@ -74,7 +73,7 @@ class GCloudBlobStore(BlobStore):
loop = None
upload_url: str
object_base_url: str
json_credentials: str
json_credentials: Optional[str]
bucket: str
location: str
project: str
Expand All @@ -88,9 +87,16 @@ async def get_access_headers(self):
return {"AUTHORIZATION": f"Bearer {token}"}

def _get_access_token(self):
access_token = self._credentials.get_access_token()
self._creation_access_token = datetime.now()
return access_token.access_token
if isinstance(self._credentials, google.auth.compute_engine.credentials.Credentials):
# google default auth object
if self._credentials.expired or self._credentials.valid is False:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)

return self._credentials.token
else:
access_token = self._credentials.get_access_token()
return access_token.access_token

async def finalize(self):
if self.session is not None:
Expand All @@ -112,17 +118,19 @@ async def initialize(
self.bucket_labels = bucket_labels
self.object_base_url = object_base_url + "/storage/v1/b"
self.upload_url = object_base_url + "/upload/storage/v1/b/{bucket}/o?uploadType=resumable" # noqa

self.json_credentials = json_credentials
self._credentials = None

if json_credentials is not None:
self.json_credentials_file = os.path.join(tempfile.mkdtemp(), "gcs_credentials.json")
open(self.json_credentials_file, "w").write(
base64.b64decode(json_credentials).decode("utf-8")
)
self._credentials = ServiceAccountCredentials.from_json_keyfile_name(
self.json_credentials_file, SCOPES
if self.json_credentials is not None and self.json_credentials.strip() != "":
self._credentials = service_account.Credentials.from_service_account_info(
self.json_credentials, scopes=SCOPES
)
else:
try:
self._credentials, self.project = google.auth.default()
except DefaultCredentialsError:
logger.warning("Setting up without credentials as couldn't find workload identity")
self._credentials = None

loop = asyncio.get_event_loop()
self.session = aiohttp.ClientSession(loop=loop)
Expand All @@ -132,7 +140,9 @@ async def check_exists(self, bucket_name: str):
raise AttributeError()

headers = await self.get_access_headers()
url = f"{self.object_base_url}/{bucket_name}?project={self.project}"
# Using object access url instead of bucket access to avoid
# giving admin permission to the SA, needed to GET a bucket
url = f"{self.object_base_url}/{bucket_name}/o"
async with self.session.get(
url,
headers=headers,
Expand Down
69 changes: 69 additions & 0 deletions nucliadb/tests/writer/unit/tus/test_gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (C) 2021 Bosutech XXI S.L.
#
# nucliadb is offered under the AGPL v3.0 and as commercial software.
# For commercial licensing, contact us at [email protected].
#
# AGPL:
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from unittest.mock import MagicMock, patch

import pytest

from nucliadb.writer.tus.gcs import GCloudBlobStore


@pytest.mark.asyncio
@patch("nucliadb.writer.tus.gcs.aiohttp")
@patch("nucliadb.writer.tus.gcs.service_account")
async def test_tus_gcs(mock_sa, mock_aiohttp):
mock_session = MagicMock()
mock_aiohttp.ClientSession.return_value = mock_session
mock_aenter = MagicMock()
mock_aenter.__aenter__.return_value = MagicMock(status=200)
mock_session.get.return_value = mock_aenter

gblobstore = GCloudBlobStore()
await gblobstore.initialize(
bucket="test-bucket",
location="test-location",
project="test-project",
bucket_labels="test-labels",
object_base_url="test-url",
json_credentials="test-cred",
)

mock_sa.Credentials.from_service_account_info.assert_called_once_with(
"test-cred", scopes=["https://www.googleapis.com/auth/devstorage.read_write"]
)

assert await gblobstore.check_exists("test-bucket")


@pytest.mark.asyncio
@patch("nucliadb.writer.tus.gcs.google")
async def test_tus_gcs_empty_json_credentials(mock_google):
mock_google.auth.default.return_value = ("credential", "project")

gblobstore = GCloudBlobStore()
await gblobstore.initialize(
bucket="test-bucket",
location="test-location",
project="test-project",
bucket_labels="test-labels",
object_base_url="test-url",
json_credentials=None,
)

mock_google.auth.default.assert_called_once()
29 changes: 22 additions & 7 deletions nucliadb_utils/src/nucliadb_utils/storages/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import backoff
import google.auth.transport.requests # type: ignore
import yarl
from google.auth.exceptions import DefaultCredentialsError # type: ignore
from google.oauth2 import service_account # type: ignore

from nucliadb_protos.resources_pb2 import CloudFile
Expand Down Expand Up @@ -458,12 +459,25 @@ def __init__(
url: str = "https://www.googleapis.com",
scopes: Optional[List[str]] = None,
):
if account_credentials is not None:
if account_credentials is None:
self._json_credentials = None
elif isinstance(account_credentials, str) and account_credentials.strip() == "":
self._json_credentials = None
else:
self._json_credentials = json.loads(base64.b64decode(account_credentials))

if self._json_credentials is not None:
self._credentials = service_account.Credentials.from_service_account_info(
self._json_credentials,
scopes=DEFAULT_SCOPES if scopes is None else scopes,
)
else:
try:
self._credentials, self._project = google.auth.default()
except DefaultCredentialsError:
logger.warning("Setting up without credentials as couldn't find workload identity")
self._credentials = None

self.source = CloudFile.GCS
self.deadletter_bucket = deadletter_bucket
self.indexing_bucket = indexing_bucket
Expand All @@ -473,16 +487,15 @@ def __init__(
# https://cloud.google.com/storage/docs/bucket-locations
self._bucket_labels = labels or {}
self._executor = executor
self._creation_access_token = datetime.now()
self._upload_url = url + "/upload/storage/v1/b/{bucket}/o?uploadType=resumable" # noqa
self.object_base_url = url + "/storage/v1/b"
self._client = None

def _get_access_token(self):
if self._credentials.valid is False:
req = google.auth.transport.requests.Request()
self._credentials.refresh(req)
self._creation_access_token = datetime.now()
if self._credentials.expired or self._credentials.valid is False:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)

return self._credentials.token

@storage_ops_observer.wrap({"type": "initialize"})
Expand Down Expand Up @@ -552,7 +565,9 @@ async def check_exists(self, bucket_name: str):
raise AttributeError()

headers = await self.get_access_headers()
url = f"{self.object_base_url}/{bucket_name}?project={self._project}"
# Using object access url instead of bucket access to avoid
# giving admin permission to the SA, needed to GET a bucket
url = f"{self.object_base_url}/{bucket_name}/o"
async with self.session.get(
url,
headers=headers,
Expand Down

0 comments on commit c0b8aca

Please sign in to comment.