Skip to content

Commit

Permalink
Fix s3 storage driver: missing move method (#2194)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored May 27, 2024
1 parent c58715c commit 60834ae
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 57 deletions.
4 changes: 2 additions & 2 deletions nucliadb_utils/nucliadb_utils/storages/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from typing import Any, AsyncIterator, Dict, List, Optional
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
from urllib.parse import quote_plus

import aiohttp
Expand Down Expand Up @@ -221,7 +221,7 @@ async def _inner_iter_data(self, headers=None):
break

@storage_ops_observer.wrap({"type": "read_range"})
async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]:
async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
"""
Iterate through ranges of data
"""
Expand Down
4 changes: 2 additions & 2 deletions nucliadb_utils/nucliadb_utils/storages/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import os
import shutil
from datetime import datetime
from typing import Any, AsyncIterator, Dict, Optional
from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional

import aiofiles
from nucliadb_protos.resources_pb2 import CloudFile
Expand Down Expand Up @@ -87,7 +87,7 @@ async def iter_data(self, headers=None):
break
yield data

async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]:
async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
"""
Iterate through ranges of data
"""
Expand Down
4 changes: 2 additions & 2 deletions nucliadb_utils/nucliadb_utils/storages/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import asyncio
import logging
import uuid
from typing import Any, AsyncIterator, Optional, TypedDict
from typing import Any, AsyncGenerator, AsyncIterator, Optional, TypedDict

import asyncpg
from nucliadb_protos.resources_pb2 import CloudFile
Expand Down Expand Up @@ -427,7 +427,7 @@ async def iter_data(self, headers=None):
async for chunk in dl.iterate_chunks(bucket, key):
yield chunk["data"]

async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]:
async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
"""
Iterate through ranges of data
"""
Expand Down
19 changes: 16 additions & 3 deletions nucliadb_utils/nucliadb_utils/storages/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

from contextlib import AsyncExitStack
from datetime import datetime
from typing import Any, AsyncIterator, Optional
from typing import Any, AsyncGenerator, AsyncIterator, Optional

import aiobotocore # type: ignore
import aiohttp
import backoff # type: ignore
import botocore # type: ignore
from aiobotocore.client import AioBaseClient # type: ignore
from aiobotocore.session import AioSession, get_session # type: ignore
from nucliadb_protos.resources_pb2 import CloudFile

Expand Down Expand Up @@ -111,7 +112,7 @@ async def iter_data(self, **kwargs):
yield data
data = await stream.read(CHUNK_SIZE)

async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]:
async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
"""
Iterate through ranges of data
"""
Expand Down Expand Up @@ -319,6 +320,18 @@ async def copy(
Key=destination_uri,
)

async def move(
self,
origin_uri: str,
destination_uri: str,
origin_bucket_name: str,
destination_bucket_name: str,
):
await self.copy(
origin_uri, destination_uri, origin_bucket_name, destination_bucket_name
)
await self.storage.delete_upload(origin_uri, origin_bucket_name)

async def upload(self, iterator: AsyncIterator, origin: CloudFile) -> CloudFile:
self.field = await self.start(origin)
await self.append(origin, iterator)
Expand Down Expand Up @@ -384,7 +397,7 @@ def session(self):

async def initialize(self):
session = AioSession()
self._s3aioclient = await self._exit_stack.enter_async_context(
self._s3aioclient: AioBaseClient = await self._exit_stack.enter_async_context(
session.create_client("s3", **self.opts)
)
for bucket in (self.deadletter_bucket, self.indexing_bucket):
Expand Down
79 changes: 39 additions & 40 deletions nucliadb_utils/nucliadb_utils/storages/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
MESSAGE_KEY = "message/{kbid}/{rid}/{mid}"


class StorageField:
class StorageField(abc.ABC, metaclass=abc.ABCMeta):
storage: Storage
bucket: str
key: str
Expand All @@ -78,18 +78,18 @@ def __init__(
self.key = fullkey
self.field = field

async def upload(self, iterator: AsyncIterator, origin: CloudFile) -> CloudFile:
raise NotImplementedError()
@abc.abstractmethod
async def upload(self, iterator: AsyncIterator, origin: CloudFile) -> CloudFile: ...

async def iter_data(self, headers=None):
@abc.abstractmethod
async def iter_data(self, headers=None) -> AsyncGenerator[bytes, None]: # type: ignore
raise NotImplementedError()
yield b""

async def read_range(self, start: int, end: int) -> AsyncIterator[bytes]:
"""
Iterate through ranges of data
"""
@abc.abstractmethod
async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
raise NotImplementedError()
yield b"" # pragma: no cover
yield b""

async def delete(self) -> bool:
deleted = False
Expand All @@ -98,38 +98,38 @@ async def delete(self) -> bool:
deleted = True
return deleted

async def exists(self) -> Optional[Dict[str, str]]:
raise NotImplementedError
@abc.abstractmethod
async def exists(self) -> Optional[Dict[str, str]]: ...

@abc.abstractmethod
async def copy(
self,
origin_uri: str,
destination_uri: str,
origin_bucket_name: str,
destination_bucket_name: str,
):
raise NotImplementedError()
): ...

@abc.abstractmethod
async def move(
self,
origin_uri: str,
destination_uri: str,
origin_bucket_name: str,
destination_bucket_name: str,
):
raise NotImplementedError()
): ...

async def start(self, cf: CloudFile) -> CloudFile:
raise NotImplementedError()
@abc.abstractmethod
async def start(self, cf: CloudFile) -> CloudFile: ...

async def append(self, cf: CloudFile, iterable: AsyncIterator) -> int:
raise NotImplementedError()
@abc.abstractmethod
async def append(self, cf: CloudFile, iterable: AsyncIterator) -> int: ...

