From 60834ae0db3031fa2360ac92495a550a7217520e Mon Sep 17 00:00:00 2001 From: Ferran Llamas Date: Mon, 27 May 2024 12:26:23 +0200 Subject: [PATCH] Fix s3 storage driver: missing move method (#2194) --- nucliadb_utils/nucliadb_utils/storages/gcs.py | 4 +- .../nucliadb_utils/storages/local.py | 4 +- nucliadb_utils/nucliadb_utils/storages/pg.py | 4 +- nucliadb_utils/nucliadb_utils/storages/s3.py | 19 ++++- .../nucliadb_utils/storages/storage.py | 79 +++++++++---------- .../storages/test_field_storage.py | 26 ++++-- .../tests/unit/storages/test_storage.py | 23 +++++- 7 files changed, 102 insertions(+), 57 deletions(-) diff --git a/nucliadb_utils/nucliadb_utils/storages/gcs.py b/nucliadb_utils/nucliadb_utils/storages/gcs.py index b3c5a27bea..2a41d5495c 100644 --- a/nucliadb_utils/nucliadb_utils/storages/gcs.py +++ b/nucliadb_utils/nucliadb_utils/storages/gcs.py @@ -26,7 +26,7 @@ from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from datetime import datetime -from typing import Any, AsyncIterator, Dict, List, Optional +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional from urllib.parse import quote_plus import aiohttp @@ -221,7 +221,7 @@ async def _inner_iter_data(self, headers=None): break @storage_ops_observer.wrap({"type": "read_range"}) - async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]: + async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]: """ Iterate through ranges of data """ diff --git a/nucliadb_utils/nucliadb_utils/storages/local.py b/nucliadb_utils/nucliadb_utils/storages/local.py index b0ea806578..7f9abea178 100644 --- a/nucliadb_utils/nucliadb_utils/storages/local.py +++ b/nucliadb_utils/nucliadb_utils/storages/local.py @@ -24,7 +24,7 @@ import os import shutil from datetime import datetime -from typing import Any, AsyncIterator, Dict, Optional +from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional import aiofiles from nucliadb_protos.resources_pb2 import CloudFile @@ -87,7 +87,7 @@ async def iter_data(self, headers=None): break yield data - async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]: + async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]: """ Iterate through ranges of data """ diff --git a/nucliadb_utils/nucliadb_utils/storages/pg.py b/nucliadb_utils/nucliadb_utils/storages/pg.py index c3ee9e2731..79ead53bad 100644 --- a/nucliadb_utils/nucliadb_utils/storages/pg.py +++ b/nucliadb_utils/nucliadb_utils/storages/pg.py @@ -22,7 +22,7 @@ import asyncio import logging import uuid -from typing import Any, AsyncIterator, Optional, TypedDict +from typing import Any, AsyncGenerator, AsyncIterator, Optional, TypedDict import asyncpg from nucliadb_protos.resources_pb2 import CloudFile @@ -427,7 +427,7 @@ async def iter_data(self, headers=None): async for chunk in dl.iterate_chunks(bucket, key): yield chunk["data"] - async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]: + async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]: """ Iterate through ranges of data """ diff --git a/nucliadb_utils/nucliadb_utils/storages/s3.py b/nucliadb_utils/nucliadb_utils/storages/s3.py index db9dfcb3d9..09236f9f75 100644 --- a/nucliadb_utils/nucliadb_utils/storages/s3.py +++ b/nucliadb_utils/nucliadb_utils/storages/s3.py @@ -21,12 +21,13 @@ from contextlib import AsyncExitStack from datetime import datetime -from typing import Any, AsyncIterator, Optional +from typing import Any, AsyncGenerator, AsyncIterator, Optional import aiobotocore # type: ignore import aiohttp import backoff # type: ignore import botocore # type: ignore +from aiobotocore.client import AioBaseClient # type: ignore from aiobotocore.session import AioSession, get_session # type: ignore from nucliadb_protos.resources_pb2 import CloudFile @@ -111,7 +112,7 @@ async def iter_data(self, **kwargs): yield data data = await stream.read(CHUNK_SIZE) - async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]: + async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]: """ Iterate through ranges of data """ @@ -319,6 +320,18 @@ async def copy( Key=destination_uri, ) + async def move( + self, + origin_uri: str, + destination_uri: str, + origin_bucket_name: str, + destination_bucket_name: str, + ): + await self.copy( + origin_uri, destination_uri, origin_bucket_name, destination_bucket_name + ) + await self.storage.delete_upload(origin_uri, origin_bucket_name) + async def upload(self, iterator: AsyncIterator, origin: CloudFile) -> CloudFile: self.field = await self.start(origin) await self.append(origin, iterator) @@ -384,7 +397,7 @@ def session(self): async def initialize(self): session = AioSession() - self._s3aioclient = await self._exit_stack.enter_async_context( + self._s3aioclient: AioBaseClient = await self._exit_stack.enter_async_context( session.create_client("s3", **self.opts) ) for bucket in (self.deadletter_bucket, self.indexing_bucket): diff --git a/nucliadb_utils/nucliadb_utils/storages/storage.py b/nucliadb_utils/nucliadb_utils/storages/storage.py index 7032716482..3b50f6eea5 100644 --- a/nucliadb_utils/nucliadb_utils/storages/storage.py +++ b/nucliadb_utils/nucliadb_utils/storages/storage.py @@ -60,7 +60,7 @@ MESSAGE_KEY = "message/{kbid}/{rid}/{mid}" -class StorageField: +class StorageField(abc.ABC, metaclass=abc.ABCMeta): storage: Storage bucket: str key: str @@ -78,18 +78,18 @@ def __init__( self.key = fullkey self.field = field - async def upload(self, iterator: AsyncIterator, origin: CloudFile) -> CloudFile: - raise NotImplementedError() + @abc.abstractmethod + async def upload(self, iterator: AsyncIterator, origin: CloudFile) -> CloudFile: ... - async def iter_data(self, headers=None): + @abc.abstractmethod + async def iter_data(self, headers=None) -> AsyncGenerator[bytes, None]: # type: ignore raise NotImplementedError() + yield b"" - async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]: - """ - Iterate through ranges of data - """ + @abc.abstractmethod + async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]: raise NotImplementedError() - yield b"" # pragma: no cover + yield b"" async def delete(self) -> bool: deleted = False @@ -98,38 +98,38 @@ async def delete(self) -> bool: deleted = True return deleted - async def exists(self) -> Optional[Dict[str, str]]: - raise NotImplementedError + @abc.abstractmethod + async def exists(self) -> Optional[Dict[str, str]]: ... + @abc.abstractmethod async def copy( self, origin_uri: str, destination_uri: str, origin_bucket_name: str, destination_bucket_name: str, - ): - raise NotImplementedError() + ): ... + @abc.abstractmethod async def move( self, origin_uri: str, destination_uri: str, origin_bucket_name: str, destination_bucket_name: str, - ): - raise NotImplementedError() + ): ... - async def start(self, cf: CloudFile) -> CloudFile: - raise NotImplementedError() + @abc.abstractmethod + async def start(self, cf: CloudFile) -> CloudFile: ... - async def append(self, cf: CloudFile, iterable: AsyncIterator) -> int: - raise NotImplementedError() + @abc.abstractmethod + async def append(self, cf: CloudFile, iterable: AsyncIterator) -> int: ... - async def finish(self): - raise NotImplementedError() + @abc.abstractmethod + async def finish(self): ... -class Storage: +class Storage(abc.ABC, metaclass=abc.ABCMeta): source: int field_klass: Type deadletter_bucket: Optional[str] = None @@ -498,40 +498,39 @@ async def download_pb(self, sf: StorageField, PBKlass: Type): pb.ParseFromString(payload.read()) return pb - async def delete_upload(self, uri: str, bucket_name: str): - raise NotImplementedError() + @abc.abstractmethod + async def delete_upload(self, uri: str, bucket_name: str): ... - def get_bucket_name(self, kbid: str): - raise NotImplementedError() + @abc.abstractmethod + def get_bucket_name(self, kbid: str) -> str: ... - async def initialize(self): - raise NotImplementedError() + @abc.abstractmethod + async def initialize(self) -> None: ... - async def finalize(self): - raise NotImplementedError() + @abc.abstractmethod + async def finalize(self) -> None: ... @abc.abstractmethod - def iterate_bucket(self, bucket: str, prefix: str) -> AsyncIterator[Any]: - raise NotImplementedError() + def iterate_bucket(self, bucket: str, prefix: str) -> AsyncIterator[Any]: ... - async def copy(self, file: CloudFile, destination: StorageField): + async def copy(self, file: CloudFile, destination: StorageField) -> None: await destination.copy( file.uri, destination.key, file.bucket_name, destination.bucket ) - async def move(self, file: CloudFile, destination: StorageField): + async def move(self, file: CloudFile, destination: StorageField) -> None: await destination.move( file.uri, destination.key, file.bucket_name, destination.bucket ) - async def create_kb(self, kbid: str) -> bool: - raise NotImplementedError() + @abc.abstractmethod + async def create_kb(self, kbid: str) -> bool: ... - async def delete_kb(self, kbid: str) -> Tuple[bool, bool]: - raise NotImplementedError() + @abc.abstractmethod + async def delete_kb(self, kbid: str) -> Tuple[bool, bool]: ... - async def schedule_delete_kb(self, kbid: str) -> bool: - raise NotImplementedError() + @abc.abstractmethod + async def schedule_delete_kb(self, kbid: str) -> bool: ... async def set_stream_message(self, kbid: str, rid: str, data: bytes) -> str: key = MESSAGE_KEY.format(kbid=kbid, rid=rid, mid=uuid.uuid4()) diff --git a/nucliadb_utils/nucliadb_utils/tests/integration/storages/test_field_storage.py b/nucliadb_utils/nucliadb_utils/tests/integration/storages/test_field_storage.py index 67623ff7ff..d4a209733e 100644 --- a/nucliadb_utils/nucliadb_utils/tests/integration/storages/test_field_storage.py +++ b/nucliadb_utils/nucliadb_utils/tests/integration/storages/test_field_storage.py @@ -68,21 +68,35 @@ async def storage_field_test(storage: Storage): assert metadata["FILENAME"] == "myfile.txt" # Download the file and check that it's the same - downloaded_data = b"" - async for data in sfield.iter_data(): - downloaded_data += data - assert downloaded_data == binary_data + async def check_downloaded_data(sfield, expected_data: bytes): + downloaded_data = b"" + async for data in sfield.iter_data(): + downloaded_data += data + assert downloaded_data == expected_data + + await check_downloaded_data(sfield, binary_data) # Test if storage.source == CloudFile.Source.LOCAL: # There is a bug to be fixed in the copy method on the local storage driver return + # Copy the file to another bucket (with the same key) kbid2 = uuid.uuid4().hex assert await storage.create_kb(kbid2) bucket2 = storage.get_bucket_name(kbid2) - rid = "rid" - field_id = "field1" field_key = KB_RESOURCE_FIELD.format(kbid=kbid2, uuid=rid, field=field_id) + sfield_kb2 = storage.file_field(kbid2, rid, field=field_id) await sfield.copy(sfield.key, field_key, bucket, bucket2) + + await check_downloaded_data(sfield_kb2, binary_data) + + # Move the file to another key (same bucket) + new_field_id = "field3" + new_field_key = KB_RESOURCE_FIELD.format(kbid=kbid2, uuid=rid, field=new_field_id) + new_sfield = storage.file_field(kbid2, rid, field=new_field_id) + + await sfield_kb2.move(sfield_kb2.key, new_field_key, bucket2, bucket2) + + await check_downloaded_data(new_sfield, binary_data) diff --git a/nucliadb_utils/nucliadb_utils/tests/unit/storages/test_storage.py b/nucliadb_utils/nucliadb_utils/tests/unit/storages/test_storage.py index b5b8ed0af9..e7fef60244 100644 --- a/nucliadb_utils/nucliadb_utils/tests/unit/storages/test_storage.py +++ b/nucliadb_utils/nucliadb_utils/tests/unit/storages/test_storage.py @@ -26,6 +26,7 @@ from nucliadb_protos.nodewriter_pb2 import IndexMessage from nucliadb_protos.resources_pb2 import CloudFile +from nucliadb_utils.storages.local import LocalStorageField from nucliadb_utils.storages.storage import ( Storage, StorageField, @@ -45,7 +46,7 @@ def field(self): @pytest.fixture def storage_field(self, storage, field): - yield StorageField(storage, "bucket", "fullkey", field) + yield LocalStorageField(storage, "bucket", "fullkey", field) @pytest.mark.asyncio async def test_delete(self, storage_field: StorageField, storage): @@ -73,6 +74,24 @@ async def download(self, bucket_name, uri): br = BrainResource(labels=["label"]) yield br.SerializeToString() + async def create_kb(self, kbid): + return True + + async def delete_kb(self, kbid): + return True + + async def delete_upload(self, uri, bucket): + return True + + async def initialize(self) -> None: + pass + + async def finalize(self) -> None: + pass + + async def schedule_delete_kb(self, kbid: str) -> bool: + return True + class TestStorage: @pytest.fixture @@ -136,7 +155,7 @@ async def test_delete_indexing(self, storage: StorageTest): async def test_download_pb(self, storage: StorageTest): assert isinstance( await storage.download_pb( - StorageField(storage, "bucket", "fullkey"), BrainResource + LocalStorageField(storage, "bucket", "fullkey"), BrainResource ), BrainResource, )