From af01b9a91933c8dfc7d17a84958f0689cd7339b6 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 25 Oct 2023 22:29:25 +0200 Subject: [PATCH 01/18] Feature: Add LightNode and MessageCache A LightNode can synchronize on a subset, or domain, of aleph.im messages. It relies on the MessageCache, which manages a message database with peewee. --- .gitignore | 1 + setup.cfg | 1 + src/aleph/sdk/__init__.py | 4 +- src/aleph/sdk/client/__init__.py | 4 + src/aleph/sdk/client/http.py | 8 + src/aleph/sdk/client/light_node.py | 394 +++++++++++++++++++++ src/aleph/sdk/client/message_cache.py | 490 ++++++++++++++++++++++++++ src/aleph/sdk/conf.py | 9 + src/aleph/sdk/db/aggregate.py | 31 ++ src/aleph/sdk/db/common.py | 34 ++ src/aleph/sdk/db/message.py | 124 +++++++ src/aleph/sdk/db/post.py | 140 ++++++++ src/aleph/sdk/exceptions.py | 8 + src/aleph/sdk/vm/cache.py | 18 +- tests/unit/test_download.py | 35 +- tests/unit/test_light_node.py | 262 ++++++++++++++ tests/unit/test_message_cache.py | 321 +++++++++++++++++ 17 files changed, 1870 insertions(+), 14 deletions(-) create mode 100644 src/aleph/sdk/client/light_node.py create mode 100644 src/aleph/sdk/client/message_cache.py create mode 100644 src/aleph/sdk/db/aggregate.py create mode 100644 src/aleph/sdk/db/common.py create mode 100644 src/aleph/sdk/db/message.py create mode 100644 src/aleph/sdk/db/post.py create mode 100644 tests/unit/test_light_node.py create mode 100644 tests/unit/test_message_cache.py diff --git a/.gitignore b/.gitignore index c4734889..a12a6219 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ *.pot __pycache__/* .cache/* +cache/**/* .*.swp */.ipynb_checkpoints/* diff --git a/setup.cfg b/setup.cfg index ca39e9ce..4d271446 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ install_requires = # Required to fix a dependency issue with parsimonious and Python3.11 eth_abi==4.0.0b2; python_version>="3.11" python-magic + peewee # The usage of test_requires is discouraged, see `Dependency Management` docs # tests_require = pytest; pytest-cov # Require a specific Python version, e.g. Python 2.7 or >= 3.4 diff --git a/src/aleph/sdk/__init__.py b/src/aleph/sdk/__init__.py index c14b64f6..416f12cd 100644 --- a/src/aleph/sdk/__init__.py +++ b/src/aleph/sdk/__init__.py @@ -1,6 +1,6 @@ from pkg_resources import DistributionNotFound, get_distribution -from aleph.sdk.client import AlephHttpClient, AuthenticatedAlephHttpClient +from aleph.sdk.client import AlephHttpClient, AuthenticatedAlephHttpClient, LightNode try: # Change here if project is renamed and does not equal the package name @@ -11,4 +11,4 @@ finally: del get_distribution, DistributionNotFound -__all__ = ["AlephHttpClient", "AuthenticatedAlephHttpClient"] +__all__ = ["AlephHttpClient", "AuthenticatedAlephHttpClient", "LightNode"] diff --git a/src/aleph/sdk/client/__init__.py b/src/aleph/sdk/client/__init__.py index 9ee25dd9..eadc6cac 100644 --- a/src/aleph/sdk/client/__init__.py +++ b/src/aleph/sdk/client/__init__.py @@ -1,10 +1,14 @@ from .abstract import AlephClient, AuthenticatedAlephClient from .authenticated_http import AuthenticatedAlephHttpClient from .http import AlephHttpClient +from .light_node import LightNode +from .message_cache import MessageCache __all__ = [ "AlephClient", "AuthenticatedAlephClient", "AlephHttpClient", "AuthenticatedAlephHttpClient", + "MessageCache", + "LightNode", ] diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index 0a46896b..69bde3c8 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -6,6 +6,7 @@ import aiohttp from aleph_message import parse_message from aleph_message.models import AlephMessage, ItemHash, ItemType +from aleph_message.status import MessageStatus from pydantic import ValidationError from ..conf import settings @@ -178,6 +179,8 @@ async def download_file_to_buffer( ) else: raise FileTooLarge(f"The file from {file_hash} is too large") + else: + response.raise_for_status() async def download_file_ipfs_to_buffer( self, @@ -343,6 +346,11 @@ async def get_message_error( "details": message_raw["details"], } + async def get_message_status(self, item_hash: str) -> MessageStatus: + async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp: + resp.raise_for_status() + return MessageStatus((await resp.json())["status"]) + async def watch_messages( self, message_filter: Optional[MessageFilter] = None, diff --git a/src/aleph/sdk/client/light_node.py b/src/aleph/sdk/client/light_node.py new file mode 100644 index 00000000..91506e8f --- /dev/null +++ b/src/aleph/sdk/client/light_node.py @@ -0,0 +1,394 @@ +import asyncio +from datetime import datetime +from io import BytesIO +from pathlib import Path +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union + +from aleph_message.models import AlephMessage, Chain, MessageType +from aleph_message.models.execution.base import Encoding +from aleph_message.status import MessageStatus + +from ..query.filters import MessageFilter +from ..types import StorageEnum +from ..utils import Writable +from .abstract import AuthenticatedAlephClient +from .authenticated_http import AuthenticatedAlephHttpClient +from .message_cache import MessageCache + + +class LightNode(MessageCache, AuthenticatedAlephClient): + """ + A LightNode is a client that can listen to the Aleph network and stores messages in a local database. Furthermore, + it can create messages and submit them to the network, as well as upload files, while keeping track of the + corresponding messages locally. + + It synchronizes with the network on a subset of the messages (the "domain") by listening to the network and storing + the messages in the database. The user may define the domain by specifying a channels, tags, senders, chains and/or + message types. + """ + + def __init__( + self, + session: AuthenticatedAlephHttpClient, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_types: Optional[Iterable[MessageType]] = None, + ): + """ + Initialize a LightNode. Besides requiring an established session with a core channel node, the user may specify + a domain to listen to. The domain is the intersection of the specified channels, tags, senders, chains and + message types. A smaller domain will synchronize faster, require less storage and less bandwidth. + + Args: + session: An authenticated session to an Aleph core channel node. + channels: The channels to listen to. + tags: The tags to listen to. + addresses: The addresses to listen to. + chains: The chains to listen to. + message_types: The message types to listen to. + + Raises: + InvalidCacheDatabaseSchema: If the database schema does not match the expected message schema. + """ + super().__init__() + self.session = session + self.channels = channels + self.tags = tags + self.addresses = ( + list(addresses) + [session.account.get_address()] + if addresses + else [session.account.get_address()] + ) + self.chains = ( + list(chains) + [Chain(session.account.CHAIN)] + if chains + else [session.account.CHAIN] + ) + self.message_types = message_types + + async def run(self): + """ + Start listening to the network and synchronize with past messages. + """ + asyncio.create_task( + self.listen_to( + self.session.watch_messages( + message_filter=MessageFilter( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_types=self.message_types, + ) + ) + ) + ) + # synchronize with past messages + await self.synchronize( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_types=self.message_types, + ) + + async def synchronize( + self, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_types: Optional[Iterable[MessageType]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + """ + Synchronize with past messages. + """ + chunk_size = 200 + messages = [] + async for message in self.session.get_messages_iterator( + message_filter=MessageFilter( + channels=channels, + tags=tags, + addresses=addresses, + chains=chains, + message_types=message_types, + start_date=start_date, + end_date=end_date, + ) + ): + messages.append(message) + if len(messages) >= chunk_size: + self.add(messages) + messages = [] + if messages: + self.add(messages) + + async def download_file(self, file_hash: str) -> bytes: + """ + Download a file from the network and store it locally. If it already exists locally, it will not be downloaded + again. + + Args: + file_hash: The hash of the file to download. + + Returns: + The file content. + + Raises: + FileNotFoundError: If the file does not exist on the network. + """ + try: + return await super().download_file(file_hash) + except FileNotFoundError: + pass + file = await self.session.download_file(file_hash) + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(file) + return file + + async def download_file_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the network and store it in a buffer. If it already exists locally, it will not be + downloaded again. + + Args: + file_hash: The hash of the file to download. + output_buffer: The buffer to store the file content in. + + Raises: + FileNotFoundError: If the file does not exist on the network. + """ + try: + return await super().download_file_to_buffer(file_hash, output_buffer) + except FileNotFoundError: + pass + buffer = BytesIO() + await self.session.download_file_ipfs_to_buffer(file_hash, buffer) + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(buffer.getvalue()) + output_buffer.write(buffer.getvalue()) + + def check_validity( + self, + message_type: MessageType, + address: Optional[str] = None, + channel: Optional[str] = None, + content: Optional[Dict] = None, + ): + if self.message_types and message_type not in self.message_types: + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to post messages." + ) + if address and self.addresses and address not in self.addresses: + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to messages from address {address}." + ) + if channel and self.channels and channel not in self.channels: + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to messages from channel {channel}." + ) + if ( + content + and self.tags + and not set(content.get("tags", [])).intersection(self.tags) + ): + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to any of these tags: {content.get('tags', [])}." + ) + + async def delete_if_rejected(self, item_hash): + async def _delete_if_rejected(): + await asyncio.sleep(5) + retries = 0 + status = await self.session.get_message_status(item_hash) + while status == MessageStatus.PENDING: + await asyncio.sleep(5) + status = await self.session.get_message_status(item_hash) + retries += 1 + if retries > 10: + raise TimeoutError( + f"Message {item_hash} has been pending for too long." + ) + if status in [MessageStatus.REJECTED, MessageStatus.FORGOTTEN]: + del self[item_hash] + + return _delete_if_rejected + + async def create_post( + self, + post_content: Any, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.post, address, channel, post_content) + resp, status = await self.session.create_post( + post_content=post_content, + post_type=post_type, + ref=ref, + address=address, + channel=channel, + inline=inline, + storage_engine=storage_engine, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status + + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.aggregate, address, channel) + resp, status = await self.session.create_aggregate( + key=key, + content=content, + address=address, + channel=channel, + inline=inline, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status + + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.store, address, channel, extra_fields) + resp, status = await self.session.create_store( + address=address, + file_content=file_content, + file_path=file_path, + file_hash=file_hash, + guess_mime_type=guess_mime_type, + ref=ref, + storage_engine=storage_engine, + extra_fields=extra_fields, + channel=channel, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status + + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity( + MessageType.program, address, channel, dict(metadata) if metadata else None + ) + resp, status = await self.session.create_program( + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + environment_variables=environment_variables, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, + metadata=metadata, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status + + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.forget, address, channel) + resp, status = await self.session.forget( + hashes=hashes, + reason=reason, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + ) + del self[resp.item_hash] + return resp, status + + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.submit( + content=content, + message_type=message_type, + channel=channel, + storage_engine=storage_engine, + allow_inlining=allow_inlining, + sync=sync, + ) + if status in [MessageStatus.PROCESSED, MessageStatus.PENDING]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status diff --git a/src/aleph/sdk/client/message_cache.py b/src/aleph/sdk/client/message_cache.py new file mode 100644 index 00000000..c0789bac --- /dev/null +++ b/src/aleph/sdk/client/message_cache.py @@ -0,0 +1,490 @@ +import logging +import typing +from datetime import datetime +from pathlib import Path +from typing import ( + AsyncIterable, + Coroutine, + Dict, + Iterable, + Iterator, + List, + Optional, + Type, + Union, +) + +from aleph_message import MessagesResponse +from aleph_message.models import AlephMessage, ItemHash, MessageType, PostMessage +from peewee import SqliteDatabase +from playhouse.shortcuts import model_to_dict + +from ..conf import settings +from ..db.aggregate import AggregateDBModel, aggregate_to_model +from ..db.message import ( + MessageDBModel, + message_filter_to_query, + message_to_model, + model_to_message, +) +from ..db.post import ( + PostDBModel, + message_to_post, + model_to_post, + post_filter_to_query, + post_to_model, +) +from ..exceptions import InvalidCacheDatabaseSchema, MessageNotFoundError +from ..query.filters import MessageFilter, PostFilter +from ..query.responses import PostsResponse +from ..types import GenericMessage +from ..utils import Writable +from .abstract import AlephClient + + +class MessageCache(AlephClient): + """ + A wrapper around a sqlite3 database for caching AlephMessage objects. + + It can be used independently of a DomainNode to implement any kind of caching strategy. + """ + + missing_posts: Dict[ItemHash, PostMessage] = {} + """A dict of all posts by item_hash and their amend messages that are missing from the db.""" + + def __init__(self, database_path: Optional[Union[str, Path]] = None): + """ + Args: + database_path: The path to the sqlite3 database file. If not provided, the default + path will be used. + + Note: + The database schema is automatically checked and updated if necessary. + + !!! warning + :memory: databases are not supported, as they do not persist across connections. + + Raises: + InvalidCacheDatabaseSchema: If the database schema does not match the expected message schema. + """ + self.database_path: Path = ( + Path(database_path) if database_path else settings.CACHE_DATABASE_PATH + ) + if not self.database_path.exists(): + self.database_path.parent.mkdir(parents=True, exist_ok=True) + + self.db = SqliteDatabase(self.database_path) + MessageDBModel._meta.database = self.db + PostDBModel._meta.database = self.db + AggregateDBModel._meta.database = self.db + + self.db.connect(reuse_if_open=True) + if not MessageDBModel.table_exists(): + self.db.create_tables([MessageDBModel]) + if not PostDBModel.table_exists(): + self.db.create_tables([PostDBModel]) + if not AggregateDBModel.table_exists(): + self.db.create_tables([AggregateDBModel]) + self._check_schema() + self.db.close() + + def _check_schema(self): + if sorted(MessageDBModel._meta.fields.keys()) != sorted( + [col.name for col in self.db.get_columns(MessageDBModel._meta.table_name)] + ): + raise InvalidCacheDatabaseSchema( + "MessageDBModel schema does not match MessageModel schema" + ) + if sorted(PostDBModel._meta.fields.keys()) != sorted( + [col.name for col in self.db.get_columns(PostDBModel._meta.table_name)] + ): + raise InvalidCacheDatabaseSchema( + "PostDBModel schema does not match PostModel schema" + ) + if sorted(AggregateDBModel._meta.fields.keys()) != sorted( + [col.name for col in self.db.get_columns(AggregateDBModel._meta.table_name)] + ): + raise InvalidCacheDatabaseSchema( + "AggregateDBModel schema does not match AggregateModel schema" + ) + + async def __aenter__(self): + self.db.connect(reuse_if_open=True) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.db.close() + + def __del__(self): + self.db.close() + + def __getitem__(self, item_hash: ItemHash) -> Optional[AlephMessage]: + item = MessageDBModel.get_or_none(MessageDBModel.item_hash == str(item_hash)) + return model_to_message(item) if item else None + + def __delitem__(self, item_hash: ItemHash): + MessageDBModel.delete().where( + MessageDBModel.item_hash == str(item_hash) + ).execute() + PostDBModel.delete().where( + PostDBModel.original_item_hash == str(item_hash) + ).execute() + AggregateDBModel.delete().where( + AggregateDBModel.original_message_hash == str(item_hash) + ).execute() + # delete stored files + file_path = self._file_path(str(item_hash)) + if file_path.exists(): + file_path.unlink() + + def __contains__(self, item_hash: ItemHash) -> bool: + return ( + MessageDBModel.select() + .where(MessageDBModel.item_hash == str(item_hash)) + .exists() + ) + + def __len__(self): + return MessageDBModel.select().count() + + def __iter__(self) -> Iterator[AlephMessage]: + """ + Iterate over all messages in the db, the latest first. + """ + for item in iter(MessageDBModel.select().order_by(-MessageDBModel.time)): + yield model_to_message(item) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return repr(self) + + @typing.overload + def add(self, messages: Iterable[AlephMessage]): + ... + + @typing.overload + def add(self, messages: AlephMessage): + ... + + def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): + """ + Add a message or a list of messages to the database. If the message is an amend, + it will be applied to the original post. If the original post is not in the db, + the amend message will be stored in memory until the original post is added. + Aggregate message will be merged with any existing aggregate messages. + + Args: + messages: A message or list of messages to add to the database. + """ + if isinstance(messages, typing.get_args(AlephMessage)): + messages = [messages] + + message_data = (message_to_model(message) for message in messages) + MessageDBModel.insert_many(message_data).on_conflict_replace().execute() + + # Sort messages and insert posts first + post_data = [] + amend_messages = [] + aggregate_messages = [] + forget_messages = [] + for message in messages: + if message.type == MessageType.aggregate.value: + aggregate_messages.append(message) + continue + if message.type == MessageType.forget.value: + forget_messages.append(message) + continue + if message.type != MessageType.post.value: + continue + if message.content.type == "amend": + amend_messages.append(message) + continue + + post = post_to_model(message_to_post(message)) + post_data.append(post) + + # Check if we can now add any amend messages that had missing refs + if message.item_hash in self.missing_posts: + amend_messages += self.missing_posts.pop(message.item_hash) + + with self.db.atomic(): + PostDBModel.insert_many(post_data).on_conflict_replace().execute() + + self._handle_amends(amend_messages) + + self._handle_aggregates(aggregate_messages) + + self._handle_forgets(forget_messages) + + def _handle_amends(self, amend_messages: List[PostMessage]): + post_data = [] + for amend in amend_messages: + original_post = MessageDBModel.get_or_none( + MessageDBModel.original_item_hash == amend.content.ref + ) + if not original_post: + latest_amend = self.missing_posts.get(ItemHash(amend.content.ref)) + if latest_amend and amend.time < latest_amend.time: + self.missing_posts[ItemHash(amend.content.ref)] = amend + continue + + if datetime.fromtimestamp(amend.time) < original_post.last_updated: + continue + + original_post.item_hash = amend.item_hash + original_post.content = amend.content.content + original_post.original_item_hash = amend.content.ref + original_post.original_type = amend.content.type + original_post.address = amend.sender + original_post.channel = amend.channel + original_post.last_updated = datetime.fromtimestamp(amend.time) + post_data.append(model_to_dict(original_post)) + with self.db.atomic(): + PostDBModel.insert_many(post_data).on_conflict_replace().execute() + + def _handle_aggregates(self, aggregate_messages): + aggregate_data = [] + for aggregate in aggregate_messages: + existing_aggregate = AggregateDBModel.get_or_none( + AggregateDBModel.address == aggregate.sender, + AggregateDBModel.key == aggregate.content.key, + ) + if not existing_aggregate: + aggregate_data.append(aggregate_to_model(aggregate)) + continue + data = model_to_dict(existing_aggregate) + if aggregate.time > existing_aggregate.time: + existing_aggregate.content.update(aggregate.content.content) + existing_aggregate.time = aggregate.time + elif existing_aggregate.content is None: + existing_aggregate.content = aggregate.content.content + else: + existing_aggregate.content = dict( + aggregate.content.content, **existing_aggregate.content + ) + data = model_to_dict(existing_aggregate) + aggregate_data.append(data) + with self.db.atomic(): + AggregateDBModel.insert_many(aggregate_data).on_conflict_replace().execute() + + def _handle_forgets(self, forget_messages): + refs = [forget.content.ref for forget in forget_messages] + with self.db.atomic(): + MessageDBModel.delete().where(MessageDBModel.item_hash.in_(refs)).execute() + PostDBModel.delete().where(PostDBModel.item_hash.in_(refs)).execute() + AggregateDBModel.delete().where( + AggregateDBModel.original_message_hash.in_(refs) + ).execute() + + @typing.overload + def get(self, item_hashes: Iterable[ItemHash]) -> List[AlephMessage]: + ... + + @typing.overload + def get(self, item_hashes: ItemHash) -> Optional[AlephMessage]: + ... + + def get( + self, item_hashes: Union[ItemHash, Iterable[ItemHash]] + ) -> List[AlephMessage]: + """ + Get many messages from the db by their item hash. + """ + if isinstance(item_hashes, ItemHash): + item_hashes = [item_hashes] + item_hashes = [str(item_hash) for item_hash in item_hashes] + items = ( + MessageDBModel.select() + .where(MessageDBModel.item_hash.in_(item_hashes)) + .execute() + ) + return [model_to_message(item) for item in items] + + def listen_to(self, message_stream: AsyncIterable[AlephMessage]) -> Coroutine: + """ + Listen to a stream of messages and add them to the database. + """ + + async def _listen(): + async for message in message_stream: + self.add(message) + print(f"Added message {message.item_hash} to db") + + return _listen() + + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: + item = ( + AggregateDBModel.select() + .where(AggregateDBModel.address == address) + .where(AggregateDBModel.key == key) + .order_by(AggregateDBModel.key.desc()) + .get_or_none() + ) + if not item: + raise MessageNotFoundError(f"No such aggregate {address} {key}") + return item.content + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Dict[str, Dict]: + query = ( + AggregateDBModel.select() + .where(AggregateDBModel.address == address) + .order_by(AggregateDBModel.key.desc()) + ) + if keys: + query = query.where(AggregateDBModel.key.in_(keys)) + return {item.key: item.content for item in list(query)} + + async def get_posts( + self, + pagination: int = 200, + page: int = 1, + post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> PostsResponse: + query = ( + post_filter_to_query(post_filter) if post_filter else PostDBModel.select() + ) + + query = query.paginate(page, pagination) + + posts = [model_to_post(item) for item in list(query)] + + return PostsResponse( + posts=posts, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="posts", + ) + + @staticmethod + def _file_path(file_hash: str) -> Path: + return settings.CACHE_FILES_PATH / Path(file_hash) + + async def download_file(self, file_hash: str) -> bytes: + """ + Opens a file that has been locally stored by its hash. + + Raises: + FileNotFoundError: If the file does not exist. + """ + with open(self._file_path(file_hash), "rb") as f: + return f.read() + + async def download_file_ipfs(self, file_hash: str) -> bytes: + """ + Opens a file that has been locally stored by its IPFS hash. + + Raises: + FileNotFoundError: If the file does not exist. + """ + return await self.download_file(file_hash) + + async def download_file_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Opens a file and writes it to a buffer. + + Raises: + FileNotFoundError: If the file does not exist. + """ + with open(self._file_path(file_hash), "rb") as f: + output_buffer.write(f.read()) + + async def download_file_to_buffer_ipfs( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Opens a file and writes it to a buffer. + + Raises: + FileNotFoundError: If the file does not exist. + """ + await self.download_file_to_buffer(file_hash, output_buffer) + + async def add_file(self, file_hash: str, file_content: bytes): + """ + Store a file locally by its hash. + """ + if not self._file_path(file_hash).exists(): + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(file_content) + + async def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> MessagesResponse: + """ + Get and filter many messages from the database. + """ + query = ( + message_filter_to_query(message_filter) + if message_filter + else MessageDBModel.select() + ) + + query = query.paginate(page, pagination) + + messages = [model_to_message(item) for item in list(query)] + + return MessagesResponse( + messages=messages, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="messages", + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + """ + Get a single message from the database by its item hash. + """ + query = MessageDBModel.select().where(MessageDBModel.item_hash == item_hash) + + if message_type: + query = query.where(MessageDBModel.type == message_type.value) + if channel: + query = query.where(MessageDBModel.channel == channel) + + item = query.first() + + if item: + return model_to_message(item) + + raise MessageNotFoundError(f"No such hash {item_hash}") + + async def watch_messages( + self, + message_filter: Optional[MessageFilter] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Watch new messages as they are added to the database. + """ + query = ( + message_filter_to_query(message_filter) + if message_filter + else MessageDBModel.select() + ) + + async for item in query: + yield model_to_message(item) diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index f8d798c6..626b1ba8 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -43,6 +43,15 @@ class Settings(BaseSettings): DNS_STATIC_DOMAIN = "static.public.aleph.sh" DNS_RESOLVERS = ["9.9.9.9", "1.1.1.1"] + CACHE_DATABASE_PATH: Path = Field( + default=Path("cache", "message_cache.sqlite"), + description="Path to the db database", + ) + CACHE_FILES_PATH: Path = Field( + default=Path("cache", "files"), + description="Path to the db files", + ) + class Config: env_prefix = "ALEPH_" case_sensitive = False diff --git a/src/aleph/sdk/db/aggregate.py b/src/aleph/sdk/db/aggregate.py new file mode 100644 index 00000000..f06c18f4 --- /dev/null +++ b/src/aleph/sdk/db/aggregate.py @@ -0,0 +1,31 @@ +from typing import Dict + +from aleph_message.models import AggregateMessage +from peewee import CharField, FloatField, Model +from playhouse.sqlite_ext import JSONField + +from .common import pydantic_json_dumps + + +class AggregateDBModel(Model): + """ + A simple database model for storing aleph.im Aggregates. + """ + + original_message_hash = CharField(primary_key=True) + address = CharField(index=True) + key = CharField() + channel = CharField(null=True) + content = JSONField(json_dumps=pydantic_json_dumps, null=True) + time = FloatField() + + +def aggregate_to_model(message: AggregateMessage) -> Dict: + return { + "original_message_hash": str(message.item_hash), + "address": str(message.sender), + "key": str(message.content.key), + "channel": message.channel, + "content": message.content.content, + "time": message.time, + } diff --git a/src/aleph/sdk/db/common.py b/src/aleph/sdk/db/common.py new file mode 100644 index 00000000..2b4ccb40 --- /dev/null +++ b/src/aleph/sdk/db/common.py @@ -0,0 +1,34 @@ +import json +from functools import partial +from typing import Generic, Optional, TypeVar + +from playhouse.sqlite_ext import JSONField +from pydantic import BaseModel +from pydantic.json import pydantic_encoder + +T = TypeVar("T", bound=BaseModel) + + +pydantic_json_dumps = partial(json.dumps, default=pydantic_encoder) + + +class PydanticField(JSONField, Generic[T]): + """ + A field for storing pydantic model types as JSON in a database. Uses json for serialization. + """ + + type: T + + def __init__(self, *args, **kwargs): + self.type = kwargs.pop("type") + super().__init__(*args, **kwargs) + + def db_value(self, value: Optional[T]) -> Optional[str]: + if value is None: + return None + return pydantic_json_dumps(value) + + def python_value(self, value: Optional[str]) -> Optional[T]: + if not value: + return None + return self.type.parse_raw(value) diff --git a/src/aleph/sdk/db/message.py b/src/aleph/sdk/db/message.py new file mode 100644 index 00000000..962de88c --- /dev/null +++ b/src/aleph/sdk/db/message.py @@ -0,0 +1,124 @@ +from typing import Any, Dict, Iterable + +from aleph_message import parse_message +from aleph_message.models import AlephMessage, MessageConfirmation +from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField + +from ..query.filters import MessageFilter +from .common import PydanticField, pydantic_json_dumps + + +class MessageDBModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + item_hash = CharField(primary_key=True) + chain = CharField(5) + type = CharField(9) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField(null=True) + signature = CharField(null=True) + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + hash_type = CharField(6, null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + forgotten_by = CharField(null=True) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + +def message_to_model(message: AlephMessage) -> Dict: + return { + "item_hash": str(message.item_hash), + "chain": message.chain, + "type": message.type, + "sender": message.sender, + "channel": message.channel, + "confirmations": message.confirmations[0] if message.confirmations else None, + "confirmed": message.confirmed, + "signature": message.signature, + "size": message.size, + "time": message.time, + "item_type": message.item_type, + "item_content": message.item_content, + "hash_type": message.hash_type, + "content": message.content, + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, + "tags": message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None, + "key": message.content.key if hasattr(message.content, "key") else None, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "content_type": message.content.type + if hasattr(message.content, "type") + else None, + } + + +def model_to_message(item: Any) -> AlephMessage: + item.confirmations = [item.confirmations] if item.confirmations else [] + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None + + to_exclude = [ + MessageDBModel.tags, + MessageDBModel.ref, + MessageDBModel.key, + MessageDBModel.content_type, + ] + + item_dict = model_to_dict(item, exclude=to_exclude) + return parse_message(item_dict) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(MessageDBModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def message_filter_to_query(filter: MessageFilter) -> MessageDBModel: + query = MessageDBModel.select().order_by(MessageDBModel.time.desc()) + conditions = [] + if filter.message_types: + conditions.append( + query_field("type", [type.value for type in filter.message_types]) + ) + if filter.content_keys: + conditions.append(query_field("key", filter.content_keys)) + if filter.content_types: + conditions.append(query_field("content_type", filter.content_types)) + if filter.refs: + conditions.append(query_field("ref", filter.refs)) + if filter.addresses: + conditions.append(query_field("sender", filter.addresses)) + if filter.tags: + for tag in filter.tags: + conditions.append(MessageDBModel.tags.contains(tag)) + if filter.hashes: + conditions.append(query_field("item_hash", filter.hashes)) + if filter.channels: + conditions.append(query_field("channel", filter.channels)) + if filter.chains: + conditions.append(query_field("chain", filter.chains)) + if filter.start_date: + conditions.append(MessageDBModel.time >= filter.start_date) + if filter.end_date: + conditions.append(MessageDBModel.time <= filter.end_date) + + if conditions: + query = query.where(*conditions) + return query diff --git a/src/aleph/sdk/db/post.py b/src/aleph/sdk/db/post.py new file mode 100644 index 00000000..b842e7d6 --- /dev/null +++ b/src/aleph/sdk/db/post.py @@ -0,0 +1,140 @@ +from typing import Any, Dict, Iterable + +from aleph_message.models import MessageConfirmation, PostMessage +from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField + +from ..query.filters import PostFilter +from ..query.responses import Post +from .common import PydanticField, pydantic_json_dumps + + +class PostDBModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + original_item_hash = CharField(primary_key=True) + original_type = CharField() + original_signature = CharField() + item_hash = CharField() + chain = CharField(5) + type = CharField(index=True) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField() + signature = CharField() + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + @classmethod + def get_all_fields(cls): + return cls._meta.sorted_field_names + + +def post_to_model(post: Post) -> Dict: + return { + "item_hash": str(post.original_item_hash), + "chain": post.chain.value, + "type": post.type, + "sender": post.sender, + "channel": post.channel, + "confirmations": post.confirmations[0] if post.confirmations else None, + "confirmed": post.confirmed, + "signature": post.signature, + "size": post.size, + "time": post.time, + "original_item_hash": str(post.original_item_hash), + "original_type": post.original_type if post.original_type else post.type, + "original_signature": post.original_signature + if post.original_signature + else post.signature, + "item_type": post.item_type, + "item_content": post.item_content, + "content": post.content, + "tags": post.content.content.get("tags", None) + if hasattr(post.content, "content") + else None, + "ref": post.ref, + } + + +def message_to_post(message: PostMessage) -> Post: + return Post( + chain=message.chain, + item_hash=message.item_hash, + sender=message.sender, + type=message.content.type, + channel=message.channel, + confirmed=message.confirmed if message.confirmed else False, + confirmations=message.confirmations if message.confirmations else [], + content=message.content, + item_content=message.item_content, + item_type=message.item_type, + signature=message.signature, + size=message.size if message.size else len(message.content.json().encode()), + time=message.time, + original_item_hash=message.item_hash, + original_signature=message.signature, + original_type=message.content.type, + hash=message.item_hash, + ref=message.content.ref, + ) + + +def model_to_post(item: Any) -> Post: + to_exclude = [PostDBModel.tags] + model_dict = model_to_dict(item, exclude=to_exclude) + model_dict["confirmations"] = ( + [model_dict["confirmations"]] if model_dict["confirmations"] else [] + ) + model_dict["hash"] = model_dict["item_hash"] + return Post.parse_obj(model_dict) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(PostDBModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def post_filter_to_query(filter: PostFilter) -> PostDBModel: + query = PostDBModel.select().order_by(PostDBModel.time.desc()) + conditions = [] + if filter.types: + conditions.append(query_field("type", filter.types)) + if filter.refs: + conditions.append(query_field("ref", filter.refs)) + if filter.addresses: + conditions.append(query_field("sender", filter.addresses)) + if filter.tags: + for tag in filter.tags: + conditions.append(PostDBModel.tags.contains(tag)) + if filter.hashes: + conditions.append(query_field("original_item_hash", filter.hashes)) + if filter.channels: + conditions.append(query_field("channel", filter.channels)) + if filter.chains: + conditions.append(query_field("chain", filter.chains)) + if filter.start_date: + conditions.append(PostDBModel.time >= filter.start_date) + if filter.end_date: + conditions.append(PostDBModel.time <= filter.end_date) + + if conditions: + query = query.where(*conditions) + return query diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index 39972f7f..4a7e3178 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -78,3 +78,11 @@ def __init__(self, required_funds: float, available_funds: float): super().__init__( f"Insufficient funds: required {required_funds}, available {available_funds}" ) + + +class InvalidCacheDatabaseSchema(Exception): + """ + The cache's database schema is invalid. + """ + + pass diff --git a/src/aleph/sdk/vm/cache.py b/src/aleph/sdk/vm/cache.py index ff5ca7c8..e02e5d85 100644 --- a/src/aleph/sdk/vm/cache.py +++ b/src/aleph/sdk/vm/cache.py @@ -19,7 +19,7 @@ def sanitize_cache_key(key: str) -> CacheKey: class BaseVmCache(abc.ABC): - """Virtual Machines can use this cache to store temporary data in memory on the host.""" + """Virtual Machines can use this db to store temporary data in memory on the host.""" @abc.abstractmethod async def get(self, key: str) -> Optional[bytes]: @@ -43,7 +43,7 @@ async def keys(self, pattern: str = "*") -> List[str]: class VmCache(BaseVmCache): - """Virtual Machines can use this cache to store temporary data in memory on the host.""" + """Virtual Machines can use this db to store temporary data in memory on the host.""" session: aiohttp.ClientSession cache: Dict[str, bytes] @@ -74,7 +74,7 @@ def __init__( async def get(self, key: str) -> Optional[bytes]: sanitized_key = sanitize_cache_key(key) - async with self.session.get(f"{self.api_host}/cache/{sanitized_key}") as resp: + async with self.session.get(f"{self.api_host}/db/{sanitized_key}") as resp: if resp.status == 404: return None @@ -85,16 +85,14 @@ async def set(self, key: str, value: Union[str, bytes]) -> Any: sanitized_key = sanitize_cache_key(key) data = value if isinstance(value, bytes) else value.encode() async with self.session.put( - f"{self.api_host}/cache/{sanitized_key}", data=data + f"{self.api_host}/db/{sanitized_key}", data=data ) as resp: resp.raise_for_status() return await resp.json() async def delete(self, key: str) -> Any: sanitized_key = sanitize_cache_key(key) - async with self.session.delete( - f"{self.api_host}/cache/{sanitized_key}" - ) as resp: + async with self.session.delete(f"{self.api_host}/db/{sanitized_key}") as resp: resp.raise_for_status() return await resp.json() @@ -103,15 +101,13 @@ async def keys(self, pattern: str = "*") -> List[str]: raise ValueError( "Pattern may only contain letters, numbers, underscore, ?, *, ^, -" ) - async with self.session.get( - f"{self.api_host}/cache/?pattern={pattern}" - ) as resp: + async with self.session.get(f"{self.api_host}/db/?pattern={pattern}") as resp: resp.raise_for_status() return await resp.json() class LocalVmCache(BaseVmCache): - """This is a local, dict-based cache that can be used for testing purposes.""" + """This is a local, dict-based db that can be used for testing purposes.""" def __init__(self): self._cache: Dict[str, bytes] = {} diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index 377e6d41..b8fc299c 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -1,6 +1,9 @@ +from io import BytesIO + import pytest -from aleph.sdk import AlephHttpClient +from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient +from aleph.sdk.client import LightNode from aleph.sdk.conf import settings as sdk_settings @@ -19,6 +22,19 @@ async def test_download(file_hash: str, expected_size: int): assert file_size == expected_size +@pytest.mark.asyncio +async def test_download_light_node(solana_account): + session = AuthenticatedAlephHttpClient( + solana_account, api_server=sdk_settings.API_HOST + ) + async with LightNode(session=session) as node: + file_content = await node.download_file( + "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH" + ) + file_size = len(file_content) + assert file_size == 5 + + @pytest.mark.parametrize( "file_hash,expected_size", [ @@ -32,3 +48,20 @@ async def test_download_ipfs(file_hash: str, expected_size: int): file_content = await client.download_file_ipfs(file_hash) # 5817703 B FILE file_size = len(file_content) assert file_size == expected_size + + +@pytest.mark.asyncio +async def test_download_to_buffer_light_node(solana_account): + session = AuthenticatedAlephHttpClient( + solana_account, api_server=sdk_settings.API_HOST + ) + async with LightNode(session=session) as node: + item_hash = "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH" + del node[item_hash] + buffer = BytesIO() + await node.download_file_to_buffer( + "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH", + buffer, + ) + file_size = buffer.getbuffer().nbytes + assert file_size == 5 diff --git a/tests/unit/test_light_node.py b/tests/unit/test_light_node.py new file mode 100644 index 00000000..318babc9 --- /dev/null +++ b/tests/unit/test_light_node.py @@ -0,0 +1,262 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock + +import pytest +import pytest_asyncio +from aleph_message.models import ( + AggregateMessage, + ForgetMessage, + MessageType, + PostMessage, + ProgramMessage, + StoreMessage, +) +from aleph_message.status import MessageStatus + +from aleph.sdk import AuthenticatedAlephHttpClient +from aleph.sdk.client.light_node import LightNode +from aleph.sdk.conf import settings +from aleph.sdk.types import Account, StorageEnum + + +class MockPostResponse: + def __init__(self, response_message: Any, sync: bool): + self.response_message = response_message + self.sync = sync + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 if self.sync else 202 + + def raise_for_status(self): + if self.status not in [200, 202]: + raise Exception("Bad status code") + + async def json(self): + message_status = "processed" if self.sync else "pending" + return { + "message_status": message_status, + "publication_status": {"status": "success", "failed": []}, + "hash": "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + "message": self.response_message, + } + + async def text(self): + return json.dumps(await self.json()) + + +class MockGetResponse: + def __init__(self, response_message, page=1): + self.response_message = response_message + self.page = page + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 + + def raise_for_status(self): + if self.status != 200: + raise Exception("Bad status code") + + async def json(self): + return self.response_message(self.page) + + +@pytest.fixture +def mock_session_with_two_messages( + ethereum_account: Account, raw_messages_response: Dict[str, Any] +) -> AuthenticatedAlephHttpClient: + http_session = AsyncMock() + http_session.post = MagicMock() + http_session.post.side_effect = lambda *args, **kwargs: MockPostResponse( + response_message={ + "type": "post", + "channel": "TEST", + "content": {"Hello": "World"}, + "key": "QmBlahBlahBlah", + "item_hash": "QmBlahBlahBlah", + }, + sync=kwargs.get("sync", False), + ) + http_session.get = MagicMock() + http_session.get.side_effect = lambda *args, **kwargs: MockGetResponse( + response_message=raw_messages_response, + page=kwargs.get("params", {}).get("page", 1), + ) + + client = AuthenticatedAlephHttpClient( + account=ethereum_account, api_server="http://localhost" + ) + client.http_session = http_session + + return client + + +@pytest.mark.asyncio +async def test_node_init(mock_session_with_two_messages): + node = LightNode(session=mock_session_with_two_messages) + await node.run() + assert node.session == mock_session_with_two_messages + assert len(node) >= 2 + + +@pytest_asyncio.fixture +async def mock_node_with_post_success(mock_session_with_two_messages): + node = LightNode(session=mock_session_with_two_messages) + await node.run() + return node + + +@pytest.mark.asyncio +async def test_create_post(mock_node_with_post_success): + async with mock_node_with_post_success as session: + content = {"Hello": "World"} + + post_message, message_status = await session.create_post( + post_content=content, + post_type="TEST", + channel="TEST", + sync=False, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(post_message, PostMessage) + assert message_status == MessageStatus.PENDING + + +@pytest.mark.asyncio +async def test_create_aggregate(mock_node_with_post_success): + async with mock_node_with_post_success as session: + aggregate_message, message_status = await session.create_aggregate( + key="hello", + content={"Hello": "world"}, + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(aggregate_message, AggregateMessage) + + +@pytest.mark.asyncio +async def test_create_store(mock_node_with_post_success): + mock_ipfs_push_file = AsyncMock() + mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + + mock_node_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_node_with_post_success as node: + _ = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + _ = await node.create_store( + file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + mock_storage_push_file = AsyncMock() + mock_storage_push_file.return_value = ( + "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + ) + mock_node_with_post_success.storage_push_file = mock_storage_push_file + async with mock_node_with_post_success as node: + store_message, message_status = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.storage, + ) + + assert mock_node_with_post_success.session.http_session.post.called + assert isinstance(store_message, StoreMessage) + + +@pytest.mark.asyncio +async def test_create_program(mock_node_with_post_success): + async with mock_node_with_post_success as node: + program_message, message_status = await node.create_program( + program_ref="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + entrypoint="main:app", + runtime="facefacefacefacefacefacefacefacefacefacefacefacefacefacefaceface", + channel="TEST", + metadata={"tags": ["test"]}, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(program_message, ProgramMessage) + + +@pytest.mark.asyncio +async def test_forget(mock_node_with_post_success): + async with mock_node_with_post_success as node: + forget_message, message_status = await node.forget( + hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], + reason="GDPR", + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(forget_message, ForgetMessage) + + +@pytest.mark.asyncio +async def test_download_file(mock_node_with_post_success): + mock_node_with_post_success.session.download_file = AsyncMock() + mock_node_with_post_success.session.download_file.return_value = b"HELLO" + + # remove file locally + if os.path.exists(settings.CACHE_FILES_PATH / Path("QmAndSoOn")): + os.remove(settings.CACHE_FILES_PATH / Path("QmAndSoOn")) + + # fetch from mocked response + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert mock_node_with_post_success.session.http_session.get.called_once + assert file_content == b"HELLO" + + # fetch cached + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert file_content == b"HELLO" + + +@pytest.mark.asyncio +async def test_submit_message(mock_node_with_post_success): + content = {"Hello": "World"} + async with mock_node_with_post_success as node: + message, status = await node.submit( + content={ + "address": "0x1234567890123456789012345678901234567890", + "time": 1234567890, + "type": "TEST", + "content": content, + }, + message_type=MessageType.post, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert message.content.content == content + assert status == MessageStatus.PENDING diff --git a/tests/unit/test_message_cache.py b/tests/unit/test_message_cache.py new file mode 100644 index 00000000..c8b445a3 --- /dev/null +++ b/tests/unit/test_message_cache.py @@ -0,0 +1,321 @@ +import json +from hashlib import sha256 +from typing import List + +import pytest +from aleph_message.models import ( + AlephMessage, + Chain, + MessageType, + PostContent, + PostMessage, +) + +from aleph.sdk.chains.ethereum import get_fallback_account +from aleph.sdk.client.message_cache import MessageCache +from aleph.sdk.db.post import message_to_post +from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.query.filters import MessageFilter, PostFilter + + +@pytest.mark.asyncio +async def test_base(aleph_messages): + # test add_many + cache = MessageCache() + cache.add(aleph_messages) + assert len(cache) >= len(aleph_messages) + + item_hashes = [message.item_hash for message in aleph_messages] + cached_messages = cache.get(item_hashes) + assert len(cached_messages) == len(aleph_messages) + + for message in aleph_messages: + assert cache[message.item_hash] == message + + for message in aleph_messages: + assert message.item_hash in cache + + for message in cache: + del cache[message.item_hash] + assert message.item_hash not in cache + + assert len(cache) == 0 + del cache + + +class TestMessageQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, aleph_messages): + self.messages = aleph_messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_iterate(self): + assert len(self.cache) == len(self.messages) + for message in self.cache: + assert message in self.messages + + @pytest.mark.asyncio + async def test_addresses(self): + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + addresses=[self.messages[0].sender], + ) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len( + ( + await self.cache.get_messages( + message_filter=MessageFilter(tags=["thistagdoesnotexist"]) + ) + ).messages + ) + == 0 + ) + + @pytest.mark.asyncio + async def test_message_type(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(message_types=[MessageType.post]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_refs(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(refs=[self.messages[1].content.ref]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_hashes(self): + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(hashes=[self.messages[0].item_hash]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_pagination(self): + assert len((await self.cache.get_messages(pagination=1)).messages) == 1 + + @pytest.mark.asyncio + async def test_content_types(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + content_types=[self.messages[1].content.type] + ) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_channels(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(channels=[self.messages[1].channel]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_chains(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(chains=[self.messages[1].chain]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_content_keys(self): + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + content_keys=[self.messages[0].content.key] + ) + ) + ).messages + ) + + +class TestPostQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, aleph_messages): + self.messages = aleph_messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_addresses(self): + assert ( + message_to_post(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(addresses=[self.messages[1].sender]) + ) + ).posts + ) + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len( + ( + await self.cache.get_posts( + post_filter=PostFilter(tags=["thistagdoesnotexist"]) + ) + ).posts + ) + == 0 + ) + + @pytest.mark.asyncio + async def test_types(self): + assert ( + len( + ( + await self.cache.get_posts( + post_filter=PostFilter(types=["thistypedoesnotexist"]) + ) + ).posts + ) + == 0 + ) + + @pytest.mark.asyncio + async def test_channels(self): + assert ( + message_to_post(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(channels=[self.messages[1].channel]) + ) + ).posts + ) + + @pytest.mark.asyncio + async def test_chains(self): + assert ( + message_to_post(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(chains=[self.messages[1].chain]) + ) + ).posts + ) + + +@pytest.mark.asyncio +async def test_message_cache_listener(): + async def mock_message_stream(): + for i in range(3): + content = PostContent( + content={"hello": f"world{i}"}, + type="test", + address=get_fallback_account().get_address(), + time=0, + ) + message = PostMessage( + sender=get_fallback_account().get_address(), + item_hash=sha256(json.dumps(content.dict()).encode()).hexdigest(), + chain=Chain.ETH.value, + type=MessageType.post.value, + item_type="inline", + time=0, + content=content, + item_content=json.dumps(content.dict()), + signature="", + ) + yield message + + cache = MessageCache() + # test listener + coro = cache.listen_to(mock_message_stream()) + await coro + assert len(cache) >= 3 + + +@pytest.mark.asyncio +async def test_fetch_aggregate(aleph_messages): + cache = MessageCache() + cache.add(aleph_messages) + + aggregate = await cache.fetch_aggregate( + aleph_messages[0].sender, aleph_messages[0].content.key + ) + + print(aggregate) + + assert aggregate == aleph_messages[0].content.content + + +@pytest.mark.asyncio +async def test_fetch_aggregates(aleph_messages): + cache = MessageCache() + cache.add(aleph_messages) + + aggregates = await cache.fetch_aggregates(aleph_messages[0].sender) + + assert aggregates == { + aleph_messages[0].content.key: aleph_messages[0].content.content + } + + +@pytest.mark.asyncio +async def test_get_message(aleph_messages): + cache = MessageCache() + cache.add(aleph_messages) + + message: AlephMessage = await cache.get_message(aleph_messages[0].item_hash) + + assert message == aleph_messages[0] + + +@pytest.mark.asyncio +async def test_get_message_fail(): + cache = MessageCache() + + with pytest.raises(MessageNotFoundError): + await cache.get_message("0x1234567890123456789012345678901234567890") From ae7309e2f2f5fd9116863c2c5e194fd6b486fdf3 Mon Sep 17 00:00:00 2001 From: mhh Date: Mon, 27 Nov 2023 17:18:36 +0100 Subject: [PATCH 02/18] Add `DateTimeField` to cache DB models; Add `create_instance` and adjust `create_program` methods on `LightNode` --- src/aleph/sdk/client/light_node.py | 57 +++++++++++++++++++++++++++ src/aleph/sdk/client/message_cache.py | 10 +++-- src/aleph/sdk/db/aggregate.py | 4 +- src/aleph/sdk/db/message.py | 4 +- src/aleph/sdk/db/post.py | 4 +- src/aleph/sdk/query/responses.py | 3 +- tests/unit/conftest.py | 5 ++- 7 files changed, 75 insertions(+), 12 deletions(-) diff --git a/src/aleph/sdk/client/light_node.py b/src/aleph/sdk/client/light_node.py index 91506e8f..a86431eb 100644 --- a/src/aleph/sdk/client/light_node.py +++ b/src/aleph/sdk/client/light_node.py @@ -319,6 +319,9 @@ async def create_program( vcpus: Optional[int] = None, timeout_seconds: Optional[float] = None, persistent: bool = False, + allow_amend: bool = False, + internet: bool = True, + aleph_api: bool = True, encoding: Encoding = Encoding.zip, volumes: Optional[List[Mapping]] = None, subscriptions: Optional[List[Mapping]] = None, @@ -340,6 +343,9 @@ async def create_program( vcpus=vcpus, timeout_seconds=timeout_seconds, persistent=persistent, + allow_amend=allow_amend, + internet=internet, + aleph_api=aleph_api, encoding=encoding, volumes=volumes, subscriptions=subscriptions, @@ -350,6 +356,57 @@ async def create_program( asyncio.create_task(self.delete_if_rejected(resp.item_hash)) return resp, status + async def create_instance( + self, + rootfs: str, + rootfs_size: int, + rootfs_name: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + allow_amend: bool = False, + internet: bool = True, + aleph_api: bool = True, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + volume_persistence: str = "host", + ssh_keys: Optional[List[str]] = None, + metadata: Optional[Mapping[str, Any]] = None, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity( + MessageType.instance, address, channel, dict(metadata) if metadata else None + ) + resp, status = await self.session.create_instance( + rootfs=rootfs, + rootfs_size=rootfs_size, + rootfs_name=rootfs_name, + environment_variables=environment_variables, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + allow_amend=allow_amend, + internet=internet, + aleph_api=aleph_api, + encoding=encoding, + volumes=volumes, + volume_persistence=volume_persistence, + ssh_keys=ssh_keys, + metadata=metadata, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + asyncio.create_task(self.delete_if_rejected(resp.item_hash)) + return resp, status + async def forget( self, hashes: List[str], diff --git a/src/aleph/sdk/client/message_cache.py b/src/aleph/sdk/client/message_cache.py index c0789bac..e5d7212c 100644 --- a/src/aleph/sdk/client/message_cache.py +++ b/src/aleph/sdk/client/message_cache.py @@ -1,6 +1,6 @@ +import datetime import logging import typing -from datetime import datetime from pathlib import Path from typing import ( AsyncIterable, @@ -230,7 +230,7 @@ def _handle_amends(self, amend_messages: List[PostMessage]): self.missing_posts[ItemHash(amend.content.ref)] = amend continue - if datetime.fromtimestamp(amend.time) < original_post.last_updated: + if amend.time < original_post.last_updated: continue original_post.item_hash = amend.item_hash @@ -239,7 +239,7 @@ def _handle_amends(self, amend_messages: List[PostMessage]): original_post.original_type = amend.content.type original_post.address = amend.sender original_post.channel = amend.channel - original_post.last_updated = datetime.fromtimestamp(amend.time) + original_post.last_updated = amend.time post_data.append(model_to_dict(original_post)) with self.db.atomic(): PostDBModel.insert_many(post_data).on_conflict_replace().execute() @@ -254,7 +254,9 @@ def _handle_aggregates(self, aggregate_messages): if not existing_aggregate: aggregate_data.append(aggregate_to_model(aggregate)) continue - data = model_to_dict(existing_aggregate) + existing_aggregate.time = datetime.datetime.fromisoformat( + existing_aggregate.time + ) if aggregate.time > existing_aggregate.time: existing_aggregate.content.update(aggregate.content.content) existing_aggregate.time = aggregate.time diff --git a/src/aleph/sdk/db/aggregate.py b/src/aleph/sdk/db/aggregate.py index f06c18f4..f7b71741 100644 --- a/src/aleph/sdk/db/aggregate.py +++ b/src/aleph/sdk/db/aggregate.py @@ -1,7 +1,7 @@ from typing import Dict from aleph_message.models import AggregateMessage -from peewee import CharField, FloatField, Model +from peewee import CharField, DateTimeField, Model from playhouse.sqlite_ext import JSONField from .common import pydantic_json_dumps @@ -17,7 +17,7 @@ class AggregateDBModel(Model): key = CharField() channel = CharField(null=True) content = JSONField(json_dumps=pydantic_json_dumps, null=True) - time = FloatField() + time = DateTimeField() def aggregate_to_model(message: AggregateMessage) -> Dict: diff --git a/src/aleph/sdk/db/message.py b/src/aleph/sdk/db/message.py index 962de88c..450f50cb 100644 --- a/src/aleph/sdk/db/message.py +++ b/src/aleph/sdk/db/message.py @@ -2,7 +2,7 @@ from aleph_message import parse_message from aleph_message.models import AlephMessage, MessageConfirmation -from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from peewee import BooleanField, CharField, DateTimeField, IntegerField, Model from playhouse.shortcuts import model_to_dict from playhouse.sqlite_ext import JSONField @@ -26,7 +26,7 @@ class MessageDBModel(Model): confirmed = BooleanField(null=True) signature = CharField(null=True) size = IntegerField(null=True) - time = FloatField() + time = DateTimeField() item_type = CharField(7) item_content = CharField(null=True) hash_type = CharField(6, null=True) diff --git a/src/aleph/sdk/db/post.py b/src/aleph/sdk/db/post.py index b842e7d6..0455ba4b 100644 --- a/src/aleph/sdk/db/post.py +++ b/src/aleph/sdk/db/post.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Iterable from aleph_message.models import MessageConfirmation, PostMessage -from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from peewee import BooleanField, CharField, DateTimeField, IntegerField, Model from playhouse.shortcuts import model_to_dict from playhouse.sqlite_ext import JSONField @@ -29,7 +29,7 @@ class PostDBModel(Model): confirmed = BooleanField() signature = CharField() size = IntegerField(null=True) - time = FloatField() + time = DateTimeField() item_type = CharField(7) item_content = CharField(null=True) content = JSONField(json_dumps=pydantic_json_dumps) diff --git a/src/aleph/sdk/query/responses.py b/src/aleph/sdk/query/responses.py index 5fb91804..d61596af 100644 --- a/src/aleph/sdk/query/responses.py +++ b/src/aleph/sdk/query/responses.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime from typing import Any, Dict, List, Optional, Union from aleph_message.models import ( @@ -35,7 +36,7 @@ class Post(BaseModel): description="Cryptographic signature of the message by the sender" ) size: int = Field(description="Size of the post") - time: float = Field(description="Timestamp of the post") + time: datetime = Field(description="Timestamp of the post") confirmations: List[MessageConfirmation] = Field( description="Number of confirmations" ) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 95cc7851..88f7c0b7 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -161,7 +161,10 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): ... - async def raise_for_status(self): + def raise_for_status(self): + ... + + async def close(self): ... @property From b6dab52f0d3a4d4b56d6d2c2e0882c7959a44dab Mon Sep 17 00:00:00 2001 From: mhh Date: Mon, 4 Dec 2023 09:40:37 +0100 Subject: [PATCH 03/18] Add `__init__.py` files for `db` and `query` modules --- src/aleph/sdk/db/__init__.py | 0 src/aleph/sdk/query/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/aleph/sdk/db/__init__.py create mode 100644 src/aleph/sdk/query/__init__.py diff --git a/src/aleph/sdk/db/__init__.py b/src/aleph/sdk/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/query/__init__.py b/src/aleph/sdk/query/__init__.py new file mode 100644 index 00000000..e69de29b From 77bdf1b832b2541f6a90d1349a0b78494c89473a Mon Sep 17 00:00:00 2001 From: mhh Date: Fri, 8 Dec 2023 11:27:24 +0100 Subject: [PATCH 04/18] Adjust light_node.submit() to fit new interface --- src/aleph/sdk/client/light_node.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/aleph/sdk/client/light_node.py b/src/aleph/sdk/client/light_node.py index a86431eb..0afd2d43 100644 --- a/src/aleph/sdk/client/light_node.py +++ b/src/aleph/sdk/client/light_node.py @@ -436,16 +436,18 @@ async def submit( storage_engine: StorageEnum = StorageEnum.storage, allow_inlining: bool = True, sync: bool = False, - ) -> Tuple[AlephMessage, MessageStatus]: - resp, status = await self.session.submit( + raise_on_rejected: bool = True, + ) -> Tuple[AlephMessage, MessageStatus, Optional[Dict[str, Any]]]: + message, status, response = await self.session.submit( content=content, message_type=message_type, channel=channel, storage_engine=storage_engine, allow_inlining=allow_inlining, sync=sync, + raise_on_rejected=raise_on_rejected ) if status in [MessageStatus.PROCESSED, MessageStatus.PENDING]: - self.add(resp) - asyncio.create_task(self.delete_if_rejected(resp.item_hash)) - return resp, status + self.add(message) + asyncio.create_task(self.delete_if_rejected(message.item_hash)) + return message, status, response From bf508d8c712169d158be26538d466b2206d0da63 Mon Sep 17 00:00:00 2001 From: mhh Date: Fri, 8 Dec 2023 11:36:55 +0100 Subject: [PATCH 05/18] Fix test using new submit() --- tests/unit/test_light_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_light_node.py b/tests/unit/test_light_node.py index 318babc9..662f1b82 100644 --- a/tests/unit/test_light_node.py +++ b/tests/unit/test_light_node.py @@ -247,7 +247,7 @@ async def test_download_file(mock_node_with_post_success): async def test_submit_message(mock_node_with_post_success): content = {"Hello": "World"} async with mock_node_with_post_success as node: - message, status = await node.submit( + message, status, _ = await node.submit( content={ "address": "0x1234567890123456789012345678901234567890", "time": 1234567890, From b1ad1a5b8aaa55e4ef39aaab24e4a0e0fec67b2b Mon Sep 17 00:00:00 2001 From: mhh Date: Fri, 8 Dec 2023 11:40:37 +0100 Subject: [PATCH 06/18] Fix formatting --- src/aleph/sdk/client/light_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aleph/sdk/client/light_node.py b/src/aleph/sdk/client/light_node.py index 0afd2d43..3fc2439f 100644 --- a/src/aleph/sdk/client/light_node.py +++ b/src/aleph/sdk/client/light_node.py @@ -445,7 +445,7 @@ async def submit( storage_engine=storage_engine, allow_inlining=allow_inlining, sync=sync, - raise_on_rejected=raise_on_rejected + raise_on_rejected=raise_on_rejected, ) if status in [MessageStatus.PROCESSED, MessageStatus.PENDING]: self.add(message) From 0f6f8ca90acedffd4f4a7016e1433755fdd7103e Mon Sep 17 00:00:00 2001 From: mhh Date: Thu, 1 Feb 2024 13:55:00 +0100 Subject: [PATCH 07/18] Reformat with new black --- src/aleph/sdk/client/message_cache.py | 12 ++++-------- src/aleph/sdk/db/message.py | 14 ++++++++------ src/aleph/sdk/db/post.py | 14 ++++++++------ tests/unit/test_light_node.py | 6 ++---- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/aleph/sdk/client/message_cache.py b/src/aleph/sdk/client/message_cache.py index e5d7212c..257c85e5 100644 --- a/src/aleph/sdk/client/message_cache.py +++ b/src/aleph/sdk/client/message_cache.py @@ -161,12 +161,10 @@ def __str__(self) -> str: return repr(self) @typing.overload - def add(self, messages: Iterable[AlephMessage]): - ... + def add(self, messages: Iterable[AlephMessage]): ... @typing.overload - def add(self, messages: AlephMessage): - ... + def add(self, messages: AlephMessage): ... def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): """ @@ -281,12 +279,10 @@ def _handle_forgets(self, forget_messages): ).execute() @typing.overload - def get(self, item_hashes: Iterable[ItemHash]) -> List[AlephMessage]: - ... + def get(self, item_hashes: Iterable[ItemHash]) -> List[AlephMessage]: ... @typing.overload - def get(self, item_hashes: ItemHash) -> Optional[AlephMessage]: - ... + def get(self, item_hashes: ItemHash) -> Optional[AlephMessage]: ... def get( self, item_hashes: Union[ItemHash, Iterable[ItemHash]] diff --git a/src/aleph/sdk/db/message.py b/src/aleph/sdk/db/message.py index 450f50cb..29e553c1 100644 --- a/src/aleph/sdk/db/message.py +++ b/src/aleph/sdk/db/message.py @@ -55,14 +55,16 @@ def message_to_model(message: AlephMessage) -> Dict: "hash_type": message.hash_type, "content": message.content, "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, - "tags": message.content.content.get("tags", None) - if hasattr(message.content, "content") - else None, + "tags": ( + message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None + ), "key": message.content.key if hasattr(message.content, "key") else None, "ref": message.content.ref if hasattr(message.content, "ref") else None, - "content_type": message.content.type - if hasattr(message.content, "type") - else None, + "content_type": ( + message.content.type if hasattr(message.content, "type") else None + ), } diff --git a/src/aleph/sdk/db/post.py b/src/aleph/sdk/db/post.py index 0455ba4b..c8815f9b 100644 --- a/src/aleph/sdk/db/post.py +++ b/src/aleph/sdk/db/post.py @@ -57,15 +57,17 @@ def post_to_model(post: Post) -> Dict: "time": post.time, "original_item_hash": str(post.original_item_hash), "original_type": post.original_type if post.original_type else post.type, - "original_signature": post.original_signature - if post.original_signature - else post.signature, + "original_signature": ( + post.original_signature if post.original_signature else post.signature + ), "item_type": post.item_type, "item_content": post.item_content, "content": post.content, - "tags": post.content.content.get("tags", None) - if hasattr(post.content, "content") - else None, + "tags": ( + post.content.content.get("tags", None) + if hasattr(post.content, "content") + else None + ), "ref": post.ref, } diff --git a/tests/unit/test_light_node.py b/tests/unit/test_light_node.py index 662f1b82..3d610f10 100644 --- a/tests/unit/test_light_node.py +++ b/tests/unit/test_light_node.py @@ -30,8 +30,7 @@ def __init__(self, response_message: Any, sync: bool): async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc_val, exc_tb): - ... + async def __aexit__(self, exc_type, exc_val, exc_tb): ... @property def status(self): @@ -62,8 +61,7 @@ def __init__(self, response_message, page=1): async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc_val, exc_tb): - ... + async def __aexit__(self, exc_type, exc_val, exc_tb): ... @property def status(self): From 38c146e08309ceae32691866fe48dbe77d1be4ab Mon Sep 17 00:00:00 2001 From: Antonyjin Date: Tue, 6 Feb 2024 17:36:28 +0100 Subject: [PATCH 08/18] I had the following problem: ssl:True [SSLCertVerificationError: (1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate when using the function: AuthenticatedAlephHttpClient I searched on the internet for a way to solve this problem, but all the commands/advice given didn't work. So I thought it would be a good idea to give the user the option of specifying a specific SSL certificate if they wish. This worked in my case and gave me the option of continuing to use the SDK provided by Aleph. --- src/aleph/sdk/client/authenticated_http.py | 251 +++++++++++---------- src/aleph/sdk/client/http.py | 88 ++++---- 2 files changed, 173 insertions(+), 166 deletions(-) diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index 79385f4f..4e67f4d0 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -4,6 +4,7 @@ import time from pathlib import Path from typing import Any, Dict, List, Mapping, NoReturn, Optional, Tuple, Union +import ssl import aiohttp from aleph_message import parse_message @@ -66,18 +67,20 @@ class AuthenticatedAlephHttpClient(AlephHttpClient, AuthenticatedAlephClient): } def __init__( - self, - account: Account, - api_server: Optional[str] = None, - api_unix_socket: Optional[str] = None, - allow_unix_sockets: bool = True, - timeout: Optional[aiohttp.ClientTimeout] = None, + self, + account: Account, + api_server: Optional[str] = None, + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ssl_context: Optional[ssl.SSLContext] = None, ): super().__init__( api_server=api_server, api_unix_socket=api_unix_socket, allow_unix_sockets=allow_unix_sockets, timeout=timeout, + ssl_context=ssl_context, ) self.account = account @@ -192,8 +195,8 @@ async def _handle_broadcast_error(response: aiohttp.ClientResponse) -> NoReturn: raise BroadcastError(error_msg) async def _handle_broadcast_deprecated_response( - self, - response: aiohttp.ClientResponse, + self, + response: aiohttp.ClientResponse, ) -> None: if response.status != 200: await self._handle_broadcast_error(response) @@ -210,16 +213,16 @@ async def _broadcast_deprecated(self, message_dict: Mapping[str, Any]) -> None: url = "/api/v0/ipfs/pubsub/pub" logger.debug(f"Posting message on {url}") async with self.http_session.post( - url, - json={ - "topic": "ALEPH-TEST", - "data": message_dict, - }, + url, + json={ + "topic": "ALEPH-TEST", + "data": message_dict, + }, ) as response: await self._handle_broadcast_deprecated_response(response) async def _handle_broadcast_response( - self, response: aiohttp.ClientResponse, sync: bool, raise_on_rejected: bool + self, response: aiohttp.ClientResponse, sync: bool, raise_on_rejected: bool ) -> Tuple[Dict[str, Any], MessageStatus]: if response.status in (200, 202): status = await response.json() @@ -239,10 +242,10 @@ async def _handle_broadcast_response( await self._handle_broadcast_error(response) async def _broadcast( - self, - message: AlephMessage, - sync: bool, - raise_on_rejected: bool = True, + self, + message: AlephMessage, + sync: bool, + raise_on_rejected: bool = True, ) -> Tuple[Dict[str, Any], MessageStatus]: """ Broadcast a message on the aleph.im network. @@ -256,11 +259,11 @@ async def _broadcast( message_dict = message.dict(include=self.BROADCAST_MESSAGE_FIELDS) async with self.http_session.post( - url, - json={ - "sync": sync, - "message": message_dict, - }, + url, + json={ + "sync": sync, + "message": message_dict, + }, ) as response: # The endpoint may be unavailable on this node, try the deprecated version. if response.status in (404, 405): @@ -275,15 +278,15 @@ async def _broadcast( ) async def create_post( - self, - post_content, - post_type: str, - ref: Optional[str] = None, - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, - sync: bool = False, + self, + post_content, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, ) -> Tuple[PostMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -306,13 +309,13 @@ async def create_post( return message, status async def create_aggregate( - self, - key: str, - content: Mapping[str, Any], - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - sync: bool = False, + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, ) -> Tuple[AggregateMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -333,17 +336,17 @@ async def create_aggregate( return message, status async def create_store( - self, - address: Optional[str] = None, - file_content: Optional[bytes] = None, - file_path: Optional[Union[str, Path]] = None, - file_hash: Optional[str] = None, - guess_mime_type: bool = False, - ref: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - extra_fields: Optional[dict] = None, - channel: Optional[str] = None, - sync: bool = False, + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -407,26 +410,26 @@ async def create_store( return message, status async def create_program( - self, - program_ref: str, - entrypoint: str, - runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - persistent: bool = False, - allow_amend: bool = False, - internet: bool = True, - aleph_api: bool = True, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, - metadata: Optional[Mapping[str, Any]] = None, + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + allow_amend: bool = False, + internet: bool = True, + aleph_api: bool = True, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, ) -> Tuple[ProgramMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -500,26 +503,26 @@ async def create_program( return message, status async def create_instance( - self, - rootfs: str, - rootfs_size: int, - rootfs_name: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - allow_amend: bool = False, - internet: bool = True, - aleph_api: bool = True, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - volume_persistence: str = "host", - ssh_keys: Optional[List[str]] = None, - metadata: Optional[Mapping[str, Any]] = None, + self, + rootfs: str, + rootfs_size: int, + rootfs_name: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + allow_amend: bool = False, + internet: bool = True, + aleph_api: bool = True, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + volume_persistence: str = "host", + ssh_keys: Optional[List[str]] = None, + metadata: Optional[Mapping[str, Any]] = None, ) -> Tuple[InstanceMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -591,13 +594,13 @@ async def create_instance( raise ValueError(f"Unknown error code {error_code}: {rejected_message}") async def forget( - self, - hashes: List[str], - reason: Optional[str], - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, ) -> Tuple[ForgetMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -625,12 +628,12 @@ def compute_sha256(s: str) -> str: return h.hexdigest() async def _prepare_aleph_message( - self, - message_type: MessageType, - content: Dict[str, Any], - channel: Optional[str], - allow_inlining: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, + self, + message_type: MessageType, + content: Dict[str, Any], + channel: Optional[str], + allow_inlining: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, ) -> AlephMessage: message_dict: Dict[str, Any] = { "sender": self.account.get_address(), @@ -667,14 +670,14 @@ async def _prepare_aleph_message( return parse_message(message_dict) async def submit( - self, - content: Dict[str, Any], - message_type: MessageType, - channel: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - allow_inlining: bool = True, - sync: bool = False, - raise_on_rejected: bool = True, + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + raise_on_rejected: bool = True, ) -> Tuple[AlephMessage, MessageStatus, Optional[Dict[str, Any]]]: message = await self._prepare_aleph_message( message_type=message_type, @@ -689,11 +692,11 @@ async def submit( return message, message_status, response async def _storage_push_file_with_message( - self, - file_content: bytes, - store_content: StoreContent, - channel: Optional[str] = None, - sync: bool = False, + self, + file_content: bytes, + store_content: StoreContent, + channel: Optional[str] = None, + sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: """Push a file to the storage service.""" data = aiohttp.FormData() @@ -727,14 +730,14 @@ async def _storage_push_file_with_message( return message, message_status async def _upload_file_native( - self, - address: str, - file_content: bytes, - guess_mime_type: bool = False, - ref: Optional[str] = None, - extra_fields: Optional[dict] = None, - channel: Optional[str] = None, - sync: bool = False, + self, + address: str, + file_content: bytes, + guess_mime_type: bool = False, + ref: Optional[str] = None, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: file_hash = hashlib.sha256(file_content).hexdigest() if magic and guess_mime_type: diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index 0a46896b..0ae01605 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -1,5 +1,6 @@ import json import logging +import ssl from io import BytesIO from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Type @@ -30,11 +31,12 @@ class AlephHttpClient(AlephClient): http_session: aiohttp.ClientSession def __init__( - self, - api_server: Optional[str] = None, - api_unix_socket: Optional[str] = None, - allow_unix_sockets: bool = True, - timeout: Optional[aiohttp.ClientTimeout] = None, + self, + api_server: Optional[str] = None, + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ssl_context: Optional[ssl.SSLContext] = None, ): """AlephClient can use HTTP(S) or HTTP over Unix sockets. Unix sockets are used when running inside a virtual machine, @@ -48,6 +50,8 @@ def __init__( if unix_socket_path and allow_unix_sockets: check_unix_socket_valid(unix_socket_path) connector = aiohttp.UnixConnector(path=unix_socket_path) + elif ssl_context: + connector = aiohttp.TCPConnector(ssl=ssl_context) else: connector = None @@ -79,7 +83,7 @@ async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: params: Dict[str, Any] = {"keys": key} async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", params=params + f"/api/v0/aggregates/{address}.json", params=params ) as resp: resp.raise_for_status() result = await resp.json() @@ -87,7 +91,7 @@ async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: return data.get(key) async def fetch_aggregates( - self, address: str, keys: Optional[Iterable[str]] = None + self, address: str, keys: Optional[Iterable[str]] = None ) -> Dict[str, Dict]: keys_str = ",".join(keys) if keys else "" params: Dict[str, Any] = {} @@ -95,8 +99,8 @@ async def fetch_aggregates( params["keys"] = keys_str async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", - params=params, + f"/api/v0/aggregates/{address}.json", + params=params, ) as resp: resp.raise_for_status() result = await resp.json() @@ -104,12 +108,12 @@ async def fetch_aggregates( return data async def get_posts( - self, - page_size: int = 200, - page: int = 1, - post_filter: Optional[PostFilter] = None, - ignore_invalid_messages: Optional[bool] = True, - invalid_messages_log_level: Optional[int] = logging.NOTSET, + self, + page_size: int = 200, + page: int = 1, + post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages @@ -153,9 +157,9 @@ async def get_posts( ) async def download_file_to_buffer( - self, - file_hash: str, - output_buffer: Writable[bytes], + self, + file_hash: str, + output_buffer: Writable[bytes], ) -> None: """ Download a file from the storage engine and write it to the specified output buffer. @@ -164,7 +168,7 @@ async def download_file_to_buffer( """ async with self.http_session.get( - f"/api/v0/storage/raw/{file_hash}" + f"/api/v0/storage/raw/{file_hash}" ) as response: if response.status == 200: await copy_async_readable_to_buffer( @@ -180,9 +184,9 @@ async def download_file_to_buffer( raise FileTooLarge(f"The file from {file_hash} is too large") async def download_file_ipfs_to_buffer( - self, - file_hash: str, - output_buffer: Writable[bytes], + self, + file_hash: str, + output_buffer: Writable[bytes], ) -> None: """ Download a file from the storage engine and write it to the specified output buffer. @@ -192,7 +196,7 @@ async def download_file_ipfs_to_buffer( """ async with aiohttp.ClientSession() as session: async with session.get( - f"https://ipfs.aleph.im/ipfs/{file_hash}" + f"https://ipfs.aleph.im/ipfs/{file_hash}" ) as response: if response.status == 200: await copy_async_readable_to_buffer( @@ -202,8 +206,8 @@ async def download_file_ipfs_to_buffer( response.raise_for_status() async def download_file( - self, - file_hash: str, + self, + file_hash: str, ) -> bytes: """ Get a file from the storage engine as raw bytes. @@ -217,8 +221,8 @@ async def download_file( return buffer.getvalue() async def download_file_ipfs( - self, - file_hash: str, + self, + file_hash: str, ) -> bytes: """ Get a file from the ipfs storage engine as raw bytes. @@ -232,12 +236,12 @@ async def download_file_ipfs( return buffer.getvalue() async def get_messages( - self, - page_size: int = 200, - page: int = 1, - message_filter: Optional[MessageFilter] = None, - ignore_invalid_messages: Optional[bool] = True, - invalid_messages_log_level: Optional[int] = logging.NOTSET, + self, + page_size: int = 200, + page: int = 1, + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages @@ -259,7 +263,7 @@ async def get_messages( params["pagination"] = str(page_size) async with self.http_session.get( - "/api/v0/messages.json", params=params + "/api/v0/messages.json", params=params ) as resp: resp.raise_for_status() response_json = await resp.json() @@ -294,9 +298,9 @@ async def get_messages( ) async def get_message( - self, - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, ) -> GenericMessage: async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp: try: @@ -321,8 +325,8 @@ async def get_message( return message async def get_message_error( - self, - item_hash: str, + self, + item_hash: str, ) -> Optional[Dict[str, Any]]: async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp: try: @@ -344,14 +348,14 @@ async def get_message_error( } async def watch_messages( - self, - message_filter: Optional[MessageFilter] = None, + self, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: message_filter = message_filter or MessageFilter() params = message_filter.as_http_params() async with self.http_session.ws_connect( - "/api/ws0/messages", params=params + "/api/ws0/messages", params=params ) as ws: logger.debug("Websocket connected") async for msg in ws: From ec71aaf9e0103f51e487807b5f72bd377eb01955 Mon Sep 17 00:00:00 2001 From: Antonyjin Date: Tue, 6 Feb 2024 17:47:05 +0100 Subject: [PATCH 09/18] Bug: problem with the ssl certificate I had the following problem: ssl:True [SSLCertVerificationError: (1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate when using the function: AuthenticatedAlephHttpClient I searched on the internet for a way to solve this problem, but all the commands/advice given didn't work. So I thought it would be a good idea to give the user the option of specifying a specific SSL certificate if they wish. This worked in my case and gave me the option of continuing to use the SDK provided by Aleph. --- src/aleph/sdk/client/authenticated_http.py | 250 ++++++++++----------- src/aleph/sdk/client/http.py | 88 ++++---- 2 files changed, 169 insertions(+), 169 deletions(-) diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index 4e67f4d0..be2b0593 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -67,13 +67,13 @@ class AuthenticatedAlephHttpClient(AlephHttpClient, AuthenticatedAlephClient): } def __init__( - self, - account: Account, - api_server: Optional[str] = None, - api_unix_socket: Optional[str] = None, - allow_unix_sockets: bool = True, - timeout: Optional[aiohttp.ClientTimeout] = None, - ssl_context: Optional[ssl.SSLContext] = None, + self, + account: Account, + api_server: Optional[str] = None, + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ssl_context: Optional[ssl.SSLContext] = None, ): super().__init__( api_server=api_server, @@ -195,8 +195,8 @@ async def _handle_broadcast_error(response: aiohttp.ClientResponse) -> NoReturn: raise BroadcastError(error_msg) async def _handle_broadcast_deprecated_response( - self, - response: aiohttp.ClientResponse, + self, + response: aiohttp.ClientResponse, ) -> None: if response.status != 200: await self._handle_broadcast_error(response) @@ -213,16 +213,16 @@ async def _broadcast_deprecated(self, message_dict: Mapping[str, Any]) -> None: url = "/api/v0/ipfs/pubsub/pub" logger.debug(f"Posting message on {url}") async with self.http_session.post( - url, - json={ - "topic": "ALEPH-TEST", - "data": message_dict, - }, + url, + json={ + "topic": "ALEPH-TEST", + "data": message_dict, + }, ) as response: await self._handle_broadcast_deprecated_response(response) async def _handle_broadcast_response( - self, response: aiohttp.ClientResponse, sync: bool, raise_on_rejected: bool + self, response: aiohttp.ClientResponse, sync: bool, raise_on_rejected: bool ) -> Tuple[Dict[str, Any], MessageStatus]: if response.status in (200, 202): status = await response.json() @@ -242,10 +242,10 @@ async def _handle_broadcast_response( await self._handle_broadcast_error(response) async def _broadcast( - self, - message: AlephMessage, - sync: bool, - raise_on_rejected: bool = True, + self, + message: AlephMessage, + sync: bool, + raise_on_rejected: bool = True, ) -> Tuple[Dict[str, Any], MessageStatus]: """ Broadcast a message on the aleph.im network. @@ -259,11 +259,11 @@ async def _broadcast( message_dict = message.dict(include=self.BROADCAST_MESSAGE_FIELDS) async with self.http_session.post( - url, - json={ - "sync": sync, - "message": message_dict, - }, + url, + json={ + "sync": sync, + "message": message_dict, + }, ) as response: # The endpoint may be unavailable on this node, try the deprecated version. if response.status in (404, 405): @@ -278,15 +278,15 @@ async def _broadcast( ) async def create_post( - self, - post_content, - post_type: str, - ref: Optional[str] = None, - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, - sync: bool = False, + self, + post_content, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, ) -> Tuple[PostMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -309,13 +309,13 @@ async def create_post( return message, status async def create_aggregate( - self, - key: str, - content: Mapping[str, Any], - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - sync: bool = False, + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, ) -> Tuple[AggregateMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -336,17 +336,17 @@ async def create_aggregate( return message, status async def create_store( - self, - address: Optional[str] = None, - file_content: Optional[bytes] = None, - file_path: Optional[Union[str, Path]] = None, - file_hash: Optional[str] = None, - guess_mime_type: bool = False, - ref: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - extra_fields: Optional[dict] = None, - channel: Optional[str] = None, - sync: bool = False, + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -410,26 +410,26 @@ async def create_store( return message, status async def create_program( - self, - program_ref: str, - entrypoint: str, - runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - persistent: bool = False, - allow_amend: bool = False, - internet: bool = True, - aleph_api: bool = True, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, - metadata: Optional[Mapping[str, Any]] = None, + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + allow_amend: bool = False, + internet: bool = True, + aleph_api: bool = True, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, ) -> Tuple[ProgramMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -503,26 +503,26 @@ async def create_program( return message, status async def create_instance( - self, - rootfs: str, - rootfs_size: int, - rootfs_name: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - allow_amend: bool = False, - internet: bool = True, - aleph_api: bool = True, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - volume_persistence: str = "host", - ssh_keys: Optional[List[str]] = None, - metadata: Optional[Mapping[str, Any]] = None, + self, + rootfs: str, + rootfs_size: int, + rootfs_name: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + allow_amend: bool = False, + internet: bool = True, + aleph_api: bool = True, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + volume_persistence: str = "host", + ssh_keys: Optional[List[str]] = None, + metadata: Optional[Mapping[str, Any]] = None, ) -> Tuple[InstanceMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -594,13 +594,13 @@ async def create_instance( raise ValueError(f"Unknown error code {error_code}: {rejected_message}") async def forget( - self, - hashes: List[str], - reason: Optional[str], - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, ) -> Tuple[ForgetMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -628,12 +628,12 @@ def compute_sha256(s: str) -> str: return h.hexdigest() async def _prepare_aleph_message( - self, - message_type: MessageType, - content: Dict[str, Any], - channel: Optional[str], - allow_inlining: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, + self, + message_type: MessageType, + content: Dict[str, Any], + channel: Optional[str], + allow_inlining: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, ) -> AlephMessage: message_dict: Dict[str, Any] = { "sender": self.account.get_address(), @@ -670,14 +670,14 @@ async def _prepare_aleph_message( return parse_message(message_dict) async def submit( - self, - content: Dict[str, Any], - message_type: MessageType, - channel: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - allow_inlining: bool = True, - sync: bool = False, - raise_on_rejected: bool = True, + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + raise_on_rejected: bool = True, ) -> Tuple[AlephMessage, MessageStatus, Optional[Dict[str, Any]]]: message = await self._prepare_aleph_message( message_type=message_type, @@ -692,11 +692,11 @@ async def submit( return message, message_status, response async def _storage_push_file_with_message( - self, - file_content: bytes, - store_content: StoreContent, - channel: Optional[str] = None, - sync: bool = False, + self, + file_content: bytes, + store_content: StoreContent, + channel: Optional[str] = None, + sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: """Push a file to the storage service.""" data = aiohttp.FormData() @@ -730,14 +730,14 @@ async def _storage_push_file_with_message( return message, message_status async def _upload_file_native( - self, - address: str, - file_content: bytes, - guess_mime_type: bool = False, - ref: Optional[str] = None, - extra_fields: Optional[dict] = None, - channel: Optional[str] = None, - sync: bool = False, + self, + address: str, + file_content: bytes, + guess_mime_type: bool = False, + ref: Optional[str] = None, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: file_hash = hashlib.sha256(file_content).hexdigest() if magic and guess_mime_type: diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index 0ae01605..b8a62cf3 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -1,8 +1,8 @@ import json import logging -import ssl from io import BytesIO from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Type +import ssl import aiohttp from aleph_message import parse_message @@ -31,12 +31,12 @@ class AlephHttpClient(AlephClient): http_session: aiohttp.ClientSession def __init__( - self, - api_server: Optional[str] = None, - api_unix_socket: Optional[str] = None, - allow_unix_sockets: bool = True, - timeout: Optional[aiohttp.ClientTimeout] = None, - ssl_context: Optional[ssl.SSLContext] = None, + self, + api_server: Optional[str] = None, + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ssl_context: Optional[ssl.SSLContext] = None, ): """AlephClient can use HTTP(S) or HTTP over Unix sockets. Unix sockets are used when running inside a virtual machine, @@ -83,7 +83,7 @@ async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: params: Dict[str, Any] = {"keys": key} async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", params=params + f"/api/v0/aggregates/{address}.json", params=params ) as resp: resp.raise_for_status() result = await resp.json() @@ -91,7 +91,7 @@ async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: return data.get(key) async def fetch_aggregates( - self, address: str, keys: Optional[Iterable[str]] = None + self, address: str, keys: Optional[Iterable[str]] = None ) -> Dict[str, Dict]: keys_str = ",".join(keys) if keys else "" params: Dict[str, Any] = {} @@ -99,8 +99,8 @@ async def fetch_aggregates( params["keys"] = keys_str async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", - params=params, + f"/api/v0/aggregates/{address}.json", + params=params, ) as resp: resp.raise_for_status() result = await resp.json() @@ -108,12 +108,12 @@ async def fetch_aggregates( return data async def get_posts( - self, - page_size: int = 200, - page: int = 1, - post_filter: Optional[PostFilter] = None, - ignore_invalid_messages: Optional[bool] = True, - invalid_messages_log_level: Optional[int] = logging.NOTSET, + self, + page_size: int = 200, + page: int = 1, + post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages @@ -157,9 +157,9 @@ async def get_posts( ) async def download_file_to_buffer( - self, - file_hash: str, - output_buffer: Writable[bytes], + self, + file_hash: str, + output_buffer: Writable[bytes], ) -> None: """ Download a file from the storage engine and write it to the specified output buffer. @@ -168,7 +168,7 @@ async def download_file_to_buffer( """ async with self.http_session.get( - f"/api/v0/storage/raw/{file_hash}" + f"/api/v0/storage/raw/{file_hash}" ) as response: if response.status == 200: await copy_async_readable_to_buffer( @@ -184,9 +184,9 @@ async def download_file_to_buffer( raise FileTooLarge(f"The file from {file_hash} is too large") async def download_file_ipfs_to_buffer( - self, - file_hash: str, - output_buffer: Writable[bytes], + self, + file_hash: str, + output_buffer: Writable[bytes], ) -> None: """ Download a file from the storage engine and write it to the specified output buffer. @@ -196,7 +196,7 @@ async def download_file_ipfs_to_buffer( """ async with aiohttp.ClientSession() as session: async with session.get( - f"https://ipfs.aleph.im/ipfs/{file_hash}" + f"https://ipfs.aleph.im/ipfs/{file_hash}" ) as response: if response.status == 200: await copy_async_readable_to_buffer( @@ -206,8 +206,8 @@ async def download_file_ipfs_to_buffer( response.raise_for_status() async def download_file( - self, - file_hash: str, + self, + file_hash: str, ) -> bytes: """ Get a file from the storage engine as raw bytes. @@ -221,8 +221,8 @@ async def download_file( return buffer.getvalue() async def download_file_ipfs( - self, - file_hash: str, + self, + file_hash: str, ) -> bytes: """ Get a file from the ipfs storage engine as raw bytes. @@ -236,12 +236,12 @@ async def download_file_ipfs( return buffer.getvalue() async def get_messages( - self, - page_size: int = 200, - page: int = 1, - message_filter: Optional[MessageFilter] = None, - ignore_invalid_messages: Optional[bool] = True, - invalid_messages_log_level: Optional[int] = logging.NOTSET, + self, + page_size: int = 200, + page: int = 1, + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages @@ -263,7 +263,7 @@ async def get_messages( params["pagination"] = str(page_size) async with self.http_session.get( - "/api/v0/messages.json", params=params + "/api/v0/messages.json", params=params ) as resp: resp.raise_for_status() response_json = await resp.json() @@ -298,9 +298,9 @@ async def get_messages( ) async def get_message( - self, - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, ) -> GenericMessage: async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp: try: @@ -325,8 +325,8 @@ async def get_message( return message async def get_message_error( - self, - item_hash: str, + self, + item_hash: str, ) -> Optional[Dict[str, Any]]: async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp: try: @@ -348,14 +348,14 @@ async def get_message_error( } async def watch_messages( - self, - message_filter: Optional[MessageFilter] = None, + self, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: message_filter = message_filter or MessageFilter() params = message_filter.as_http_params() async with self.http_session.ws_connect( - "/api/ws0/messages", params=params + "/api/ws0/messages", params=params ) as ws: logger.debug("Websocket connected") async for msg in ws: From 302e9fc325bd617265b35293b30b3de5e7fea26a Mon Sep 17 00:00:00 2001 From: Antony JIN <91880456+Antonyjin@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:56:58 +0100 Subject: [PATCH 10/18] Update src/aleph/sdk/client/http.py Co-authored-by: Mike Hukiewitz <70762838+MHHukiewitz@users.noreply.github.com> --- src/aleph/sdk/client/http.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index b8a62cf3..28e575cc 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -47,11 +47,11 @@ def __init__( raise ValueError("Missing API host") unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET - if unix_socket_path and allow_unix_sockets: + if ssl_context: + connector = aiohttp.TCPConnector(ssl=ssl_context) + elif unix_socket_path and allow_unix_sockets: check_unix_socket_valid(unix_socket_path) connector = aiohttp.UnixConnector(path=unix_socket_path) - elif ssl_context: - connector = aiohttp.TCPConnector(ssl=ssl_context) else: connector = None From bc057c4f2f44f828c18e124077769d21a6ce8811 Mon Sep 17 00:00:00 2001 From: Antonyjin Date: Thu, 8 Feb 2024 16:01:02 +0100 Subject: [PATCH 11/18] Fix: Unit test failed but now succeeds The problem was with the connector type: src/aleph/sdk/client/http.py:55:25 : error : Incompatible types in assignment (expression has type "UnixConnector", variable has type "Optional[TCPConnector]") I have therefore declared the connector as a union of possible BaseConnector or None types. This means it can be any aiohttp BaseConnector or None. --- src/aleph/sdk/client/authenticated_http.py | 2 +- src/aleph/sdk/client/http.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index be2b0593..3b4e5f78 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -1,10 +1,10 @@ import hashlib import json import logging +import ssl import time from pathlib import Path from typing import Any, Dict, List, Mapping, NoReturn, Optional, Tuple, Union -import ssl import aiohttp from aleph_message import parse_message diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index 28e575cc..c79a07a5 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -1,8 +1,8 @@ import json import logging -from io import BytesIO -from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Type import ssl +from io import BytesIO +from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Type, Union import aiohttp from aleph_message import parse_message @@ -46,6 +46,7 @@ def __init__( if not self.api_server: raise ValueError("Missing API host") + connector: Union[aiohttp.BaseConnector, None] unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET if ssl_context: connector = aiohttp.TCPConnector(ssl=ssl_context) From 1cd2b7d52491dc9f277432d433f606ca7796d789 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Thu, 8 Feb 2024 17:55:39 +0100 Subject: [PATCH 12/18] Fix: An ETH account could not be initialized from its mnemonic Users could not easily import or migrate accounts using their mnemonic representation. Solution: Add a static method `from_mnemonic` on the `ETHAccount` class. Discussion: This is a first step and this behaviour can be extended to more chains in the future. --- src/aleph/sdk/chains/ethereum.py | 5 +++++ tests/unit/test_chain_ethereum.py | 13 ++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/aleph/sdk/chains/ethereum.py b/src/aleph/sdk/chains/ethereum.py index 4f00cd7e..124fbee7 100644 --- a/src/aleph/sdk/chains/ethereum.py +++ b/src/aleph/sdk/chains/ethereum.py @@ -36,6 +36,11 @@ def get_address(self) -> str: def get_public_key(self) -> str: return "0x" + get_public_key(private_key=self._account.key).hex() + @staticmethod + def from_mnemonic(mnemonic: str) -> "ETHAccount": + Account.enable_unaudited_hdwallet_features() + return ETHAccount(private_key=Account.from_mnemonic(mnemonic=mnemonic).key) + def get_fallback_account(path: Optional[Path] = None) -> ETHAccount: return ETHAccount(private_key=get_fallback_private_key(path=path)) diff --git a/tests/unit/test_chain_ethereum.py b/tests/unit/test_chain_ethereum.py index 9a602b3d..c05207a1 100644 --- a/tests/unit/test_chain_ethereum.py +++ b/tests/unit/test_chain_ethereum.py @@ -5,7 +5,7 @@ import pytest from aleph.sdk.chains.common import get_verification_buffer -from aleph.sdk.chains.ethereum import get_fallback_account, verify_signature +from aleph.sdk.chains.ethereum import get_fallback_account, verify_signature, ETHAccount from aleph.sdk.exceptions import BadSignatureError @@ -156,3 +156,14 @@ async def test_sign_raw(ethereum_account): assert isinstance(signature, bytes) verify_signature(signature, ethereum_account.get_address(), buffer) + + +def test_from_mnemonic(): + mnemonic = ( + "indoor dish desk flag debris potato excuse depart ticket judge file exit" + ) + account = ETHAccount.from_mnemonic(mnemonic) + assert ( + account.get_public_key() + == "0x0226cc24348fbe0c2912fbb0aa4408e089bb0ae488a88ac46bb13290629a737646" + ) From 03de4b3c3d1feaceaf09fe71a26afc46f93bf736 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Thu, 8 Feb 2024 18:07:43 +0100 Subject: [PATCH 13/18] fixup! Fix: An ETH account could not be initialized from its mnemonic --- tests/unit/test_chain_ethereum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_chain_ethereum.py b/tests/unit/test_chain_ethereum.py index c05207a1..84b8ae26 100644 --- a/tests/unit/test_chain_ethereum.py +++ b/tests/unit/test_chain_ethereum.py @@ -5,7 +5,7 @@ import pytest from aleph.sdk.chains.common import get_verification_buffer -from aleph.sdk.chains.ethereum import get_fallback_account, verify_signature, ETHAccount +from aleph.sdk.chains.ethereum import ETHAccount, get_fallback_account, verify_signature from aleph.sdk.exceptions import BadSignatureError From fed1d95ae7e2a68a3a86e13e89fffa68c6066b10 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Thu, 8 Feb 2024 18:01:20 +0100 Subject: [PATCH 14/18] Fix: `create_instance` required a program `encoding` Problem: Instances don't use program encodings. This argument was left from refactoring. Solution: Drop the argument. --- src/aleph/sdk/client/abstract.py | 1 - src/aleph/sdk/client/authenticated_http.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index 20a04e43..03581960 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -363,7 +363,6 @@ async def create_instance( allow_amend: bool = False, internet: bool = True, aleph_api: bool = True, - encoding: Encoding = Encoding.zip, volumes: Optional[List[Mapping]] = None, volume_persistence: str = "host", ssh_keys: Optional[List[str]] = None, diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index 3b4e5f78..95a1babf 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -518,7 +518,6 @@ async def create_instance( allow_amend: bool = False, internet: bool = True, aleph_api: bool = True, - encoding: Encoding = Encoding.zip, volumes: Optional[List[Mapping]] = None, volume_persistence: str = "host", ssh_keys: Optional[List[str]] = None, From 1c29e0f9234f27a12b383f6269929c3de5a59585 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Thu, 8 Feb 2024 18:33:30 +0100 Subject: [PATCH 15/18] Fix: Payment could not be specified for instances Problem: Users of the SDK could not create instances with a specific payment method such as token streams. Solution: Add a new argument, `payment`, to `AuthenticatedAlephClient` and use it in the "Instance" messages generated. The argument is optional and defaults to "hold" on "ETH" for backward compatibility. Discussion: The argument is added just after mandatory arguments so it can be made mandatory in the future. This may however break backward compatibility with existing code that does not call the function using keywords arguments. --- src/aleph/sdk/client/abstract.py | 3 ++ src/aleph/sdk/client/authenticated_http.py | 7 ++++- tests/unit/test_asynchronous.py | 34 ++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index 03581960..3ffc388b 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -20,6 +20,7 @@ AlephMessage, MessagesResponse, MessageType, + Payment, PostMessage, ) from aleph_message.models.execution.program import Encoding @@ -352,6 +353,7 @@ async def create_instance( rootfs: str, rootfs_size: int, rootfs_name: str, + payment: Optional[Payment] = None, environment_variables: Optional[Mapping[str, str]] = None, storage_engine: StorageEnum = StorageEnum.storage, channel: Optional[str] = None, @@ -374,6 +376,7 @@ async def create_instance( :param rootfs: Root filesystem to use :param rootfs_size: Size of root filesystem :param rootfs_name: Name of root filesystem + :param payment: Payment method used to pay for the instance :param environment_variables: Environment variables to pass to the program :param storage_engine: Storage engine to use (Default: "storage") :param channel: Channel to use (Default: "TEST") diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index 95a1babf..cf75d986 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -12,6 +12,7 @@ AggregateContent, AggregateMessage, AlephMessage, + Chain, ForgetContent, ForgetMessage, InstanceContent, @@ -25,7 +26,7 @@ StoreContent, StoreMessage, ) -from aleph_message.models.execution.base import Encoding +from aleph_message.models.execution.base import Encoding, Payment, PaymentType from aleph_message.models.execution.environment import ( FunctionEnvironment, MachineResources, @@ -507,6 +508,7 @@ async def create_instance( rootfs: str, rootfs_size: int, rootfs_name: str, + payment: Optional[Payment] = None, environment_variables: Optional[Mapping[str, str]] = None, storage_engine: StorageEnum = StorageEnum.storage, channel: Optional[str] = None, @@ -530,6 +532,8 @@ async def create_instance( vcpus = vcpus or settings.DEFAULT_VM_VCPUS timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT + payment = payment or Payment(chain=Chain.ETH, type=PaymentType.hold) + content = InstanceContent( address=address, allow_amend=allow_amend, @@ -563,6 +567,7 @@ async def create_instance( time=time.time(), authorized_keys=ssh_keys, metadata=metadata, + payment=payment, ) message, status, response = await self.submit( content=content.dict(exclude_none=True), diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index c13df757..ef8b67ca 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -4,9 +4,12 @@ import pytest as pytest from aleph_message.models import ( AggregateMessage, + Chain, ForgetMessage, InstanceMessage, MessageType, + Payment, + PaymentType, PostMessage, ProgramMessage, StoreMessage, @@ -108,12 +111,39 @@ async def test_create_instance(mock_session_with_post_success): rootfs_name="rootfs", channel="TEST", metadata={"tags": ["test"]}, + payment=Payment( + chain=Chain.AVAX, + receiver="0x4145f182EF2F06b45E50468519C1B92C60FBd4A0", + type=PaymentType.superfluid, + ), ) assert mock_session_with_post_success.http_session.post.called_once assert isinstance(instance_message, InstanceMessage) +@pytest.mark.asyncio +async def test_create_instance_no_payment(mock_session_with_post_success): + """Test that an instance can be created with no payment specified. + It should in this case default to "holding" on "ETH". + """ + async with mock_session_with_post_success as session: + instance_message, message_status = await session.create_instance( + rootfs="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + rootfs_size=1, + rootfs_name="rootfs", + channel="TEST", + metadata={"tags": ["test"]}, + payment=None, + ) + + assert instance_message.content.payment.type == PaymentType.hold + assert instance_message.content.payment.chain == Chain.ETH + + assert mock_session_with_post_success.http_session.post.called_once + assert isinstance(instance_message, InstanceMessage) + + @pytest.mark.asyncio async def test_forget(mock_session_with_post_success): async with mock_session_with_post_success as session: @@ -199,4 +229,8 @@ async def test_create_instance_insufficient_funds_error( rootfs_name="rootfs", channel="TEST", metadata={"tags": ["test"]}, + payment=Payment( + chain=Chain.ETH, + type=PaymentType.hold, + ), ) From a3e120ab11c63a0e91dc591e565bbec6696a4f24 Mon Sep 17 00:00:00 2001 From: Mike Hukiewitz <70762838+MHHukiewitz@users.noreply.github.com> Date: Mon, 12 Feb 2024 16:33:40 +0100 Subject: [PATCH 16/18] Fix: Loosen aleph-message dependency (#92) Problem: too strict aleph-message dependency Solution: loosen it to accept compatible versions to 0.4.2 Co-authored-by: mhh --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a132f9bc..de505203 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,7 @@ install_requires = eciespy>=0.3.13; python_version>="3.11" typing_extensions typer - aleph-message==0.4.1 + aleph-message~=0.4.3 eth_account>=0.4.0 # Required to fix a dependency issue with parsimonious and Python3.11 eth_abi==4.0.0b2; python_version>="3.11" From b6c37149dd4591fef604f11e8ab3c230b14adafb Mon Sep 17 00:00:00 2001 From: Mike Hukiewitz <70762838+MHHukiewitz@users.noreply.github.com> Date: Mon, 12 Feb 2024 16:36:46 +0100 Subject: [PATCH 17/18] Feature: Add Deprecation Message (#103) Problem: too strict aleph-message dependency Solution: loosen it to accept compatible versions to 0.4.2 Co-authored-by: Hugo Herter --- src/aleph/sdk/__init__.py | 19 +++++++++++++++ src/aleph/sdk/client/abstract.py | 42 +++++++++++++++++++++----------- tests/unit/test_init.py | 24 ++++++++++++++++++ 3 files changed, 71 insertions(+), 14 deletions(-) diff --git a/src/aleph/sdk/__init__.py b/src/aleph/sdk/__init__.py index c14b64f6..a3ecc693 100644 --- a/src/aleph/sdk/__init__.py +++ b/src/aleph/sdk/__init__.py @@ -12,3 +12,22 @@ del get_distribution, DistributionNotFound __all__ = ["AlephHttpClient", "AuthenticatedAlephHttpClient"] + + +def __getattr__(name): + if name == "AlephClient": + raise ImportError( + "AlephClient has been turned into an abstract class. Please use `AlephHttpClient` instead." + ) + elif name == "AuthenticatedAlephClient": + raise ImportError( + "AuthenticatedAlephClient has been turned into an abstract class. Please use `AuthenticatedAlephHttpClient` instead." + ) + elif name == "synchronous": + raise ImportError( + "The 'aleph.sdk.synchronous' type is deprecated and has been removed from the aleph SDK. Please use `aleph.sdk.client.AlephHttpClient` instead." + ) + elif name == "asynchronous": + raise ImportError( + "The 'aleph.sdk.asynchronous' type is deprecated and has been removed from the aleph SDK. Please use `aleph.sdk.client.AlephHttpClient` instead." + ) diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index 3ffc388b..3335ad86 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -43,7 +43,7 @@ async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: :param address: Address of the owner of the aggregate :param key: Key of the aggregate """ - pass + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") @abstractmethod async def fetch_aggregates( @@ -55,7 +55,7 @@ async def fetch_aggregates( :param address: Address of the owner of the aggregate :param keys: Keys of the aggregates to fetch (Default: all items) """ - pass + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") @abstractmethod async def get_posts( @@ -75,7 +75,7 @@ async def get_posts( :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ - pass + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") async def get_posts_iterator( self, @@ -110,7 +110,7 @@ async def download_file( :param file_hash: The hash of the file to retrieve. """ - pass + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") async def download_file_ipfs( self, @@ -168,7 +168,7 @@ async def get_messages( :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ - pass + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") async def get_messages_iterator( self, @@ -203,7 +203,7 @@ async def get_message( :param item_hash: Hash of the message to fetch :param message_type: Type of message to fetch """ - pass + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") @abstractmethod def watch_messages( @@ -215,7 +215,7 @@ def watch_messages( :param message_filter: Filter to apply to the messages """ - pass + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") class AuthenticatedAlephClient(AlephClient): @@ -243,7 +243,9 @@ async def create_post( :param storage_engine: An optional storage engine to use for the message, if not inlined (Default: "storage") :param sync: If true, waits for the message to be processed by the API server (Default: False) """ - pass + raise NotImplementedError( + "Did you mean to import `AuthenticatedAlephHttpClient`?" + ) @abstractmethod async def create_aggregate( @@ -265,7 +267,9 @@ async def create_aggregate( :param inline: Whether to write content inside the message (Default: True) :param sync: If true, waits for the message to be processed by the API server (Default: False) """ - pass + raise NotImplementedError( + "Did you mean to import `AuthenticatedAlephHttpClient`?" + ) @abstractmethod async def create_store( @@ -297,7 +301,9 @@ async def create_store( :param channel: Channel to post the message to (Default: "TEST") :param sync: If true, waits for the message to be processed by the API server (Default: False) """ - pass + raise NotImplementedError( + "Did you mean to import `AuthenticatedAlephHttpClient`?" + ) @abstractmethod async def create_program( @@ -345,7 +351,9 @@ async def create_program( :param subscriptions: Patterns of aleph.im messages to forward to the program's event receiver :param metadata: Metadata to attach to the message """ - pass + raise NotImplementedError( + "Did you mean to import `AuthenticatedAlephHttpClient`?" + ) @abstractmethod async def create_instance( @@ -394,7 +402,9 @@ async def create_instance( :param ssh_keys: SSH keys to authorize access to the VM :param metadata: Metadata to attach to the message """ - pass + raise NotImplementedError( + "Did you mean to import `AuthenticatedAlephHttpClient`?" + ) @abstractmethod async def forget( @@ -419,7 +429,9 @@ async def forget( :param address: Address to use (Default: account.get_address()) :param sync: If true, waits for the message to be processed by the API server (Default: False) """ - pass + raise NotImplementedError( + "Did you mean to import `AuthenticatedAlephHttpClient`?" + ) @abstractmethod async def submit( @@ -444,7 +456,9 @@ async def submit( :param sync: If true, waits for the message to be processed by the API server (Default: False) :param raise_on_rejected: Whether to raise an exception if the message is rejected (Default: True) """ - pass + raise NotImplementedError( + "Did you mean to import `AuthenticatedAlephHttpClient`?" + ) async def ipfs_push(self, content: Mapping) -> str: """ diff --git a/tests/unit/test_init.py b/tests/unit/test_init.py index 85a1ba69..664783a3 100644 --- a/tests/unit/test_init.py +++ b/tests/unit/test_init.py @@ -1,5 +1,29 @@ +import pytest + from aleph.sdk import __version__ def test_version(): assert __version__ != "" + + +def test_deprecation(): + with pytest.raises(ImportError): + from aleph.sdk import AlephClient # noqa + + with pytest.raises(ImportError): + from aleph.sdk import AuthenticatedAlephClient # noqa + + with pytest.raises(ImportError): + from aleph.sdk import synchronous # noqa + + with pytest.raises(ImportError): + from aleph.sdk import asynchronous # noqa + + with pytest.raises(ImportError): + import aleph.sdk.synchronous # noqa + + with pytest.raises(ImportError): + import aleph.sdk.asynchronous # noqa + + from aleph.sdk import AlephHttpClient # noqa From 12ac7d5cb77f11f3bbf35e45f815822aa0a443fb Mon Sep 17 00:00:00 2001 From: mhh Date: Mon, 12 Feb 2024 18:02:48 +0100 Subject: [PATCH 18/18] Add payment argument to LightNode.create_instance() --- src/aleph/sdk/client/light_node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aleph/sdk/client/light_node.py b/src/aleph/sdk/client/light_node.py index 3fc2439f..478ef875 100644 --- a/src/aleph/sdk/client/light_node.py +++ b/src/aleph/sdk/client/light_node.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union from aleph_message.models import AlephMessage, Chain, MessageType -from aleph_message.models.execution.base import Encoding +from aleph_message.models.execution.base import Encoding, Payment from aleph_message.status import MessageStatus from ..query.filters import MessageFilter @@ -361,6 +361,7 @@ async def create_instance( rootfs: str, rootfs_size: int, rootfs_name: str, + payment: Optional[Payment] = None, environment_variables: Optional[Mapping[str, str]] = None, storage_engine: StorageEnum = StorageEnum.storage, channel: Optional[str] = None, @@ -372,7 +373,6 @@ async def create_instance( allow_amend: bool = False, internet: bool = True, aleph_api: bool = True, - encoding: Encoding = Encoding.zip, volumes: Optional[List[Mapping]] = None, volume_persistence: str = "host", ssh_keys: Optional[List[str]] = None, @@ -385,6 +385,7 @@ async def create_instance( rootfs=rootfs, rootfs_size=rootfs_size, rootfs_name=rootfs_name, + payment=payment, environment_variables=environment_variables, storage_engine=storage_engine, channel=channel, @@ -396,7 +397,6 @@ async def create_instance( allow_amend=allow_amend, internet=internet, aleph_api=aleph_api, - encoding=encoding, volumes=volumes, volume_persistence=volume_persistence, ssh_keys=ssh_keys,