async def finish(self):
raise NotImplementedError()
@abc.abstractmethod
async def finish(self): ...


class Storage:
class Storage(abc.ABC, metaclass=abc.ABCMeta):
source: int
field_klass: Type
deadletter_bucket: Optional[str] = None
Expand Down Expand Up @@ -498,40 +498,39 @@ async def download_pb(self, sf: StorageField, PBKlass: Type):
pb.ParseFromString(payload.read())
return pb

async def delete_upload(self, uri: str, bucket_name: str):
raise NotImplementedError()
@abc.abstractmethod
async def delete_upload(self, uri: str, bucket_name: str): ...

def get_bucket_name(self, kbid: str):
raise NotImplementedError()
@abc.abstractmethod
def get_bucket_name(self, kbid: str) -> str: ...

async def initialize(self):
raise NotImplementedError()
@abc.abstractmethod
async def initialize(self) -> None: ...

async def finalize(self):
raise NotImplementedError()
@abc.abstractmethod
async def finalize(self) -> None: ...

@abc.abstractmethod
def iterate_bucket(self, bucket: str, prefix: str) -> AsyncIterator[Any]:
raise NotImplementedError()
def iterate_bucket(self, bucket: str, prefix: str) -> AsyncIterator[Any]: ...

async def copy(self, file: CloudFile, destination: StorageField):
async def copy(self, file: CloudFile, destination: StorageField) -> None:
await destination.copy(
file.uri, destination.key, file.bucket_name, destination.bucket
)

async def move(self, file: CloudFile, destination: StorageField):
async def move(self, file: CloudFile, destination: StorageField) -> None:
await destination.move(
file.uri, destination.key, file.bucket_name, destination.bucket
)

async def create_kb(self, kbid: str) -> bool:
raise NotImplementedError()
@abc.abstractmethod
async def create_kb(self, kbid: str) -> bool: ...

async def delete_kb(self, kbid: str) -> Tuple[bool, bool]:
raise NotImplementedError()
@abc.abstractmethod
async def delete_kb(self, kbid: str) -> Tuple[bool, bool]: ...

async def schedule_delete_kb(self, kbid: str) -> bool:
raise NotImplementedError()
@abc.abstractmethod
async def schedule_delete_kb(self, kbid: str) -> bool: ...

async def set_stream_message(self, kbid: str, rid: str, data: bytes) -> str:
key = MESSAGE_KEY.format(kbid=kbid, rid=rid, mid=uuid.uuid4())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,35 @@ async def storage_field_test(storage: Storage):
assert metadata["FILENAME"] == "myfile.txt"

# Download the file and check that it's the same
downloaded_data = b""
async for data in sfield.iter_data():
downloaded_data += data
assert downloaded_data == binary_data
async def check_downloaded_data(sfield, expected_data: bytes):
downloaded_data = b""
async for data in sfield.iter_data():
downloaded_data += data
assert downloaded_data == expected_data

await check_downloaded_data(sfield, binary_data)

# Test
if storage.source == CloudFile.Source.LOCAL:
# There is a bug to be fixed in the copy method on the local storage driver
return

# Copy the file to another bucket (with the same key)
kbid2 = uuid.uuid4().hex
assert await storage.create_kb(kbid2)
bucket2 = storage.get_bucket_name(kbid2)
rid = "rid"
field_id = "field1"
field_key = KB_RESOURCE_FIELD.format(kbid=kbid2, uuid=rid, field=field_id)
sfield_kb2 = storage.file_field(kbid2, rid, field=field_id)

await sfield.copy(sfield.key, field_key, bucket, bucket2)

await check_downloaded_data(sfield_kb2, binary_data)

# Move the file to another key (same bucket)
new_field_id = "field3"
new_field_key = KB_RESOURCE_FIELD.format(kbid=kbid2, uuid=rid, field=new_field_id)
new_sfield = storage.file_field(kbid2, rid, field=new_field_id)

await sfield_kb2.move(sfield_kb2.key, new_field_key, bucket2, bucket2)

await check_downloaded_data(new_sfield, binary_data)
23 changes: 21 additions & 2 deletions nucliadb_utils/nucliadb_utils/tests/unit/storages/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nucliadb_protos.nodewriter_pb2 import IndexMessage
from nucliadb_protos.resources_pb2 import CloudFile

from nucliadb_utils.storages.local import LocalStorageField
from nucliadb_utils.storages.storage import (
Storage,
StorageField,
Expand All @@ -45,7 +46,7 @@ def field(self):

@pytest.fixture
def storage_field(self, storage, field):
yield StorageField(storage, "bucket", "fullkey", field)
yield LocalStorageField(storage, "bucket", "fullkey", field)

@pytest.mark.asyncio
async def test_delete(self, storage_field: StorageField, storage):
Expand Down Expand Up @@ -73,6 +74,24 @@ async def download(self, bucket_name, uri):
br = BrainResource(labels=["label"])
yield br.SerializeToString()

async def create_kb(self, kbid):
return True

async def delete_kb(self, kbid):
return True

async def delete_upload(self, uri, bucket):
return True

async def initialize(self) -> None:
pass

async def finalize(self) -> None:
pass

async def schedule_delete_kb(self, kbid: str) -> bool:
return True


class TestStorage:
@pytest.fixture
Expand Down Expand Up @@ -136,7 +155,7 @@ async def test_delete_indexing(self, storage: StorageTest):
async def test_download_pb(self, storage: StorageTest):
assert isinstance(
await storage.download_pb(
StorageField(storage, "bucket", "fullkey"), BrainResource
LocalStorageField(storage, "bucket", "fullkey"), BrainResource
),
BrainResource,
)
Expand Down

0 comments on commit 60834ae

Please sign in to comment.