From 6f286fcfbdfb4845a22add074cb01336d6316a25 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 25 Oct 2023 22:29:25 +0200 Subject: [PATCH] Feature: Add LightNode and MessageCache A LightNode can synchronize on a subset, or domain, of aleph.im messages. It relies on the MessageCache, which manages a message database with peewee. --- .gitignore | 1 + setup.cfg | 1 + src/aleph/sdk/__init__.py | 4 +- src/aleph/sdk/client/__init__.py | 4 + src/aleph/sdk/client/http.py | 8 + src/aleph/sdk/client/light_node.py | 394 +++++++++++++++++++++ src/aleph/sdk/client/message_cache.py | 490 ++++++++++++++++++++++++++ src/aleph/sdk/conf.py | 9 + src/aleph/sdk/db/aggregate.py | 31 ++ src/aleph/sdk/db/common.py | 34 ++ src/aleph/sdk/db/message.py | 124 +++++++ src/aleph/sdk/db/post.py | 140 ++++++++ src/aleph/sdk/exceptions.py | 8 + src/aleph/sdk/vm/cache.py | 18 +- tests/unit/test_download.py | 35 +- tests/unit/test_light_node.py | 262 ++++++++++++++ tests/unit/test_message_cache.py | 321 +++++++++++++++++ 17 files changed, 1870 insertions(+), 14 deletions(-) create mode 100644 src/aleph/sdk/client/light_node.py create mode 100644 src/aleph/sdk/client/message_cache.py create mode 100644 src/aleph/sdk/db/aggregate.py create mode 100644 src/aleph/sdk/db/common.py create mode 100644 src/aleph/sdk/db/message.py create mode 100644 src/aleph/sdk/db/post.py create mode 100644 tests/unit/test_light_node.py create mode 100644 tests/unit/test_message_cache.py diff --git a/.gitignore b/.gitignore index c4734889..a12a6219 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ *.pot __pycache__/* .cache/* +cache/**/* .*.swp */.ipynb_checkpoints/* diff --git a/setup.cfg b/setup.cfg index ca39e9ce..4d271446 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ install_requires = # Required to fix a dependency issue with parsimonious and Python3.11 eth_abi==4.0.0b2; python_version>="3.11" python-magic + peewee # The usage of test_requires is discouraged, see `Dependency Management` docs # tests_require = pytest; pytest-cov # Require a specific Python version, e.g. Python 2.7 or >= 3.4 diff --git a/src/aleph/sdk/__init__.py b/src/aleph/sdk/__init__.py index c14b64f6..416f12cd 100644 --- a/src/aleph/sdk/__init__.py +++ b/src/aleph/sdk/__init__.py @@ -1,6 +1,6 @@ from pkg_resources import DistributionNotFound, get_distribution -from aleph.sdk.client import AlephHttpClient, AuthenticatedAlephHttpClient +from aleph.sdk.client import AlephHttpClient, AuthenticatedAlephHttpClient, LightNode try: # Change here if project is renamed and does not equal the package name @@ -11,4 +11,4 @@ finally: del get_distribution, DistributionNotFound -__all__ = ["AlephHttpClient", "AuthenticatedAlephHttpClient"] +__all__ = ["AlephHttpClient", "AuthenticatedAlephHttpClient", "LightNode"] diff --git a/src/aleph/sdk/client/__init__.py b/src/aleph/sdk/client/__init__.py index 9ee25dd9..eadc6cac 100644 --- a/src/aleph/sdk/client/__init__.py +++ b/src/aleph/sdk/client/__init__.py @@ -1,10 +1,14 @@ from .abstract import AlephClient, AuthenticatedAlephClient from .authenticated_http import AuthenticatedAlephHttpClient from .http import AlephHttpClient +from .light_node import LightNode +from .message_cache import MessageCache __all__ = [ "AlephClient", "AuthenticatedAlephClient", "AlephHttpClient", "AuthenticatedAlephHttpClient", + "MessageCache", + "LightNode", ] diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index 97ad0c08..554d6024 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -6,6 +6,7 @@ import aiohttp from aleph_message import parse_message from aleph_message.models import AlephMessage, ItemHash, ItemType +from aleph_message.status import MessageStatus from pydantic import ValidationError from ..conf import settings @@ -171,6 +172,8 @@ async def download_file_to_buffer( ) else: raise FileTooLarge(f"The file from {file_hash} is too large") + else: + response.raise_for_status() async def download_file_ipfs_to_buffer( self, @@ -313,6 +316,11 @@ async def get_message( ) return message + async def get_message_status(self, item_hash: str) -> MessageStatus: + async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp: + resp.raise_for_status() + return MessageStatus((await resp.json())["status"]) + async def watch_messages( self, message_filter: Optional[MessageFilter] = None, diff --git a/src/aleph/sdk/client/light_node.py b/src/aleph/sdk/client/light_node.py new file mode 100644 index 00000000..91506e8f --- /dev/null +++ b/src/aleph/sdk/client/light_node.py @@ -0,0 +1,394 @@ +import asyncio +from datetime import datetime +from io import BytesIO +from pathlib import Path +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union + +from aleph_message.models import AlephMessage, Chain, MessageType +from aleph_message.models.execution.base import Encoding +from aleph_message.status import MessageStatus + +from ..query.filters import MessageFilter +from ..types import StorageEnum +from ..utils import Writable +from .abstract import AuthenticatedAlephClient +from .authenticated_http import AuthenticatedAlephHttpClient +from .message_cache import MessageCache + + +class LightNode(MessageCache, AuthenticatedAlephClient): + """ + A LightNode is a client that can listen to the Aleph network and stores messages in a local database. Furthermore, + it can create messages and submit them to the network, as well as upload files, while keeping track of the + corresponding messages locally. + + It synchronizes with the network on a subset of the messages (the "domain") by listening to the network and storing + the messages in the database. The user may define the domain by specifying a channels, tags, senders, chains and/or + message types. + """ + + def __init__( + self, + session: AuthenticatedAlephHttpClient, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_types: Optional[Iterable[MessageType]] = None, + ): + """ + Initialize a LightNode. Besides requiring an established session with a core channel node, the user may specify + a domain to listen to. The domain is the intersection of the specified channels, tags, senders, chains and + message types. A smaller domain will synchronize faster, require less storage and less bandwidth. + + Args: + session: An authenticated session to an Aleph core channel node. + channels: The channels to listen to. + tags: The tags to listen to. + addresses: The addresses to listen to. + chains: The chains to listen to. + message_types: The message types to listen to. + + Raises: + InvalidCacheDatabaseSchema: If the database schema does not match the expected message schema. + """ + super().__init__() + self.session = session + self.channels = channels + self.tags = tags + self.addresses = ( + list(addresses) + [session.account.get_address()] + if addresses + else [session.account.get_address()] + ) + self.chains = ( + list(chains) + [Chain(session.account.CHAIN)] + if chains + else [session.account.CHAIN] + ) + self.message_types = message_types + + async def run(self): + """ + Start listening to the network and synchronize with past messages. + """ + asyncio.create_task( + self.listen_to( + self.session.watch_messages( + message_filter=MessageFilter( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_types=self.message_types, + ) + ) + ) + ) + # synchronize with past messages + await self.synchronize( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_types=self.message_types, + ) + + async def synchronize( + self, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_types: Optional[Iterable[MessageType]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + """ + Synchronize with past messages. + """ + chunk_size = 200 + messages = [] + async for message in self.session.get_messages_iterator( + message_filter=MessageFilter( + channels=channels, + tags=tags, + addresses=addresses, + chains=chains, + message_types=message_types, + start_date=start_date, + end_date=end_date, + ) + ): + messages.append(message) + if len(messages) >= chunk_size: + self.add(messages) + messages = [] + if messages: + self.add(messages) + + async def download_file(self, file_hash: str) -> bytes: + """ + Download a file from the network and store it locally. If it already exists locally, it will not be downloaded + again. + + Args: + file_hash: The hash of the file to download. + + Returns: + The file content. + + Raises: + FileNotFoundError: If the file does not exist on the network. + """ + try: + return await super().download_file(file_hash) + except FileNotFoundError: + pass + file = await self.session.download_file(file_hash) + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(file) + return file + + async def download_file_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the network and store it in a buffer. If it already exists locally, it will not be + downloaded again. + + Args: + file_hash: The hash of the file to download. + output_buffer: The buffer to store the file content in. + + Raises: + FileNotFoundError: If the file does not exist on the network. + """ + try: + return await super().download_file_to_buffer(file_hash, output_buffer) + except FileNotFoundError: + pass + buffer = BytesIO() + await self.session.download_file_ipfs_to_buffer(file_hash, buffer) + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(buffer.getvalue()) + output_buffer.write(buffer.getvalue()) + + def check_validity( + self, + message_type: MessageType, + address: Optional[str] = None, + channel: Optional[str] = None, + content: Optional[Dict] = None, + ): + if self.message_types and message_type not in self.message_types: + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to post messages." + ) + if address and self.addresses and address not in self.addresses: + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to messages from address {address}." + ) + if channel and self.channels and channel not in self.channels: + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to messages from channel {channel}." + ) + if ( + content + and self.tags + and not set(content.get("tags", [])).intersection(self.tags) + ): + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to any of these tags: {content.get('tags', [])}." + ) + + async def delete_if_rejected(self, item_hash): + async def _delete_if_rejected(): + await asyncio.sleep(5) + retries = 0 + status = await self.session.get_message_status(item_hash) + while status == MessageStatus.PENDING: + await asyncio.sleep(5) + status = await self.session.get_message_status(item_hash) + retries += 1 + if retries > 10: + raise TimeoutError( + f"Message {item_hash} has been pending for too long." + ) + if status in [MessageStatus.REJECTED, MessageStatus.FORGOTTEN]: + del self[item_hash] + + return _delete_if_rejected + + async def create_post( + self, + post_content: Any, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.post, address, channel, post_content) + resp, status = await self.session.create_post( + post_content=post_content, + post_type=post_type, + ref=ref, + address=address, + channel=channel, + inline=inline, + storage_engine=storage_engine, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status + + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.aggregate, address, channel) + resp, status = await self.session.create_aggregate( + key=key, + content=content, + address=address, + channel=channel, + inline=inline, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status + + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.store, address, channel, extra_fields) + resp, status = await self.session.create_store( + address=address, + file_content=file_content, + file_path=file_path, + file_hash=file_hash, + guess_mime_type=guess_mime_type, + ref=ref, + storage_engine=storage_engine, + extra_fields=extra_fields, + channel=channel, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status + + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity( + MessageType.program, address, channel, dict(metadata) if metadata else None + ) + resp, status = await self.session.create_program( + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + environment_variables=environment_variables, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, + metadata=metadata, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status + + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.forget, address, channel) + resp, status = await self.session.forget( + hashes=hashes, + reason=reason, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + ) + del self[resp.item_hash] + return resp, status + + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.submit( + content=content, + message_type=message_type, + channel=channel, + storage_engine=storage_engine, + allow_inlining=allow_inlining, + sync=sync, + ) + if status in [MessageStatus.PROCESSED, MessageStatus.PENDING]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status diff --git a/src/aleph/sdk/client/message_cache.py b/src/aleph/sdk/client/message_cache.py new file mode 100644 index 00000000..c0789bac --- /dev/null +++ b/src/aleph/sdk/client/message_cache.py @@ -0,0 +1,490 @@ +import logging +import typing +from datetime import datetime +from pathlib import Path +from typing import ( + AsyncIterable, + Coroutine, + Dict, + Iterable, + Iterator, + List, + Optional, + Type, + Union, +) + +from aleph_message import MessagesResponse +from aleph_message.models import AlephMessage, ItemHash, MessageType, PostMessage +from peewee import SqliteDatabase +from playhouse.shortcuts import model_to_dict + +from ..conf import settings +from ..db.aggregate import AggregateDBModel, aggregate_to_model +from ..db.message import ( + MessageDBModel, + message_filter_to_query, + message_to_model, + model_to_message, +) +from ..db.post import ( + PostDBModel, + message_to_post, + model_to_post, + post_filter_to_query, + post_to_model, +) +from ..exceptions import InvalidCacheDatabaseSchema, MessageNotFoundError +from ..query.filters import MessageFilter, PostFilter +from ..query.responses import PostsResponse +from ..types import GenericMessage +from ..utils import Writable +from .abstract import AlephClient + + +class MessageCache(AlephClient): + """ + A wrapper around a sqlite3 database for caching AlephMessage objects. + + It can be used independently of a DomainNode to implement any kind of caching strategy. + """ + + missing_posts: Dict[ItemHash, PostMessage] = {} + """A dict of all posts by item_hash and their amend messages that are missing from the db.""" + + def __init__(self, database_path: Optional[Union[str, Path]] = None): + """ + Args: + database_path: The path to the sqlite3 database file. If not provided, the default + path will be used. + + Note: + The database schema is automatically checked and updated if necessary. + + !!! warning + :memory: databases are not supported, as they do not persist across connections. + + Raises: + InvalidCacheDatabaseSchema: If the database schema does not match the expected message schema. + """ + self.database_path: Path = ( + Path(database_path) if database_path else settings.CACHE_DATABASE_PATH + ) + if not self.database_path.exists(): + self.database_path.parent.mkdir(parents=True, exist_ok=True) + + self.db = SqliteDatabase(self.database_path) + MessageDBModel._meta.database = self.db + PostDBModel._meta.database = self.db + AggregateDBModel._meta.database = self.db + + self.db.connect(reuse_if_open=True) + if not MessageDBModel.table_exists(): + self.db.create_tables([MessageDBModel]) + if not PostDBModel.table_exists(): + self.db.create_tables([PostDBModel]) + if not AggregateDBModel.table_exists(): + self.db.create_tables([AggregateDBModel]) + self._check_schema() + self.db.close() + + def _check_schema(self): + if sorted(MessageDBModel._meta.fields.keys()) != sorted( + [col.name for col in self.db.get_columns(MessageDBModel._meta.table_name)] + ): + raise InvalidCacheDatabaseSchema( + "MessageDBModel schema does not match MessageModel schema" + ) + if sorted(PostDBModel._meta.fields.keys()) != sorted( + [col.name for col in self.db.get_columns(PostDBModel._meta.table_name)] + ): + raise InvalidCacheDatabaseSchema( + "PostDBModel schema does not match PostModel schema" + ) + if sorted(AggregateDBModel._meta.fields.keys()) != sorted( + [col.name for col in self.db.get_columns(AggregateDBModel._meta.table_name)] + ): + raise InvalidCacheDatabaseSchema( + "AggregateDBModel schema does not match AggregateModel schema" + ) + + async def __aenter__(self): + self.db.connect(reuse_if_open=True) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.db.close() + + def __del__(self): + self.db.close() + + def __getitem__(self, item_hash: ItemHash) -> Optional[AlephMessage]: + item = MessageDBModel.get_or_none(MessageDBModel.item_hash == str(item_hash)) + return model_to_message(item) if item else None + + def __delitem__(self, item_hash: ItemHash): + MessageDBModel.delete().where( + MessageDBModel.item_hash == str(item_hash) + ).execute() + PostDBModel.delete().where( + PostDBModel.original_item_hash == str(item_hash) + ).execute() + AggregateDBModel.delete().where( + AggregateDBModel.original_message_hash == str(item_hash) + ).execute() + # delete stored files + file_path = self._file_path(str(item_hash)) + if file_path.exists(): + file_path.unlink() + + def __contains__(self, item_hash: ItemHash) -> bool: + return ( + MessageDBModel.select() + .where(MessageDBModel.item_hash == str(item_hash)) + .exists() + ) + + def __len__(self): + return MessageDBModel.select().count() + + def __iter__(self) -> Iterator[AlephMessage]: + """ + Iterate over all messages in the db, the latest first. + """ + for item in iter(MessageDBModel.select().order_by(-MessageDBModel.time)): + yield model_to_message(item) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return repr(self) + + @typing.overload + def add(self, messages: Iterable[AlephMessage]): + ... + + @typing.overload + def add(self, messages: AlephMessage): + ... + + def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): + """ + Add a message or a list of messages to the database. If the message is an amend, + it will be applied to the original post. If the original post is not in the db, + the amend message will be stored in memory until the original post is added. + Aggregate message will be merged with any existing aggregate messages. + + Args: + messages: A message or list of messages to add to the database. + """ + if isinstance(messages, typing.get_args(AlephMessage)): + messages = [messages] + + message_data = (message_to_model(message) for message in messages) + MessageDBModel.insert_many(message_data).on_conflict_replace().execute() + + # Sort messages and insert posts first + post_data = [] + amend_messages = [] + aggregate_messages = [] + forget_messages = [] + for message in messages: + if message.type == MessageType.aggregate.value: + aggregate_messages.append(message) + continue + if message.type == MessageType.forget.value: + forget_messages.append(message) + continue + if message.type != MessageType.post.value: + continue + if message.content.type == "amend": + amend_messages.append(message) + continue + + post = post_to_model(message_to_post(message)) + post_data.append(post) + + # Check if we can now add any amend messages that had missing refs + if message.item_hash in self.missing_posts: + amend_messages += self.missing_posts.pop(message.item_hash) + + with self.db.atomic(): + PostDBModel.insert_many(post_data).on_conflict_replace().execute() + + self._handle_amends(amend_messages) + + self._handle_aggregates(aggregate_messages) + + self._handle_forgets(forget_messages) + + def _handle_amends(self, amend_messages: List[PostMessage]): + post_data = [] + for amend in amend_messages: + original_post = MessageDBModel.get_or_none( + MessageDBModel.original_item_hash == amend.content.ref + ) + if not original_post: + latest_amend = self.missing_posts.get(ItemHash(amend.content.ref)) + if latest_amend and amend.time < latest_amend.time: + self.missing_posts[ItemHash(amend.content.ref)] = amend + continue + + if datetime.fromtimestamp(amend.time) < original_post.last_updated: + continue + + original_post.item_hash = amend.item_hash + original_post.content = amend.content.content + original_post.original_item_hash = amend.content.ref + original_post.original_type = amend.content.type + original_post.address = amend.sender + original_post.channel = amend.channel + original_post.last_updated = datetime.fromtimestamp(amend.time) + post_data.append(model_to_dict(original_post)) + with self.db.atomic(): + PostDBModel.insert_many(post_data).on_conflict_replace().execute() + + def _handle_aggregates(self, aggregate_messages): + aggregate_data = [] + for aggregate in aggregate_messages: + existing_aggregate = AggregateDBModel.get_or_none( + AggregateDBModel.address == aggregate.sender, + AggregateDBModel.key == aggregate.content.key, + ) + if not existing_aggregate: + aggregate_data.append(aggregate_to_model(aggregate)) + continue + data = model_to_dict(existing_aggregate) + if aggregate.time > existing_aggregate.time: + existing_aggregate.content.update(aggregate.content.content) + existing_aggregate.time = aggregate.time + elif existing_aggregate.content is None: + existing_aggregate.content = aggregate.content.content + else: + existing_aggregate.content = dict( + aggregate.content.content, **existing_aggregate.content + ) + data = model_to_dict(existing_aggregate) + aggregate_data.append(data) + with self.db.atomic(): + AggregateDBModel.insert_many(aggregate_data).on_conflict_replace().execute() + + def _handle_forgets(self, forget_messages): + refs = [forget.content.ref for forget in forget_messages] + with self.db.atomic(): + MessageDBModel.delete().where(MessageDBModel.item_hash.in_(refs)).execute() + PostDBModel.delete().where(PostDBModel.item_hash.in_(refs)).execute() + AggregateDBModel.delete().where( + AggregateDBModel.original_message_hash.in_(refs) + ).execute() + + @typing.overload + def get(self, item_hashes: Iterable[ItemHash]) -> List[AlephMessage]: + ... + + @typing.overload + def get(self, item_hashes: ItemHash) -> Optional[AlephMessage]: + ... + + def get( + self, item_hashes: Union[ItemHash, Iterable[ItemHash]] + ) -> List[AlephMessage]: + """ + Get many messages from the db by their item hash. + """ + if isinstance(item_hashes, ItemHash): + item_hashes = [item_hashes] + item_hashes = [str(item_hash) for item_hash in item_hashes] + items = ( + MessageDBModel.select() + .where(MessageDBModel.item_hash.in_(item_hashes)) + .execute() + ) + return [model_to_message(item) for item in items] + + def listen_to(self, message_stream: AsyncIterable[AlephMessage]) -> Coroutine: + """ + Listen to a stream of messages and add them to the database. + """ + + async def _listen(): + async for message in message_stream: + self.add(message) + print(f"Added message {message.item_hash} to db") + + return _listen() + + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: + item = ( + AggregateDBModel.select() + .where(AggregateDBModel.address == address) + .where(AggregateDBModel.key == key) + .order_by(AggregateDBModel.key.desc()) + .get_or_none() + ) + if not item: + raise MessageNotFoundError(f"No such aggregate {address} {key}") + return item.content + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Dict[str, Dict]: + query = ( + AggregateDBModel.select() + .where(AggregateDBModel.address == address) + .order_by(AggregateDBModel.key.desc()) + ) + if keys: + query = query.where(AggregateDBModel.key.in_(keys)) + return {item.key: item.content for item in list(query)} + + async def get_posts( + self, + pagination: int = 200, + page: int = 1, + post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> PostsResponse: + query = ( + post_filter_to_query(post_filter) if post_filter else PostDBModel.select() + ) + + query = query.paginate(page, pagination) + + posts = [model_to_post(item) for item in list(query)] + + return PostsResponse( + posts=posts, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="posts", + ) + + @staticmethod + def _file_path(file_hash: str) -> Path: + return settings.CACHE_FILES_PATH / Path(file_hash) + + async def download_file(self, file_hash: str) -> bytes: + """ + Opens a file that has been locally stored by its hash. + + Raises: + FileNotFoundError: If the file does not exist. + """ + with open(self._file_path(file_hash), "rb") as f: + return f.read() + + async def download_file_ipfs(self, file_hash: str) -> bytes: + """ + Opens a file that has been locally stored by its IPFS hash. + + Raises: + FileNotFoundError: If the file does not exist. + """ + return await self.download_file(file_hash) + + async def download_file_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Opens a file and writes it to a buffer. + + Raises: + FileNotFoundError: If the file does not exist. + """ + with open(self._file_path(file_hash), "rb") as f: + output_buffer.write(f.read()) + + async def download_file_to_buffer_ipfs( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Opens a file and writes it to a buffer. + + Raises: + FileNotFoundError: If the file does not exist. + """ + await self.download_file_to_buffer(file_hash, output_buffer) + + async def add_file(self, file_hash: str, file_content: bytes): + """ + Store a file locally by its hash. + """ + if not self._file_path(file_hash).exists(): + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(file_content) + + async def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> MessagesResponse: + """ + Get and filter many messages from the database. + """ + query = ( + message_filter_to_query(message_filter) + if message_filter + else MessageDBModel.select() + ) + + query = query.paginate(page, pagination) + + messages = [model_to_message(item) for item in list(query)] + + return MessagesResponse( + messages=messages, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="messages", + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + """ + Get a single message from the database by its item hash. + """ + query = MessageDBModel.select().where(MessageDBModel.item_hash == item_hash) + + if message_type: + query = query.where(MessageDBModel.type == message_type.value) + if channel: + query = query.where(MessageDBModel.channel == channel) + + item = query.first() + + if item: + return model_to_message(item) + + raise MessageNotFoundError(f"No such hash {item_hash}") + + async def watch_messages( + self, + message_filter: Optional[MessageFilter] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Watch new messages as they are added to the database. + """ + query = ( + message_filter_to_query(message_filter) + if message_filter + else MessageDBModel.select() + ) + + async for item in query: + yield model_to_message(item) diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index f8d798c6..626b1ba8 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -43,6 +43,15 @@ class Settings(BaseSettings): DNS_STATIC_DOMAIN = "static.public.aleph.sh" DNS_RESOLVERS = ["9.9.9.9", "1.1.1.1"] + CACHE_DATABASE_PATH: Path = Field( + default=Path("cache", "message_cache.sqlite"), + description="Path to the db database", + ) + CACHE_FILES_PATH: Path = Field( + default=Path("cache", "files"), + description="Path to the db files", + ) + class Config: env_prefix = "ALEPH_" case_sensitive = False diff --git a/src/aleph/sdk/db/aggregate.py b/src/aleph/sdk/db/aggregate.py new file mode 100644 index 00000000..f06c18f4 --- /dev/null +++ b/src/aleph/sdk/db/aggregate.py @@ -0,0 +1,31 @@ +from typing import Dict + +from aleph_message.models import AggregateMessage +from peewee import CharField, FloatField, Model +from playhouse.sqlite_ext import JSONField + +from .common import pydantic_json_dumps + + +class AggregateDBModel(Model): + """ + A simple database model for storing aleph.im Aggregates. + """ + + original_message_hash = CharField(primary_key=True) + address = CharField(index=True) + key = CharField() + channel = CharField(null=True) + content = JSONField(json_dumps=pydantic_json_dumps, null=True) + time = FloatField() + + +def aggregate_to_model(message: AggregateMessage) -> Dict: + return { + "original_message_hash": str(message.item_hash), + "address": str(message.sender), + "key": str(message.content.key), + "channel": message.channel, + "content": message.content.content, + "time": message.time, + } diff --git a/src/aleph/sdk/db/common.py b/src/aleph/sdk/db/common.py new file mode 100644 index 00000000..2b4ccb40 --- /dev/null +++ b/src/aleph/sdk/db/common.py @@ -0,0 +1,34 @@ +import json +from functools import partial +from typing import Generic, Optional, TypeVar + +from playhouse.sqlite_ext import JSONField +from pydantic import BaseModel +from pydantic.json import pydantic_encoder + +T = TypeVar("T", bound=BaseModel) + + +pydantic_json_dumps = partial(json.dumps, default=pydantic_encoder) + + +class PydanticField(JSONField, Generic[T]): + """ + A field for storing pydantic model types as JSON in a database. Uses json for serialization. + """ + + type: T + + def __init__(self, *args, **kwargs): + self.type = kwargs.pop("type") + super().__init__(*args, **kwargs) + + def db_value(self, value: Optional[T]) -> Optional[str]: + if value is None: + return None + return pydantic_json_dumps(value) + + def python_value(self, value: Optional[str]) -> Optional[T]: + if not value: + return None + return self.type.parse_raw(value) diff --git a/src/aleph/sdk/db/message.py b/src/aleph/sdk/db/message.py new file mode 100644 index 00000000..962de88c --- /dev/null +++ b/src/aleph/sdk/db/message.py @@ -0,0 +1,124 @@ +from typing import Any, Dict, Iterable + +from aleph_message import parse_message +from aleph_message.models import AlephMessage, MessageConfirmation +from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField + +from ..query.filters import MessageFilter +from .common import PydanticField, pydantic_json_dumps + + +class MessageDBModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + item_hash = CharField(primary_key=True) + chain = CharField(5) + type = CharField(9) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField(null=True) + signature = CharField(null=True) + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + hash_type = CharField(6, null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + forgotten_by = CharField(null=True) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + +def message_to_model(message: AlephMessage) -> Dict: + return { + "item_hash": str(message.item_hash), + "chain": message.chain, + "type": message.type, + "sender": message.sender, + "channel": message.channel, + "confirmations": message.confirmations[0] if message.confirmations else None, + "confirmed": message.confirmed, + "signature": message.signature, + "size": message.size, + "time": message.time, + "item_type": message.item_type, + "item_content": message.item_content, + "hash_type": message.hash_type, + "content": message.content, + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, + "tags": message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None, + "key": message.content.key if hasattr(message.content, "key") else None, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "content_type": message.content.type + if hasattr(message.content, "type") + else None, + } + + +def model_to_message(item: Any) -> AlephMessage: + item.confirmations = [item.confirmations] if item.confirmations else [] + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None + + to_exclude = [ + MessageDBModel.tags, + MessageDBModel.ref, + MessageDBModel.key, + MessageDBModel.content_type, + ] + + item_dict = model_to_dict(item, exclude=to_exclude) + return parse_message(item_dict) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(MessageDBModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def message_filter_to_query(filter: MessageFilter) -> MessageDBModel: + query = MessageDBModel.select().order_by(MessageDBModel.time.desc()) + conditions = [] + if filter.message_types: + conditions.append( + query_field("type", [type.value for type in filter.message_types]) + ) + if filter.content_keys: + conditions.append(query_field("key", filter.content_keys)) + if filter.content_types: + conditions.append(query_field("content_type", filter.content_types)) + if filter.refs: + conditions.append(query_field("ref", filter.refs)) + if filter.addresses: + conditions.append(query_field("sender", filter.addresses)) + if filter.tags: + for tag in filter.tags: + conditions.append(MessageDBModel.tags.contains(tag)) + if filter.hashes: + conditions.append(query_field("item_hash", filter.hashes)) + if filter.channels: + conditions.append(query_field("channel", filter.channels)) + if filter.chains: + conditions.append(query_field("chain", filter.chains)) + if filter.start_date: + conditions.append(MessageDBModel.time >= filter.start_date) + if filter.end_date: + conditions.append(MessageDBModel.time <= filter.end_date) + + if conditions: + query = query.where(*conditions) + return query diff --git a/src/aleph/sdk/db/post.py b/src/aleph/sdk/db/post.py new file mode 100644 index 00000000..b842e7d6 --- /dev/null +++ b/src/aleph/sdk/db/post.py @@ -0,0 +1,140 @@ +from typing import Any, Dict, Iterable + +from aleph_message.models import MessageConfirmation, PostMessage +from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField + +from ..query.filters import PostFilter +from ..query.responses import Post +from .common import PydanticField, pydantic_json_dumps + + +class PostDBModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + original_item_hash = CharField(primary_key=True) + original_type = CharField() + original_signature = CharField() + item_hash = CharField() + chain = CharField(5) + type = CharField(index=True) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField() + signature = CharField() + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + @classmethod + def get_all_fields(cls): + return cls._meta.sorted_field_names + + +def post_to_model(post: Post) -> Dict: + return { + "item_hash": str(post.original_item_hash), + "chain": post.chain.value, + "type": post.type, + "sender": post.sender, + "channel": post.channel, + "confirmations": post.confirmations[0] if post.confirmations else None, + "confirmed": post.confirmed, + "signature": post.signature, + "size": post.size, + "time": post.time, + "original_item_hash": str(post.original_item_hash), + "original_type": post.original_type if post.original_type else post.type, + "original_signature": post.original_signature + if post.original_signature + else post.signature, + "item_type": post.item_type, + "item_content": post.item_content, + "content": post.content, + "tags": post.content.content.get("tags", None) + if hasattr(post.content, "content") + else None, + "ref": post.ref, + } + + +def message_to_post(message: PostMessage) -> Post: + return Post( + chain=message.chain, + item_hash=message.item_hash, + sender=message.sender, + type=message.content.type, + channel=message.channel, + confirmed=message.confirmed if message.confirmed else False, + confirmations=message.confirmations if message.confirmations else [], + content=message.content, + item_content=message.item_content, + item_type=message.item_type, + signature=message.signature, + size=message.size if message.size else len(message.content.json().encode()), + time=message.time, + original_item_hash=message.item_hash, + original_signature=message.signature, + original_type=message.content.type, + hash=message.item_hash, + ref=message.content.ref, + ) + + +def model_to_post(item: Any) -> Post: + to_exclude = [PostDBModel.tags] + model_dict = model_to_dict(item, exclude=to_exclude) + model_dict["confirmations"] = ( + [model_dict["confirmations"]] if model_dict["confirmations"] else [] + ) + model_dict["hash"] = model_dict["item_hash"] + return Post.parse_obj(model_dict) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(PostDBModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def post_filter_to_query(filter: PostFilter) -> PostDBModel: + query = PostDBModel.select().order_by(PostDBModel.time.desc()) + conditions = [] + if filter.types: + conditions.append(query_field("type", filter.types)) + if filter.refs: + conditions.append(query_field("ref", filter.refs)) + if filter.addresses: + conditions.append(query_field("sender", filter.addresses)) + if filter.tags: + for tag in filter.tags: + conditions.append(PostDBModel.tags.contains(tag)) + if filter.hashes: + conditions.append(query_field("original_item_hash", filter.hashes)) + if filter.channels: + conditions.append(query_field("channel", filter.channels)) + if filter.chains: + conditions.append(query_field("chain", filter.chains)) + if filter.start_date: + conditions.append(PostDBModel.time >= filter.start_date) + if filter.end_date: + conditions.append(PostDBModel.time <= filter.end_date) + + if conditions: + query = query.where(*conditions) + return query diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index f2cd96d6..13d8ea59 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -62,3 +62,11 @@ class ForgottenMessageError(QueryError): """The requested message was forgotten""" pass + + +class InvalidCacheDatabaseSchema(Exception): + """ + The cache's database schema is invalid. + """ + + pass diff --git a/src/aleph/sdk/vm/cache.py b/src/aleph/sdk/vm/cache.py index ff5ca7c8..e02e5d85 100644 --- a/src/aleph/sdk/vm/cache.py +++ b/src/aleph/sdk/vm/cache.py @@ -19,7 +19,7 @@ def sanitize_cache_key(key: str) -> CacheKey: class BaseVmCache(abc.ABC): - """Virtual Machines can use this cache to store temporary data in memory on the host.""" + """Virtual Machines can use this db to store temporary data in memory on the host.""" @abc.abstractmethod async def get(self, key: str) -> Optional[bytes]: @@ -43,7 +43,7 @@ async def keys(self, pattern: str = "*") -> List[str]: class VmCache(BaseVmCache): - """Virtual Machines can use this cache to store temporary data in memory on the host.""" + """Virtual Machines can use this db to store temporary data in memory on the host.""" session: aiohttp.ClientSession cache: Dict[str, bytes] @@ -74,7 +74,7 @@ def __init__( async def get(self, key: str) -> Optional[bytes]: sanitized_key = sanitize_cache_key(key) - async with self.session.get(f"{self.api_host}/cache/{sanitized_key}") as resp: + async with self.session.get(f"{self.api_host}/db/{sanitized_key}") as resp: if resp.status == 404: return None @@ -85,16 +85,14 @@ async def set(self, key: str, value: Union[str, bytes]) -> Any: sanitized_key = sanitize_cache_key(key) data = value if isinstance(value, bytes) else value.encode() async with self.session.put( - f"{self.api_host}/cache/{sanitized_key}", data=data + f"{self.api_host}/db/{sanitized_key}", data=data ) as resp: resp.raise_for_status() return await resp.json() async def delete(self, key: str) -> Any: sanitized_key = sanitize_cache_key(key) - async with self.session.delete( - f"{self.api_host}/cache/{sanitized_key}" - ) as resp: + async with self.session.delete(f"{self.api_host}/db/{sanitized_key}") as resp: resp.raise_for_status() return await resp.json() @@ -103,15 +101,13 @@ async def keys(self, pattern: str = "*") -> List[str]: raise ValueError( "Pattern may only contain letters, numbers, underscore, ?, *, ^, -" ) - async with self.session.get( - f"{self.api_host}/cache/?pattern={pattern}" - ) as resp: + async with self.session.get(f"{self.api_host}/db/?pattern={pattern}") as resp: resp.raise_for_status() return await resp.json() class LocalVmCache(BaseVmCache): - """This is a local, dict-based cache that can be used for testing purposes.""" + """This is a local, dict-based db that can be used for testing purposes.""" def __init__(self): self._cache: Dict[str, bytes] = {} diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index 377e6d41..b8fc299c 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -1,6 +1,9 @@ +from io import BytesIO + import pytest -from aleph.sdk import AlephHttpClient +from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient +from aleph.sdk.client import LightNode from aleph.sdk.conf import settings as sdk_settings @@ -19,6 +22,19 @@ async def test_download(file_hash: str, expected_size: int): assert file_size == expected_size +@pytest.mark.asyncio +async def test_download_light_node(solana_account): + session = AuthenticatedAlephHttpClient( + solana_account, api_server=sdk_settings.API_HOST + ) + async with LightNode(session=session) as node: + file_content = await node.download_file( + "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH" + ) + file_size = len(file_content) + assert file_size == 5 + + @pytest.mark.parametrize( "file_hash,expected_size", [ @@ -32,3 +48,20 @@ async def test_download_ipfs(file_hash: str, expected_size: int): file_content = await client.download_file_ipfs(file_hash) # 5817703 B FILE file_size = len(file_content) assert file_size == expected_size + + +@pytest.mark.asyncio +async def test_download_to_buffer_light_node(solana_account): + session = AuthenticatedAlephHttpClient( + solana_account, api_server=sdk_settings.API_HOST + ) + async with LightNode(session=session) as node: + item_hash = "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH" + del node[item_hash] + buffer = BytesIO() + await node.download_file_to_buffer( + "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH", + buffer, + ) + file_size = buffer.getbuffer().nbytes + assert file_size == 5 diff --git a/tests/unit/test_light_node.py b/tests/unit/test_light_node.py new file mode 100644 index 00000000..318babc9 --- /dev/null +++ b/tests/unit/test_light_node.py @@ -0,0 +1,262 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock + +import pytest +import pytest_asyncio +from aleph_message.models import ( + AggregateMessage, + ForgetMessage, + MessageType, + PostMessage, + ProgramMessage, + StoreMessage, +) +from aleph_message.status import MessageStatus + +from aleph.sdk import AuthenticatedAlephHttpClient +from aleph.sdk.client.light_node import LightNode +from aleph.sdk.conf import settings +from aleph.sdk.types import Account, StorageEnum + + +class MockPostResponse: + def __init__(self, response_message: Any, sync: bool): + self.response_message = response_message + self.sync = sync + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 if self.sync else 202 + + def raise_for_status(self): + if self.status not in [200, 202]: + raise Exception("Bad status code") + + async def json(self): + message_status = "processed" if self.sync else "pending" + return { + "message_status": message_status, + "publication_status": {"status": "success", "failed": []}, + "hash": "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + "message": self.response_message, + } + + async def text(self): + return json.dumps(await self.json()) + + +class MockGetResponse: + def __init__(self, response_message, page=1): + self.response_message = response_message + self.page = page + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 + + def raise_for_status(self): + if self.status != 200: + raise Exception("Bad status code") + + async def json(self): + return self.response_message(self.page) + + +@pytest.fixture +def mock_session_with_two_messages( + ethereum_account: Account, raw_messages_response: Dict[str, Any] +) -> AuthenticatedAlephHttpClient: + http_session = AsyncMock() + http_session.post = MagicMock() + http_session.post.side_effect = lambda *args, **kwargs: MockPostResponse( + response_message={ + "type": "post", + "channel": "TEST", + "content": {"Hello": "World"}, + "key": "QmBlahBlahBlah", + "item_hash": "QmBlahBlahBlah", + }, + sync=kwargs.get("sync", False), + ) + http_session.get = MagicMock() + http_session.get.side_effect = lambda *args, **kwargs: MockGetResponse( + response_message=raw_messages_response, + page=kwargs.get("params", {}).get("page", 1), + ) + + client = AuthenticatedAlephHttpClient( + account=ethereum_account, api_server="http://localhost" + ) + client.http_session = http_session + + return client + + +@pytest.mark.asyncio +async def test_node_init(mock_session_with_two_messages): + node = LightNode(session=mock_session_with_two_messages) + await node.run() + assert node.session == mock_session_with_two_messages + assert len(node) >= 2 + + +@pytest_asyncio.fixture +async def mock_node_with_post_success(mock_session_with_two_messages): + node = LightNode(session=mock_session_with_two_messages) + await node.run() + return node + + +@pytest.mark.asyncio +async def test_create_post(mock_node_with_post_success): + async with mock_node_with_post_success as session: + content = {"Hello": "World"} + + post_message, message_status = await session.create_post( + post_content=content, + post_type="TEST", + channel="TEST", + sync=False, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(post_message, PostMessage) + assert message_status == MessageStatus.PENDING + + +@pytest.mark.asyncio +async def test_create_aggregate(mock_node_with_post_success): + async with mock_node_with_post_success as session: + aggregate_message, message_status = await session.create_aggregate( + key="hello", + content={"Hello": "world"}, + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(aggregate_message, AggregateMessage) + + +@pytest.mark.asyncio +async def test_create_store(mock_node_with_post_success): + mock_ipfs_push_file = AsyncMock() + mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + + mock_node_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_node_with_post_success as node: + _ = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + _ = await node.create_store( + file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + mock_storage_push_file = AsyncMock() + mock_storage_push_file.return_value = ( + "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + ) + mock_node_with_post_success.storage_push_file = mock_storage_push_file + async with mock_node_with_post_success as node: + store_message, message_status = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.storage, + ) + + assert mock_node_with_post_success.session.http_session.post.called + assert isinstance(store_message, StoreMessage) + + +@pytest.mark.asyncio +async def test_create_program(mock_node_with_post_success): + async with mock_node_with_post_success as node: + program_message, message_status = await node.create_program( + program_ref="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + entrypoint="main:app", + runtime="facefacefacefacefacefacefacefacefacefacefacefacefacefacefaceface", + channel="TEST", + metadata={"tags": ["test"]}, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(program_message, ProgramMessage) + + +@pytest.mark.asyncio +async def test_forget(mock_node_with_post_success): + async with mock_node_with_post_success as node: + forget_message, message_status = await node.forget( + hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], + reason="GDPR", + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(forget_message, ForgetMessage) + + +@pytest.mark.asyncio +async def test_download_file(mock_node_with_post_success): + mock_node_with_post_success.session.download_file = AsyncMock() + mock_node_with_post_success.session.download_file.return_value = b"HELLO" + + # remove file locally + if os.path.exists(settings.CACHE_FILES_PATH / Path("QmAndSoOn")): + os.remove(settings.CACHE_FILES_PATH / Path("QmAndSoOn")) + + # fetch from mocked response + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert mock_node_with_post_success.session.http_session.get.called_once + assert file_content == b"HELLO" + + # fetch cached + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert file_content == b"HELLO" + + +@pytest.mark.asyncio +async def test_submit_message(mock_node_with_post_success): + content = {"Hello": "World"} + async with mock_node_with_post_success as node: + message, status = await node.submit( + content={ + "address": "0x1234567890123456789012345678901234567890", + "time": 1234567890, + "type": "TEST", + "content": content, + }, + message_type=MessageType.post, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert message.content.content == content + assert status == MessageStatus.PENDING diff --git a/tests/unit/test_message_cache.py b/tests/unit/test_message_cache.py new file mode 100644 index 00000000..c8b445a3 --- /dev/null +++ b/tests/unit/test_message_cache.py @@ -0,0 +1,321 @@ +import json +from hashlib import sha256 +from typing import List + +import pytest +from aleph_message.models import ( + AlephMessage, + Chain, + MessageType, + PostContent, + PostMessage, +) + +from aleph.sdk.chains.ethereum import get_fallback_account +from aleph.sdk.client.message_cache import MessageCache +from aleph.sdk.db.post import message_to_post +from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.query.filters import MessageFilter, PostFilter + + +@pytest.mark.asyncio +async def test_base(aleph_messages): + # test add_many + cache = MessageCache() + cache.add(aleph_messages) + assert len(cache) >= len(aleph_messages) + + item_hashes = [message.item_hash for message in aleph_messages] + cached_messages = cache.get(item_hashes) + assert len(cached_messages) == len(aleph_messages) + + for message in aleph_messages: + assert cache[message.item_hash] == message + + for message in aleph_messages: + assert message.item_hash in cache + + for message in cache: + del cache[message.item_hash] + assert message.item_hash not in cache + + assert len(cache) == 0 + del cache + + +class TestMessageQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, aleph_messages): + self.messages = aleph_messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_iterate(self): + assert len(self.cache) == len(self.messages) + for message in self.cache: + assert message in self.messages + + @pytest.mark.asyncio + async def test_addresses(self): + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + addresses=[self.messages[0].sender], + ) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len( + ( + await self.cache.get_messages( + message_filter=MessageFilter(tags=["thistagdoesnotexist"]) + ) + ).messages + ) + == 0 + ) + + @pytest.mark.asyncio + async def test_message_type(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(message_types=[MessageType.post]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_refs(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(refs=[self.messages[1].content.ref]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_hashes(self): + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(hashes=[self.messages[0].item_hash]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_pagination(self): + assert len((await self.cache.get_messages(pagination=1)).messages) == 1 + + @pytest.mark.asyncio + async def test_content_types(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + content_types=[self.messages[1].content.type] + ) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_channels(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(channels=[self.messages[1].channel]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_chains(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(chains=[self.messages[1].chain]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_content_keys(self): + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + content_keys=[self.messages[0].content.key] + ) + ) + ).messages + ) + + +class TestPostQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, aleph_messages): + self.messages = aleph_messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_addresses(self): + assert ( + message_to_post(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(addresses=[self.messages[1].sender]) + ) + ).posts + ) + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len( + ( + await self.cache.get_posts( + post_filter=PostFilter(tags=["thistagdoesnotexist"]) + ) + ).posts + ) + == 0 + ) + + @pytest.mark.asyncio + async def test_types(self): + assert ( + len( + ( + await self.cache.get_posts( + post_filter=PostFilter(types=["thistypedoesnotexist"]) + ) + ).posts + ) + == 0 + ) + + @pytest.mark.asyncio + async def test_channels(self): + assert ( + message_to_post(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(channels=[self.messages[1].channel]) + ) + ).posts + ) + + @pytest.mark.asyncio + async def test_chains(self): + assert ( + message_to_post(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(chains=[self.messages[1].chain]) + ) + ).posts + ) + + +@pytest.mark.asyncio +async def test_message_cache_listener(): + async def mock_message_stream(): + for i in range(3): + content = PostContent( + content={"hello": f"world{i}"}, + type="test", + address=get_fallback_account().get_address(), + time=0, + ) + message = PostMessage( + sender=get_fallback_account().get_address(), + item_hash=sha256(json.dumps(content.dict()).encode()).hexdigest(), + chain=Chain.ETH.value, + type=MessageType.post.value, + item_type="inline", + time=0, + content=content, + item_content=json.dumps(content.dict()), + signature="", + ) + yield message + + cache = MessageCache() + # test listener + coro = cache.listen_to(mock_message_stream()) + await coro + assert len(cache) >= 3 + + +@pytest.mark.asyncio +async def test_fetch_aggregate(aleph_messages): + cache = MessageCache() + cache.add(aleph_messages) + + aggregate = await cache.fetch_aggregate( + aleph_messages[0].sender, aleph_messages[0].content.key + ) + + print(aggregate) + + assert aggregate == aleph_messages[0].content.content + + +@pytest.mark.asyncio +async def test_fetch_aggregates(aleph_messages): + cache = MessageCache() + cache.add(aleph_messages) + + aggregates = await cache.fetch_aggregates(aleph_messages[0].sender) + + assert aggregates == { + aleph_messages[0].content.key: aleph_messages[0].content.content + } + + +@pytest.mark.asyncio +async def test_get_message(aleph_messages): + cache = MessageCache() + cache.add(aleph_messages) + + message: AlephMessage = await cache.get_message(aleph_messages[0].item_hash) + + assert message == aleph_messages[0] + + +@pytest.mark.asyncio +async def test_get_message_fail(): + cache = MessageCache() + + with pytest.raises(MessageNotFoundError): + await cache.get_message("0x1234567890123456789012345678901234567890")