diff --git a/examples/httpgateway.py b/examples/httpgateway.py index c2cb3bdf..24ed6ba1 100644 --- a/examples/httpgateway.py +++ b/examples/httpgateway.py @@ -7,7 +7,7 @@ from aleph.sdk.chains.common import get_fallback_private_key from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient app = web.Application() routes = web.RouteTableDef() @@ -32,7 +32,7 @@ async def source_post(request): return web.json_response( {"status": "error", "message": "unauthorized secret"} ) - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=app["account"], api_server="https://api2.aleph.im" ) as session: message, _status = await session.create_post( diff --git a/examples/metrics.py b/examples/metrics.py index 381db6be..d8f8a0cc 100644 --- a/examples/metrics.py +++ b/examples/metrics.py @@ -1,7 +1,6 @@ """ Server metrics upload. """ -# -*- coding: utf-8 -*- - +import asyncio import os import platform import time @@ -12,9 +11,11 @@ from aleph_message.status import MessageStatus from aleph.sdk.chains.ethereum import get_fallback_account -from aleph.sdk.client import AuthenticatedAlephClient, AuthenticatedUserSessionSync +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.conf import settings +# -*- coding: utf-8 -*- + def get_sysinfo(): uptime = int(time.time() - psutil.boot_time()) @@ -53,10 +54,12 @@ def get_cpu_cores(): return [c._asdict() for c in psutil.cpu_times_percent(0, percpu=True)] -def send_metrics( - session: AuthenticatedUserSessionSync, metrics +async def send_metrics( + session: AuthenticatedAlephHttpClient, metrics ) -> Tuple[AlephMessage, MessageStatus]: - return session.create_aggregate(key="metrics", content=metrics, channel="SYSINFO") + return await session.create_aggregate( + key="metrics", content=metrics, channel="SYSINFO" + ) def collect_metrics(): @@ -68,17 +71,17 @@ def collect_metrics(): } -def main(): +async def main(): account = get_fallback_account() - with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: while True: metrics = collect_metrics() - message, status = send_metrics(session, metrics) + message, status = await send_metrics(session, metrics) print("sent", message.item_hash) time.sleep(10) if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/examples/mqtt.py b/examples/mqtt.py index eff32121..b08538f9 100644 --- a/examples/mqtt.py +++ b/examples/mqtt.py @@ -10,7 +10,7 @@ from aleph.sdk.chains.common import get_fallback_private_key from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.conf import settings @@ -26,8 +26,8 @@ def get_input_data(value): return value.decode("utf-8") -def send_metrics(account, metrics): - with AuthenticatedAlephClient( +async def send_metrics(account, metrics): + async with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: return session.create_aggregate( @@ -100,7 +100,7 @@ async def gateway( if not userdata["received"]: await client.reconnect() - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: for key, value in state.items(): diff --git a/examples/store.py b/examples/store.py index 6ce5662c..b6c7a862 100644 --- a/examples/store.py +++ b/examples/store.py @@ -6,7 +6,7 @@ from aleph.sdk.chains.common import get_fallback_private_key from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.conf import settings DEFAULT_SERVER = "https://api2.aleph.im" @@ -23,7 +23,7 @@ async def print_output_hash(message: StoreMessage, status: MessageStatus): async def do_upload(account, engine, channel, filename=None, file_hash=None): - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: print(filename, account.get_address()) diff --git a/src/aleph/sdk/__init__.py b/src/aleph/sdk/__init__.py index c66fe9d6..c14b64f6 100644 --- a/src/aleph/sdk/__init__.py +++ b/src/aleph/sdk/__init__.py @@ -1,6 +1,6 @@ from pkg_resources import DistributionNotFound, get_distribution -from aleph.sdk.client import AlephClient, AuthenticatedAlephClient +from aleph.sdk.client import AlephHttpClient, AuthenticatedAlephHttpClient try: # Change here if project is renamed and does not equal the package name @@ -11,4 +11,4 @@ finally: del get_distribution, DistributionNotFound -__all__ = ["AlephClient", "AuthenticatedAlephClient"] +__all__ = ["AlephHttpClient", "AuthenticatedAlephHttpClient"] diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py deleted file mode 100644 index f79f0ceb..00000000 --- a/src/aleph/sdk/client.py +++ /dev/null @@ -1,1425 +0,0 @@ -import asyncio -import hashlib -import json -import logging -import queue -import threading -import time -import warnings -from datetime import datetime -from io import BytesIO -from pathlib import Path -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Dict, - Iterable, - List, - Mapping, - NoReturn, - Optional, - Tuple, - Type, - TypeVar, - Union, -) - -import aiohttp -from aleph_message.models import ( - AggregateContent, - AggregateMessage, - AlephMessage, - ForgetContent, - ForgetMessage, - ItemHash, - ItemType, - MessageType, - PostContent, - PostMessage, - ProgramContent, - ProgramMessage, - StoreContent, - StoreMessage, - parse_message, -) -from aleph_message.models.execution.base import Encoding -from aleph_message.status import MessageStatus -from pydantic import ValidationError -from pydantic.json import pydantic_encoder - -from aleph.sdk.types import Account, GenericMessage, StorageEnum -from aleph.sdk.utils import Writable, copy_async_readable_to_buffer - -from .base import BaseAlephClient, BaseAuthenticatedAlephClient -from .conf import settings -from .exceptions import ( - BroadcastError, - FileTooLarge, - InvalidMessageError, - MessageNotFoundError, - MultipleMessagesError, -) -from .models import MessagesResponse, Post, PostsResponse -from .utils import check_unix_socket_valid, get_message_type_value - -logger = logging.getLogger(__name__) - -try: - import magic -except ImportError: - logger.info("Could not import library 'magic', MIME type detection disabled") - magic = None # type:ignore - -T = TypeVar("T") - - -def async_wrapper(f): - """ - Copies the docstring of wrapped functions. - """ - - wrapped = getattr(AuthenticatedAlephClient, f.__name__) - f.__doc__ = wrapped.__doc__ - - -def wrap_async(func: Callable[..., Awaitable[T]]) -> Callable[..., T]: - """Wrap an asynchronous function into a synchronous one, - for easy use in synchronous code. - """ - - def func_caller(*args, **kwargs): - loop = asyncio.get_event_loop() - return loop.run_until_complete(func(*args, **kwargs)) - - # Copy wrapped function interface: - func_caller.__doc__ = func.__doc__ - func_caller.__annotations__ = func.__annotations__ - func_caller.__defaults__ = func.__defaults__ - func_caller.__kwdefaults__ = func.__kwdefaults__ - return func_caller - - -async def run_async_watcher( - *args, output_queue: queue.Queue, api_server: Optional[str], **kwargs -): - async with AlephClient(api_server=api_server) as session: - async for message in session.watch_messages(*args, **kwargs): - output_queue.put(message) - - -def watcher_thread(output_queue: queue.Queue, api_server: Optional[str], args, kwargs): - asyncio.run( - run_async_watcher( - output_queue=output_queue, api_server=api_server, *args, **kwargs - ) - ) - - -class UserSessionSync: - """ - A sync version of `UserSession`, used in sync code. - - This class is returned by the context manager of `UserSession` and is - intended as a wrapper around the methods of `UserSession` and not as a public class. - The methods are fully typed to enable static type checking, but most (all) methods - should look like this (using args and kwargs for brevity, but the functions should - be fully typed): - - >>> def func(self, *args, **kwargs): - >>> return self._wrap(self.async_session.func)(*args, **kwargs) - """ - - def __init__(self, async_session: "AlephClient"): - self.async_session = async_session - - def _wrap(self, method: Callable[..., Awaitable[T]], *args, **kwargs): - return wrap_async(method)(*args, **kwargs) - - def get_messages( - self, - pagination: int = 200, - page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[List[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - ignore_invalid_messages: bool = True, - invalid_messages_log_level: int = logging.NOTSET, - ) -> MessagesResponse: - return self._wrap( - self.async_session.get_messages, - pagination=pagination, - page=page, - message_type=message_type, - message_types=message_types, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, - ignore_invalid_messages=ignore_invalid_messages, - invalid_messages_log_level=invalid_messages_log_level, - ) - - # @async_wrapper - def get_message( - self, - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, - channel: Optional[str] = None, - ) -> GenericMessage: - return self._wrap( - self.async_session.get_message, - item_hash=item_hash, - message_type=message_type, - channel=channel, - ) - - def fetch_aggregate( - self, - address: str, - key: str, - limit: int = 100, - ) -> Dict[str, Dict]: - return self._wrap(self.async_session.fetch_aggregate, address, key, limit) - - def fetch_aggregates( - self, - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, - ) -> Dict[str, Dict]: - return self._wrap(self.async_session.fetch_aggregates, address, keys, limit) - - def get_posts( - self, - pagination: int = 200, - page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - ) -> PostsResponse: - return self._wrap( - self.async_session.get_posts, - pagination=pagination, - page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, - ) - - def download_file(self, file_hash: str) -> bytes: - return self._wrap(self.async_session.download_file, file_hash=file_hash) - - def download_file_ipfs(self, file_hash: str) -> bytes: - return self._wrap( - self.async_session.download_file_ipfs, - file_hash=file_hash, - ) - - def download_file_to_buffer( - self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: - return self._wrap( - self.async_session.download_file_to_buffer, - file_hash=file_hash, - output_buffer=output_buffer, - ) - - def download_file_ipfs_to_buffer( - self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: - return self._wrap( - self.async_session.download_file_ipfs_to_buffer, - file_hash=file_hash, - output_buffer=output_buffer, - ) - - def watch_messages( - self, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - ) -> Iterable[AlephMessage]: - """ - Iterate over current and future matching messages synchronously. - - Runs the `watch_messages` asynchronous generator in a thread. - """ - output_queue: queue.Queue[AlephMessage] = queue.Queue() - thread = threading.Thread( - target=watcher_thread, - args=( - output_queue, - self.async_session.api_server, - ( - message_type, - content_types, - refs, - addresses, - tags, - hashes, - channels, - chains, - start_date, - end_date, - ), - {}, - ), - ) - thread.start() - while True: - yield output_queue.get() - - -class AuthenticatedUserSessionSync(UserSessionSync): - async_session: "AuthenticatedAlephClient" - - def __init__(self, async_session: "AuthenticatedAlephClient"): - super().__init__(async_session=async_session) - - def ipfs_push(self, content: Mapping) -> str: - return self._wrap(self.async_session.ipfs_push, content=content) - - def storage_push(self, content: Mapping) -> str: - return self._wrap(self.async_session.storage_push, content=content) - - def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: - return self._wrap(self.async_session.ipfs_push_file, file_content=file_content) - - def storage_push_file(self, file_content: Union[str, bytes]) -> str: - return self._wrap( - self.async_session.storage_push_file, file_content=file_content - ) - - def create_post( - self, - post_content, - post_type: str, - ref: Optional[str] = None, - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, - sync: bool = False, - ) -> Tuple[PostMessage, MessageStatus]: - return self._wrap( - self.async_session.create_post, - post_content=post_content, - post_type=post_type, - ref=ref, - address=address, - channel=channel, - inline=inline, - storage_engine=storage_engine, - sync=sync, - ) - - def create_aggregate( - self, - key: str, - content: Mapping[str, Any], - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - sync: bool = False, - ) -> Tuple[AggregateMessage, MessageStatus]: - return self._wrap( - self.async_session.create_aggregate, - key=key, - content=content, - address=address, - channel=channel, - inline=inline, - sync=sync, - ) - - def create_store( - self, - address: Optional[str] = None, - file_content: Optional[bytes] = None, - file_path: Optional[Union[str, Path]] = None, - file_hash: Optional[str] = None, - guess_mime_type: bool = False, - ref: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - extra_fields: Optional[dict] = None, - channel: Optional[str] = None, - sync: bool = False, - ) -> Tuple[StoreMessage, MessageStatus]: - return self._wrap( - self.async_session.create_store, - address=address, - file_content=file_content, - file_path=file_path, - file_hash=file_hash, - guess_mime_type=guess_mime_type, - ref=ref, - storage_engine=storage_engine, - extra_fields=extra_fields, - channel=channel, - sync=sync, - ) - - def create_program( - self, - program_ref: str, - entrypoint: str, - runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - persistent: bool = False, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, - metadata: Optional[Mapping[str, Any]] = None, - ) -> Tuple[ProgramMessage, MessageStatus]: - return self._wrap( - self.async_session.create_program, - program_ref=program_ref, - entrypoint=entrypoint, - runtime=runtime, - environment_variables=environment_variables, - storage_engine=storage_engine, - channel=channel, - address=address, - sync=sync, - memory=memory, - vcpus=vcpus, - timeout_seconds=timeout_seconds, - persistent=persistent, - encoding=encoding, - volumes=volumes, - subscriptions=subscriptions, - metadata=metadata, - ) - - def forget( - self, - hashes: List[str], - reason: Optional[str], - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - ) -> Tuple[ForgetMessage, MessageStatus]: - return self._wrap( - self.async_session.forget, - hashes=hashes, - reason=reason, - storage_engine=storage_engine, - channel=channel, - address=address, - sync=sync, - ) - - def submit( - self, - content: Dict[str, Any], - message_type: MessageType, - channel: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - allow_inlining: bool = True, - sync: bool = False, - ) -> Tuple[AlephMessage, MessageStatus]: - return self._wrap( - self.async_session.submit, - content=content, - message_type=message_type, - channel=channel, - storage_engine=storage_engine, - allow_inlining=allow_inlining, - sync=sync, - ) - - -class AlephClient(BaseAlephClient): - api_server: str - http_session: aiohttp.ClientSession - - def __init__( - self, - api_server: Optional[str] = None, - api_unix_socket: Optional[str] = None, - allow_unix_sockets: bool = True, - timeout: Optional[aiohttp.ClientTimeout] = None, - ): - """AlephClient can use HTTP(S) or HTTP over Unix sockets. - Unix sockets are used when running inside a virtual machine, - and can be shared across containers in a more secure way than TCP ports. - """ - self.api_server = api_server or settings.API_HOST - if not self.api_server: - raise ValueError("Missing API host") - - unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET - if unix_socket_path and allow_unix_sockets: - check_unix_socket_valid(unix_socket_path) - connector = aiohttp.UnixConnector(path=unix_socket_path) - else: - connector = None - - # ClientSession timeout defaults to a private sentinel object and may not be None. - self.http_session = ( - aiohttp.ClientSession( - base_url=self.api_server, connector=connector, timeout=timeout - ) - if timeout - else aiohttp.ClientSession( - base_url=self.api_server, - connector=connector, - ) - ) - - def __enter__(self) -> UserSessionSync: - return UserSessionSync(async_session=self) - - def __exit__(self, exc_type, exc_val, exc_tb): - close_fut = self.http_session.close() - try: - loop = asyncio.get_running_loop() - loop.run_until_complete(close_fut) - except RuntimeError: - asyncio.run(close_fut) - - async def __aenter__(self) -> "AlephClient": - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.http_session.close() - - async def fetch_aggregate( - self, - address: str, - key: str, - limit: int = 100, - ) -> Dict[str, Dict]: - params: Dict[str, Any] = {"keys": key} - if limit: - params["limit"] = limit - - async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", params=params - ) as resp: - result = await resp.json() - data = result.get("data", dict()) - return data.get(key) - - async def fetch_aggregates( - self, - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, - ) -> Dict[str, Dict]: - keys_str = ",".join(keys) if keys else "" - params: Dict[str, Any] = {} - if keys_str: - params["keys"] = keys_str - if limit: - params["limit"] = limit - - async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", - params=params, - ) as resp: - result = await resp.json() - data = result.get("data", dict()) - return data - - async def get_posts( - self, - pagination: int = 200, - page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - ignore_invalid_messages: Optional[bool] = True, - invalid_messages_log_level: Optional[int] = logging.NOTSET, - ) -> PostsResponse: - ignore_invalid_messages = ( - True if ignore_invalid_messages is None else ignore_invalid_messages - ) - invalid_messages_log_level = ( - logging.NOTSET - if invalid_messages_log_level is None - else invalid_messages_log_level - ) - - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if types is not None: - params["types"] = ",".join(types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with self.http_session.get("/api/v0/posts.json", params=params) as resp: - resp.raise_for_status() - response_json = await resp.json() - posts_raw = response_json["posts"] - - posts: List[Post] = [] - for post_raw in posts_raw: - try: - posts.append(Post.parse_obj(post_raw)) - except ValidationError as e: - if not ignore_invalid_messages: - raise e - if invalid_messages_log_level: - logger.log(level=invalid_messages_log_level, msg=e) - return PostsResponse( - posts=posts, - pagination_page=response_json["pagination_page"], - pagination_total=response_json["pagination_total"], - pagination_per_page=response_json["pagination_per_page"], - pagination_item=response_json["pagination_item"], - ) - - async def download_file_to_buffer( - self, - file_hash: str, - output_buffer: Writable[bytes], - ) -> None: - """ - Download a file from the storage engine and write it to the specified output buffer. - :param file_hash: The hash of the file to retrieve. - :param output_buffer: Writable binary buffer. The file will be written to this buffer. - """ - - async with self.http_session.get( - f"/api/v0/storage/raw/{file_hash}" - ) as response: - if response.status == 200: - await copy_async_readable_to_buffer( - response.content, output_buffer, chunk_size=16 * 1024 - ) - if response.status == 413: - ipfs_hash = ItemHash(file_hash) - if ipfs_hash.item_type == ItemType.ipfs: - return await self.download_file_ipfs_to_buffer( - file_hash, output_buffer - ) - else: - raise FileTooLarge(f"The file from {file_hash} is too large") - - async def download_file_ipfs_to_buffer( - self, - file_hash: str, - output_buffer: Writable[bytes], - ) -> None: - """ - Download a file from the storage engine and write it to the specified output buffer. - - :param file_hash: The hash of the file to retrieve. - :param output_buffer: The binary output buffer to write the file data to. - """ - async with aiohttp.ClientSession() as session: - async with session.get( - f"https://ipfs.aleph.im/ipfs/{file_hash}" - ) as response: - if response.status == 200: - await copy_async_readable_to_buffer( - response.content, output_buffer, chunk_size=16 * 1024 - ) - else: - response.raise_for_status() - - async def download_file( - self, - file_hash: str, - ) -> bytes: - """ - Get a file from the storage engine as raw bytes. - - Warning: Downloading large files can be slow and memory intensive. - - :param file_hash: The hash of the file to retrieve. - """ - buffer = BytesIO() - await self.download_file_to_buffer(file_hash, output_buffer=buffer) - return buffer.getvalue() - - async def download_file_ipfs( - self, - file_hash: str, - ) -> bytes: - """ - Get a file from the ipfs storage engine as raw bytes. - - Warning: Downloading large files can be slow. - - :param file_hash: The hash of the file to retrieve. - """ - buffer = BytesIO() - await self.download_file_ipfs_to_buffer(file_hash, output_buffer=buffer) - return buffer.getvalue() - - async def get_messages( - self, - pagination: int = 200, - page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - ignore_invalid_messages: Optional[bool] = True, - invalid_messages_log_level: Optional[int] = logging.NOTSET, - ) -> MessagesResponse: - ignore_invalid_messages = ( - True if ignore_invalid_messages is None else ignore_invalid_messages - ) - invalid_messages_log_level = ( - logging.NOTSET - if invalid_messages_log_level is None - else invalid_messages_log_level - ) - - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - print(params["msgTypes"]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with self.http_session.get( - "/api/v0/messages.json", params=params - ) as resp: - resp.raise_for_status() - response_json = await resp.json() - messages_raw = response_json["messages"] - - # All messages may not be valid according to the latest specification in - # aleph-message. This allows the user to specify how errors should be handled. - messages: List[AlephMessage] = [] - for message_raw in messages_raw: - try: - message = parse_message(message_raw) - messages.append(message) - except KeyError as e: - if not ignore_invalid_messages: - raise e - logger.log( - level=invalid_messages_log_level, - msg=f"KeyError: Field '{e.args[0]}' not found", - ) - except ValidationError as e: - if not ignore_invalid_messages: - raise e - if invalid_messages_log_level: - logger.log(level=invalid_messages_log_level, msg=e) - - return MessagesResponse( - messages=messages, - pagination_page=response_json["pagination_page"], - pagination_total=response_json["pagination_total"], - pagination_per_page=response_json["pagination_per_page"], - pagination_item=response_json["pagination_item"], - ) - - async def get_message( - self, - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, - channel: Optional[str] = None, - ) -> GenericMessage: - messages_response = await self.get_messages( - hashes=[item_hash], - channels=[channel] if channel else None, - ) - if len(messages_response.messages) < 1: - raise MessageNotFoundError(f"No such hash {item_hash}") - if len(messages_response.messages) != 1: - raise MultipleMessagesError( - f"Multiple messages found for the same item_hash `{item_hash}`" - ) - message: GenericMessage = messages_response.messages[0] - if message_type: - expected_type = get_message_type_value(message_type) - if message.type != expected_type: - raise TypeError( - f"The message type '{message.type}' " - f"does not match the expected type '{expected_type}'" - ) - return message - - async def watch_messages( - self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - ) -> AsyncIterable[AlephMessage]: - params: Dict[str, Any] = dict() - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with self.http_session.ws_connect( - "/api/ws0/messages", params=params - ) as ws: - logger.debug("Websocket connected") - async for msg in ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if msg.data == "close cmd": - await ws.close() - break - else: - data = json.loads(msg.data) - yield parse_message(data) - elif msg.type == aiohttp.WSMsgType.ERROR: - break - - -class AuthenticatedAlephClient(AlephClient, BaseAuthenticatedAlephClient): - account: Account - - BROADCAST_MESSAGE_FIELDS = { - "sender", - "chain", - "signature", - "type", - "item_hash", - "item_type", - "item_content", - "time", - "channel", - } - - def __init__( - self, - account: Account, - api_server: Optional[str], - api_unix_socket: Optional[str] = None, - allow_unix_sockets: bool = True, - timeout: Optional[aiohttp.ClientTimeout] = None, - ): - super().__init__( - api_server=api_server, - api_unix_socket=api_unix_socket, - allow_unix_sockets=allow_unix_sockets, - timeout=timeout, - ) - self.account = account - - def __enter__(self) -> "AuthenticatedUserSessionSync": - return AuthenticatedUserSessionSync(async_session=self) - - async def __aenter__(self) -> "AuthenticatedAlephClient": - return self - - async def ipfs_push(self, content: Mapping) -> str: - """ - Push arbitrary content as JSON to the IPFS service. - - :param content: The dict-like content to upload - """ - url = "/api/v0/ipfs/add_json" - logger.debug(f"Pushing to IPFS on {url}") - - async with self.http_session.post(url, json=content) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - async def storage_push(self, content: Mapping) -> str: - """ - Push arbitrary content as JSON to the storage service. - - :param content: The dict-like content to upload - """ - url = "/api/v0/storage/add_json" - logger.debug(f"Pushing to storage on {url}") - - async with self.http_session.post(url, json=content) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - async def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: - """ - Push a file to the IPFS service. - - :param file_content: The file content to upload - """ - data = aiohttp.FormData() - data.add_field("file", file_content) - - url = "/api/v0/ipfs/add_file" - logger.debug(f"Pushing file to IPFS on {url}") - - async with self.http_session.post(url, data=data) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - async def storage_push_file(self, file_content) -> str: - """ - Push a file to the storage service. - """ - data = aiohttp.FormData() - data.add_field("file", file_content) - - url = "/api/v0/storage/add_file" - logger.debug(f"Posting file on {url}") - - async with self.http_session.post(url, data=data) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - @staticmethod - def _log_publication_status(publication_status: Mapping[str, Any]): - status = publication_status.get("status") - failures = publication_status.get("failed") - - if status == "success": - return - elif status == "warning": - logger.warning("Broadcast failed on the following network(s): %s", failures) - elif status == "error": - logger.error( - "Broadcast failed on all protocols. The message was not published." - ) - else: - raise ValueError( - f"Invalid response from server, status in missing or unknown: '{status}'" - ) - - @staticmethod - async def _handle_broadcast_error(response: aiohttp.ClientResponse) -> NoReturn: - if response.status == 500: - # Assume a broadcast error, no need to read the JSON - if response.content_type == "application/json": - error_msg = "Internal error - broadcast failed on all protocols" - else: - error_msg = f"Internal error - the message was not broadcast: {await response.text()}" - - logger.error(error_msg) - raise BroadcastError(error_msg) - elif response.status == 422: - errors = await response.json() - logger.error( - "The message could not be processed because of the following errors: %s", - errors, - ) - raise InvalidMessageError(errors) - else: - error_msg = ( - f"Unexpected HTTP response ({response.status}: {await response.text()})" - ) - logger.error(error_msg) - raise BroadcastError(error_msg) - - async def _handle_broadcast_deprecated_response( - self, - response: aiohttp.ClientResponse, - ) -> None: - if response.status != 200: - await self._handle_broadcast_error(response) - else: - publication_status = await response.json() - self._log_publication_status(publication_status) - - async def _broadcast_deprecated(self, message_dict: Mapping[str, Any]) -> None: - """ - Broadcast a message on the Aleph network using the deprecated - /ipfs/pubsub/pub/ endpoint. - """ - - url = "/api/v0/ipfs/pubsub/pub" - logger.debug(f"Posting message on {url}") - - async with self.http_session.post( - url, - json={"topic": "ALEPH-TEST", "data": json.dumps(message_dict)}, - ) as response: - await self._handle_broadcast_deprecated_response(response) - - async def _handle_broadcast_response( - self, response: aiohttp.ClientResponse, sync: bool - ) -> MessageStatus: - if response.status in (200, 202): - status = await response.json() - self._log_publication_status(status["publication_status"]) - - if response.status == 202: - if sync: - logger.warning( - "Timed out while waiting for processing of sync message" - ) - return MessageStatus.PENDING - - return MessageStatus.PROCESSED - - else: - await self._handle_broadcast_error(response) - - async def _broadcast( - self, - message: AlephMessage, - sync: bool, - ) -> MessageStatus: - """ - Broadcast a message on the Aleph network. - - Uses the POST /messages/ endpoint or the deprecated /ipfs/pubsub/pub/ endpoint - if the first method is not available. - """ - - url = "/api/v0/messages" - logger.debug(f"Posting message on {url}") - - message_dict = message.dict(include=self.BROADCAST_MESSAGE_FIELDS) - - async with self.http_session.post( - url, - json={"sync": sync, "message": message_dict}, - ) as response: - # The endpoint may be unavailable on this node, try the deprecated version. - if response.status in (404, 405): - logger.warning( - "POST /messages/ not found. Defaulting to legacy endpoint..." - ) - await self._broadcast_deprecated(message_dict=message_dict) - return MessageStatus.PENDING - else: - message_status = await self._handle_broadcast_response( - response=response, sync=sync - ) - return message_status - - async def create_post( - self, - post_content, - post_type: str, - ref: Optional[str] = None, - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, - sync: bool = False, - ) -> Tuple[PostMessage, MessageStatus]: - address = address or settings.ADDRESS_TO_USE or self.account.get_address() - - content = PostContent( - type=post_type, - address=address, - content=post_content, - time=time.time(), - ref=ref, - ) - - return await self.submit( - content=content.dict(exclude_none=True), - message_type=MessageType.post, - channel=channel, - allow_inlining=inline, - storage_engine=storage_engine, - sync=sync, - ) - - async def create_aggregate( - self, - key: str, - content: Mapping[str, Any], - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - sync: bool = False, - ) -> Tuple[AggregateMessage, MessageStatus]: - address = address or settings.ADDRESS_TO_USE or self.account.get_address() - - content_ = AggregateContent( - key=key, - address=address, - content=content, - time=time.time(), - ) - - return await self.submit( - content=content_.dict(exclude_none=True), - message_type=MessageType.aggregate, - channel=channel, - allow_inlining=inline, - sync=sync, - ) - - async def create_store( - self, - address: Optional[str] = None, - file_content: Optional[bytes] = None, - file_path: Optional[Union[str, Path]] = None, - file_hash: Optional[str] = None, - guess_mime_type: bool = False, - ref: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - extra_fields: Optional[dict] = None, - channel: Optional[str] = None, - sync: bool = False, - ) -> Tuple[StoreMessage, MessageStatus]: - address = address or settings.ADDRESS_TO_USE or self.account.get_address() - - extra_fields = extra_fields or {} - - if file_hash is None: - if file_content is None: - if file_path is None: - raise ValueError( - "Please specify at least a file_content, a file_hash or a file_path" - ) - else: - file_content = Path(file_path).read_bytes() - - if storage_engine == StorageEnum.storage: - file_hash = await self.storage_push_file(file_content=file_content) - elif storage_engine == StorageEnum.ipfs: - file_hash = await self.ipfs_push_file(file_content=file_content) - else: - raise ValueError(f"Unknown storage engine: '{storage_engine}'") - - assert file_hash, "File hash should not be empty" - - if magic is None: - pass - elif file_content and guess_mime_type and ("mime_type" not in extra_fields): - extra_fields["mime_type"] = magic.from_buffer(file_content, mime=True) - - if ref: - extra_fields["ref"] = ref - - values = { - "address": address, - "item_type": storage_engine, - "item_hash": file_hash, - "time": time.time(), - } - if extra_fields is not None: - values.update(extra_fields) - - content = StoreContent(**values) - - return await self.submit( - content=content.dict(exclude_none=True), - message_type=MessageType.store, - channel=channel, - allow_inlining=True, - sync=sync, - ) - - async def create_program( - self, - program_ref: str, - entrypoint: str, - runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - persistent: bool = False, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, - metadata: Optional[Mapping[str, Any]] = None, - ) -> Tuple[ProgramMessage, MessageStatus]: - address = address or settings.ADDRESS_TO_USE or self.account.get_address() - - volumes = volumes if volumes is not None else [] - memory = memory or settings.DEFAULT_VM_MEMORY - vcpus = vcpus or settings.DEFAULT_VM_VCPUS - timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT - - # TODO: Check that program_ref, runtime and data_ref exist - - # Register the different ways to trigger a VM - if subscriptions: - # Trigger on HTTP calls and on Aleph message subscriptions. - triggers = { - "http": True, - "persistent": persistent, - "message": subscriptions, - } - else: - # Trigger on HTTP calls. - triggers = {"http": True, "persistent": persistent} - - content = ProgramContent( - **{ - "type": "vm-function", - "address": address, - "allow_amend": False, - "code": { - "encoding": encoding, - "entrypoint": entrypoint, - "ref": program_ref, - "use_latest": True, - }, - "on": triggers, - "environment": { - "reproducible": False, - "internet": True, - "aleph_api": True, - }, - "variables": environment_variables, - "resources": { - "vcpus": vcpus, - "memory": memory, - "seconds": timeout_seconds, - }, - "runtime": { - "ref": runtime, - "use_latest": True, - "comment": "Official Aleph runtime" - if runtime == settings.DEFAULT_RUNTIME_ID - else "", - }, - "volumes": volumes, - "time": time.time(), - "metadata": metadata, - } - ) - - # Ensure that the version of aleph-message used supports the field. - assert content.on.persistent == persistent - - return await self.submit( - content=content.dict(exclude_none=True), - message_type=MessageType.program, - channel=channel, - storage_engine=storage_engine, - sync=sync, - ) - - async def forget( - self, - hashes: List[str], - reason: Optional[str], - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - ) -> Tuple[ForgetMessage, MessageStatus]: - address = address or settings.ADDRESS_TO_USE or self.account.get_address() - - content = ForgetContent( - hashes=hashes, - reason=reason, - address=address, - time=time.time(), - ) - - return await self.submit( - content=content.dict(exclude_none=True), - message_type=MessageType.forget, - channel=channel, - storage_engine=storage_engine, - allow_inlining=True, - sync=sync, - ) - - @staticmethod - def compute_sha256(s: str) -> str: - h = hashlib.sha256() - h.update(s.encode("utf-8")) - return h.hexdigest() - - async def _prepare_aleph_message( - self, - message_type: MessageType, - content: Dict[str, Any], - channel: Optional[str], - allow_inlining: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, - ) -> AlephMessage: - message_dict: Dict[str, Any] = { - "sender": self.account.get_address(), - "chain": self.account.CHAIN, - "type": message_type, - "content": content, - "time": time.time(), - "channel": channel, - } - - # Use the Pydantic encoder to serialize types like UUID, datetimes, etc. - item_content: str = json.dumps( - content, separators=(",", ":"), default=pydantic_encoder - ) - - if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): - message_dict["item_content"] = item_content - message_dict["item_hash"] = self.compute_sha256(item_content) - message_dict["item_type"] = ItemType.inline - else: - if storage_engine == StorageEnum.ipfs: - message_dict["item_hash"] = await self.ipfs_push( - content=content, - ) - message_dict["item_type"] = ItemType.ipfs - else: # storage - assert storage_engine == StorageEnum.storage - message_dict["item_hash"] = await self.storage_push( - content=content, - ) - message_dict["item_type"] = ItemType.storage - - message_dict = await self.account.sign_message(message_dict) - return parse_message(message_dict) - - async def submit( - self, - content: Dict[str, Any], - message_type: MessageType, - channel: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - allow_inlining: bool = True, - sync: bool = False, - ) -> Tuple[AlephMessage, MessageStatus]: - message = await self._prepare_aleph_message( - message_type=message_type, - content=content, - channel=channel, - allow_inlining=allow_inlining, - storage_engine=storage_engine, - ) - message_status = await self._broadcast(message=message, sync=sync) - return message, message_status diff --git a/src/aleph/sdk/client/__init__.py b/src/aleph/sdk/client/__init__.py new file mode 100644 index 00000000..9ee25dd9 --- /dev/null +++ b/src/aleph/sdk/client/__init__.py @@ -0,0 +1,10 @@ +from .abstract import AlephClient, AuthenticatedAlephClient +from .authenticated_http import AuthenticatedAlephHttpClient +from .http import AlephHttpClient + +__all__ = [ + "AlephClient", + "AuthenticatedAlephClient", + "AlephHttpClient", + "AuthenticatedAlephHttpClient", +] diff --git a/src/aleph/sdk/base.py b/src/aleph/sdk/client/abstract.py similarity index 59% rename from src/aleph/sdk/base.py rename to src/aleph/sdk/client/abstract.py index a5b2c266..26a51221 100644 --- a/src/aleph/sdk/base.py +++ b/src/aleph/sdk/client/abstract.py @@ -2,7 +2,6 @@ import logging from abc import ABC, abstractmethod -from datetime import datetime from pathlib import Path from typing import ( Any, @@ -26,76 +25,52 @@ from aleph_message.models.execution.program import Encoding from aleph_message.status import MessageStatus -from aleph.sdk.models import PostsResponse -from aleph.sdk.types import GenericMessage, StorageEnum +from ..query.filters import MessageFilter, PostFilter +from ..query.responses import PostsResponse +from ..types import GenericMessage, StorageEnum +from ..utils import Writable DEFAULT_PAGE_SIZE = 200 -class BaseAlephClient(ABC): +class AlephClient(ABC): @abstractmethod - async def fetch_aggregate( - self, - address: str, - key: str, - limit: int = 100, - ) -> Dict[str, Dict]: + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: """ Fetch a value from the aggregate store by owner address and item key. :param address: Address of the owner of the aggregate :param key: Key of the aggregate - :param limit: Maximum number of items to fetch (Default: 100) """ pass @abstractmethod async def fetch_aggregates( - self, - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, + self, address: str, keys: Optional[Iterable[str]] = None ) -> Dict[str, Dict]: """ Fetch key-value pairs from the aggregate store by owner address. :param address: Address of the owner of the aggregate :param keys: Keys of the aggregates to fetch (Default: all items) - :param limit: Maximum number of items to fetch (Default: 100) """ pass @abstractmethod async def get_posts( self, - pagination: int = DEFAULT_PAGE_SIZE, + page_size: int = DEFAULT_PAGE_SIZE, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: """ Fetch a list of posts from the network. - :param pagination: Number of items to fetch (Default: 200) + :param page_size: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -103,44 +78,20 @@ async def get_posts( async def get_posts_iterator( self, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ) -> AsyncIterable[PostMessage]: """ Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates but will always return all posts. - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) """ page = 1 resp = None while resp is None or len(resp.posts) > 0: resp = await self.get_posts( page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + post_filter=post_filter, ) page += 1 for post in resp.posts: @@ -160,43 +111,59 @@ async def download_file( """ pass + async def download_file_ipfs( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the ipfs storage engine as raw bytes. + + Warning: Downloading large files can be slow. + + :param file_hash: The hash of the file to retrieve. + """ + raise NotImplementedError() + + async def download_file_ipfs_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + + :param file_hash: The hash of the file to retrieve. + :param output_buffer: The binary output buffer to write the file data to. + """ + raise NotImplementedError() + + async def download_file_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + :param file_hash: The hash of the file to retrieve. + :param output_buffer: Writable binary buffer. The file will be written to this buffer. + """ + raise NotImplementedError() + @abstractmethod async def get_messages( self, - pagination: int = DEFAULT_PAGE_SIZE, + page_size: int = DEFAULT_PAGE_SIZE, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: """ Fetch a list of messages from the network. - :param pagination: Number of items to fetch (Default: 200) + :param page_size: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param message_type: [DEPRECATED] Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param message_types: Filter by message types, can be any combination of "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by aggregate key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -204,50 +171,20 @@ async def get_messages( async def get_messages_iterator( self, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates but will always return all messages. - :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by content key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages """ page = 1 resp = None while resp is None or len(resp.messages) > 0: resp = await self.get_messages( page=page, - message_type=message_type, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + message_filter=message_filter, ) page += 1 for message in resp.messages: @@ -272,39 +209,17 @@ async def get_message( @abstractmethod def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Iterate over current and future matching messages asynchronously. - :param message_type: [DEPRECATED] Type of message to watch - :param message_types: Types of messages to watch - :param content_types: Content types to watch - :param content_keys: Filter by aggregate key - :param refs: References to watch - :param addresses: Addresses to watch - :param tags: Tags to watch - :param hashes: Hashes to watch - :param channels: Channels to watch - :param chains: Chains to watch - :param start_date: Start date from when to watch - :param end_date: End date until when to watch + :param message_filter: Filter to apply to the messages """ pass -class BaseAuthenticatedAlephClient(BaseAlephClient): +class AuthenticatedAlephClient(AlephClient): @abstractmethod async def create_post( self, @@ -318,7 +233,7 @@ async def create_post( sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: """ - Create a POST message on the Aleph network. It is associated with a channel and owned by an account. + Create a POST message on the aleph.im network. It is associated with a channel and owned by an account. :param post_content: The content of the message :param post_type: An arbitrary content type that helps to describe the post_content @@ -368,7 +283,7 @@ async def create_store( sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: """ - Create a STORE message to store a file on the Aleph network. + Create a STORE message to store a file on the aleph.im network. Can be passed either a file path, an IPFS hash or the file's content as raw bytes. @@ -422,7 +337,7 @@ async def create_program( :param persistent: Whether the program should be persistent or not (Default: False) :param encoding: Encoding to use (Default: Encoding.zip) :param volumes: Volumes to mount - :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver + :param subscriptions: Patterns of aleph.im messages to forward to the program's event receiver :param metadata: Metadata to attach to the message """ pass @@ -474,3 +389,19 @@ async def submit( :param sync: If true, waits for the message to be processed by the API server (Default: False) """ pass + + async def ipfs_push(self, content: Mapping) -> str: + """ + Push a file to IPFS. + + :param content: Content of the file to push + """ + raise NotImplementedError() + + async def storage_push(self, content: Mapping) -> str: + """ + Push arbitrary content as JSON to the storage service. + + :param content: The dict-like content to upload + """ + raise NotImplementedError() diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py new file mode 100644 index 00000000..6291467a --- /dev/null +++ b/src/aleph/sdk/client/authenticated_http.py @@ -0,0 +1,560 @@ +import hashlib +import json +import logging +import time +from pathlib import Path +from typing import Any, Dict, List, Mapping, NoReturn, Optional, Tuple, Union + +import aiohttp +from aleph_message import parse_message +from aleph_message.models import ( + AggregateContent, + AggregateMessage, + AlephMessage, + ForgetContent, + ForgetMessage, + ItemType, + MessageType, + PostContent, + PostMessage, + ProgramContent, + ProgramMessage, + StoreContent, + StoreMessage, +) +from aleph_message.models.execution.base import Encoding +from aleph_message.models.execution.environment import ( + FunctionEnvironment, + MachineResources, +) +from aleph_message.models.execution.program import CodeContent, FunctionRuntime +from aleph_message.models.execution.volume import MachineVolume +from aleph_message.status import MessageStatus +from pydantic.json import pydantic_encoder + +from ..conf import settings +from ..exceptions import BroadcastError, InvalidMessageError +from ..types import Account, StorageEnum +from .abstract import AuthenticatedAlephClient +from .http import AlephHttpClient + +logger = logging.getLogger(__name__) + +try: + import magic +except ImportError: + logger.info("Could not import library 'magic', MIME type detection disabled") + magic = None # type:ignore + + +class AuthenticatedAlephHttpClient(AlephHttpClient, AuthenticatedAlephClient): + account: Account + + BROADCAST_MESSAGE_FIELDS = { + "sender", + "chain", + "signature", + "type", + "item_hash", + "item_type", + "item_content", + "time", + "channel", + } + + def __init__( + self, + account: Account, + api_server: Optional[str], + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ): + super().__init__( + api_server=api_server, + api_unix_socket=api_unix_socket, + allow_unix_sockets=allow_unix_sockets, + timeout=timeout, + ) + self.account = account + + async def __aenter__(self) -> "AuthenticatedAlephHttpClient": + return self + + async def ipfs_push(self, content: Mapping) -> str: + """ + Push arbitrary content as JSON to the IPFS service. + + :param content: The dict-like content to upload + """ + url = "/api/v0/ipfs/add_json" + logger.debug(f"Pushing to IPFS on {url}") + + async with self.http_session.post(url, json=content) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + async def storage_push(self, content: Mapping) -> str: + """ + Push arbitrary content as JSON to the storage service. + + :param content: The dict-like content to upload + """ + url = "/api/v0/storage/add_json" + logger.debug(f"Pushing to storage on {url}") + + async with self.http_session.post(url, json=content) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + async def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: + """ + Push a file to the IPFS service. + + :param file_content: The file content to upload + """ + data = aiohttp.FormData() + data.add_field("file", file_content) + + url = "/api/v0/ipfs/add_file" + logger.debug(f"Pushing file to IPFS on {url}") + + async with self.http_session.post(url, data=data) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + async def storage_push_file(self, file_content) -> str: + """ + Push a file to the storage service. + """ + data = aiohttp.FormData() + data.add_field("file", file_content) + + url = "/api/v0/storage/add_file" + logger.debug(f"Posting file on {url}") + + async with self.http_session.post(url, data=data) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + @staticmethod + def _log_publication_status(publication_status: Mapping[str, Any]): + status = publication_status.get("status") + failures = publication_status.get("failed") + + if status == "success": + return + elif status == "warning": + logger.warning("Broadcast failed on the following network(s): %s", failures) + elif status == "error": + logger.error( + "Broadcast failed on all protocols. The message was not published." + ) + else: + raise ValueError( + f"Invalid response from server, status in missing or unknown: '{status}'" + ) + + @staticmethod + async def _handle_broadcast_error(response: aiohttp.ClientResponse) -> NoReturn: + if response.status == 500: + # Assume a broadcast error, no need to read the JSON + if response.content_type == "application/json": + error_msg = "Internal error - broadcast failed on all protocols" + else: + error_msg = f"Internal error - the message was not broadcast: {await response.text()}" + + logger.error(error_msg) + raise BroadcastError(error_msg) + elif response.status == 422: + errors = await response.json() + logger.error( + "The message could not be processed because of the following errors: %s", + errors, + ) + raise InvalidMessageError(errors) + else: + error_msg = ( + f"Unexpected HTTP response ({response.status}: {await response.text()})" + ) + logger.error(error_msg) + raise BroadcastError(error_msg) + + async def _handle_broadcast_deprecated_response( + self, + response: aiohttp.ClientResponse, + ) -> None: + if response.status != 200: + await self._handle_broadcast_error(response) + else: + publication_status = await response.json() + self._log_publication_status(publication_status) + + async def _broadcast_deprecated(self, message_dict: Mapping[str, Any]) -> None: + """ + Broadcast a message on the aleph.im network using the deprecated + /ipfs/pubsub/pub/ endpoint. + """ + + url = "/api/v0/ipfs/pubsub/pub" + logger.debug(f"Posting message on {url}") + + async with self.http_session.post( + url, + json={"topic": "ALEPH-TEST", "data": json.dumps(message_dict)}, + ) as response: + await self._handle_broadcast_deprecated_response(response) + + async def _handle_broadcast_response( + self, response: aiohttp.ClientResponse, sync: bool + ) -> MessageStatus: + if response.status in (200, 202): + status = await response.json() + self._log_publication_status(status["publication_status"]) + + if response.status == 202: + if sync: + logger.warning( + "Timed out while waiting for processing of sync message" + ) + return MessageStatus.PENDING + + return MessageStatus.PROCESSED + + else: + await self._handle_broadcast_error(response) + + async def _broadcast( + self, + message: AlephMessage, + sync: bool, + ) -> MessageStatus: + """ + Broadcast a message on the aleph.im network. + + Uses the POST /messages/ endpoint or the deprecated /ipfs/pubsub/pub/ endpoint + if the first method is not available. + """ + + url = "/api/v0/messages" + logger.debug(f"Posting message on {url}") + + message_dict = message.dict(include=self.BROADCAST_MESSAGE_FIELDS) + + async with self.http_session.post( + url, + json={"sync": sync, "message": message_dict}, + ) as response: + # The endpoint may be unavailable on this node, try the deprecated version. + if response.status in (404, 405): + logger.warning( + "POST /messages/ not found. Defaulting to legacy endpoint..." + ) + await self._broadcast_deprecated(message_dict=message_dict) + return MessageStatus.PENDING + else: + message_status = await self._handle_broadcast_response( + response=response, sync=sync + ) + return message_status + + async def create_post( + self, + post_content, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[PostMessage, MessageStatus]: + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + content = PostContent( + type=post_type, + address=address, + content=post_content, + time=time.time(), + ref=ref, + ) + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.post, + channel=channel, + allow_inlining=inline, + storage_engine=storage_engine, + sync=sync, + ) + + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AggregateMessage, MessageStatus]: + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + content_ = AggregateContent( + key=key, + address=address, + content=content, + time=time.time(), + ) + + return await self.submit( + content=content_.dict(exclude_none=True), + message_type=MessageType.aggregate, + channel=channel, + allow_inlining=inline, + sync=sync, + ) + + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[StoreMessage, MessageStatus]: + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + extra_fields = extra_fields or {} + + if file_hash is None: + if file_content is None: + if file_path is None: + raise ValueError( + "Please specify at least a file_content, a file_hash or a file_path" + ) + else: + file_content = Path(file_path).read_bytes() + + if storage_engine == StorageEnum.storage: + file_hash = await self.storage_push_file(file_content=file_content) + elif storage_engine == StorageEnum.ipfs: + file_hash = await self.ipfs_push_file(file_content=file_content) + else: + raise ValueError(f"Unknown storage engine: '{storage_engine}'") + + assert file_hash, "File hash should not be empty" + + if magic is None: + pass + elif file_content and guess_mime_type and ("mime_type" not in extra_fields): + extra_fields["mime_type"] = magic.from_buffer(file_content, mime=True) + + if ref: + extra_fields["ref"] = ref + + values = { + "address": address, + "item_type": storage_engine, + "item_hash": file_hash, + "time": time.time(), + } + if extra_fields is not None: + values.update(extra_fields) + + content = StoreContent(**values) + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.store, + channel=channel, + allow_inlining=True, + sync=sync, + ) + + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, + ) -> Tuple[ProgramMessage, MessageStatus]: + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + volumes = volumes if volumes is not None else [] + memory = memory or settings.DEFAULT_VM_MEMORY + vcpus = vcpus or settings.DEFAULT_VM_VCPUS + timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT + + # TODO: Check that program_ref, runtime and data_ref exist + + # Register the different ways to trigger a VM + if subscriptions: + # Trigger on HTTP calls and on aleph.im message subscriptions. + triggers = { + "http": True, + "persistent": persistent, + "message": subscriptions, + } + else: + # Trigger on HTTP calls. + triggers = {"http": True, "persistent": persistent} + + volumes: List[MachineVolume] = [ + MachineVolume.parse_obj(volume) for volume in volumes + ] + + content = ProgramContent( + type="vm-function", + address=address, + allow_amend=False, + code=CodeContent( + encoding=encoding, + entrypoint=entrypoint, + ref=program_ref, + use_latest=True, + ), + on=triggers, + environment=FunctionEnvironment( + reproducible=False, + internet=True, + aleph_api=True, + ), + variables=environment_variables, + resources=MachineResources( + vcpus=vcpus, + memory=memory, + seconds=timeout_seconds, + ), + runtime=FunctionRuntime( + ref=runtime, + use_latest=True, + comment="Official aleph.im runtime" + if runtime == settings.DEFAULT_RUNTIME_ID + else "", + ), + volumes=volumes, + time=time.time(), + metadata=metadata, + ) + + # Ensure that the version of aleph-message used supports the field. + assert content.on.persistent == persistent + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.program, + channel=channel, + storage_engine=storage_engine, + sync=sync, + ) + + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[ForgetMessage, MessageStatus]: + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + content = ForgetContent( + hashes=hashes, + reason=reason, + address=address, + time=time.time(), + ) + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.forget, + channel=channel, + storage_engine=storage_engine, + allow_inlining=True, + sync=sync, + ) + + @staticmethod + def compute_sha256(s: str) -> str: + h = hashlib.sha256() + h.update(s.encode("utf-8")) + return h.hexdigest() + + async def _prepare_aleph_message( + self, + message_type: MessageType, + content: Dict[str, Any], + channel: Optional[str], + allow_inlining: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + ) -> AlephMessage: + message_dict: Dict[str, Any] = { + "sender": self.account.get_address(), + "chain": self.account.CHAIN, + "type": message_type, + "content": content, + "time": time.time(), + "channel": channel, + } + + # Use the Pydantic encoder to serialize types like UUID, datetimes, etc. + item_content: str = json.dumps( + content, separators=(",", ":"), default=pydantic_encoder + ) + + if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): + message_dict["item_content"] = item_content + message_dict["item_hash"] = self.compute_sha256(item_content) + message_dict["item_type"] = ItemType.inline + else: + if storage_engine == StorageEnum.ipfs: + message_dict["item_hash"] = await self.ipfs_push( + content=content, + ) + message_dict["item_type"] = ItemType.ipfs + else: # storage + assert storage_engine == StorageEnum.storage + message_dict["item_hash"] = await self.storage_push( + content=content, + ) + message_dict["item_type"] = ItemType.storage + + message_dict = await self.account.sign_message(message_dict) + return parse_message(message_dict) + + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + message = await self._prepare_aleph_message( + message_type=message_type, + content=content, + channel=channel, + allow_inlining=allow_inlining, + storage_engine=storage_engine, + ) + message_status = await self._broadcast(message=message, sync=sync) + return message, message_status diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py new file mode 100644 index 00000000..93cbe837 --- /dev/null +++ b/src/aleph/sdk/client/http.py @@ -0,0 +1,337 @@ +import json +import logging +from io import BytesIO +from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Type + +import aiohttp +from aleph_message import parse_message +from aleph_message.models import AlephMessage, ItemHash, ItemType +from pydantic import ValidationError + +from ..conf import settings +from ..exceptions import FileTooLarge, MessageNotFoundError, MultipleMessagesError +from ..query.filters import MessageFilter, PostFilter +from ..query.responses import MessagesResponse, Post, PostsResponse +from ..types import GenericMessage +from ..utils import ( + Writable, + check_unix_socket_valid, + copy_async_readable_to_buffer, + get_message_type_value, +) +from .abstract import AlephClient + +logger = logging.getLogger(__name__) + + +class AlephHttpClient(AlephClient): + api_server: str + http_session: aiohttp.ClientSession + + def __init__( + self, + api_server: Optional[str] = None, + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ): + """AlephClient can use HTTP(S) or HTTP over Unix sockets. + Unix sockets are used when running inside a virtual machine, + and can be shared across containers in a more secure way than TCP ports. + """ + self.api_server = api_server or settings.API_HOST + if not self.api_server: + raise ValueError("Missing API host") + + unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET + if unix_socket_path and allow_unix_sockets: + check_unix_socket_valid(unix_socket_path) + connector = aiohttp.UnixConnector(path=unix_socket_path) + else: + connector = None + + # ClientSession timeout defaults to a private sentinel object and may not be None. + self.http_session = ( + aiohttp.ClientSession( + base_url=self.api_server, connector=connector, timeout=timeout + ) + if timeout + else aiohttp.ClientSession( + base_url=self.api_server, + connector=connector, + ) + ) + + async def __aenter__(self) -> "AlephHttpClient": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.http_session.close() + + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: + params: Dict[str, Any] = {"keys": key} + + async with self.http_session.get( + f"/api/v0/aggregates/{address}.json", params=params + ) as resp: + resp.raise_for_status() + result = await resp.json() + data = result.get("data", dict()) + return data.get(key) + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Dict[str, Dict]: + keys_str = ",".join(keys) if keys else "" + params: Dict[str, Any] = {} + if keys_str: + params["keys"] = keys_str + + async with self.http_session.get( + f"/api/v0/aggregates/{address}.json", + params=params, + ) as resp: + resp.raise_for_status() + result = await resp.json() + data = result.get("data", dict()) + return data + + async def get_posts( + self, + page_size: int = 200, + page: int = 1, + post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> PostsResponse: + ignore_invalid_messages = ( + True if ignore_invalid_messages is None else ignore_invalid_messages + ) + invalid_messages_log_level = ( + logging.NOTSET + if invalid_messages_log_level is None + else invalid_messages_log_level + ) + + if not post_filter: + params = { + "page": str(page), + "pagination": str(page_size), + } + else: + params = post_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(page_size) + + async with self.http_session.get("/api/v0/posts.json", params=params) as resp: + resp.raise_for_status() + response_json = await resp.json() + posts_raw = response_json["posts"] + + posts: List[Post] = [] + for post_raw in posts_raw: + try: + posts.append(Post.parse_obj(post_raw)) + except ValidationError as e: + if not ignore_invalid_messages: + raise e + if invalid_messages_log_level: + logger.log(level=invalid_messages_log_level, msg=e) + return PostsResponse( + posts=posts, + pagination_page=response_json["pagination_page"], + pagination_total=response_json["pagination_total"], + pagination_per_page=response_json["pagination_per_page"], + pagination_item=response_json["pagination_item"], + ) + + async def download_file_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + :param file_hash: The hash of the file to retrieve. + :param output_buffer: Writable binary buffer. The file will be written to this buffer. + """ + + async with self.http_session.get( + f"/api/v0/storage/raw/{file_hash}" + ) as response: + if response.status == 200: + await copy_async_readable_to_buffer( + response.content, output_buffer, chunk_size=16 * 1024 + ) + if response.status == 413: + ipfs_hash = ItemHash(file_hash) + if ipfs_hash.item_type == ItemType.ipfs: + return await self.download_file_ipfs_to_buffer( + file_hash, output_buffer + ) + else: + raise FileTooLarge(f"The file from {file_hash} is too large") + + async def download_file_ipfs_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + + :param file_hash: The hash of the file to retrieve. + :param output_buffer: The binary output buffer to write the file data to. + """ + async with aiohttp.ClientSession() as session: + async with session.get( + f"https://ipfs.aleph.im/ipfs/{file_hash}" + ) as response: + if response.status == 200: + await copy_async_readable_to_buffer( + response.content, output_buffer, chunk_size=16 * 1024 + ) + else: + response.raise_for_status() + + async def download_file( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the storage engine as raw bytes. + + Warning: Downloading large files can be slow and memory intensive. + + :param file_hash: The hash of the file to retrieve. + """ + buffer = BytesIO() + await self.download_file_to_buffer(file_hash, output_buffer=buffer) + return buffer.getvalue() + + async def download_file_ipfs( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the ipfs storage engine as raw bytes. + + Warning: Downloading large files can be slow. + + :param file_hash: The hash of the file to retrieve. + """ + buffer = BytesIO() + await self.download_file_ipfs_to_buffer(file_hash, output_buffer=buffer) + return buffer.getvalue() + + async def get_messages( + self, + page_size: int = 200, + page: int = 1, + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> MessagesResponse: + ignore_invalid_messages = ( + True if ignore_invalid_messages is None else ignore_invalid_messages + ) + invalid_messages_log_level = ( + logging.NOTSET + if invalid_messages_log_level is None + else invalid_messages_log_level + ) + + if not message_filter: + params = { + "page": str(page), + "pagination": str(page_size), + } + else: + params = message_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(page_size) + + async with self.http_session.get( + "/api/v0/messages.json", params=params + ) as resp: + resp.raise_for_status() + response_json = await resp.json() + messages_raw = response_json["messages"] + + # All messages may not be valid according to the latest specification in + # aleph-message. This allows the user to specify how errors should be handled. + messages: List[AlephMessage] = [] + for message_raw in messages_raw: + try: + message = parse_message(message_raw) + messages.append(message) + except KeyError as e: + if not ignore_invalid_messages: + raise e + logger.log( + level=invalid_messages_log_level, + msg=f"KeyError: Field '{e.args[0]}' not found", + ) + except ValidationError as e: + if not ignore_invalid_messages: + raise e + if invalid_messages_log_level: + logger.log(level=invalid_messages_log_level, msg=e) + + return MessagesResponse( + messages=messages, + pagination_page=response_json["pagination_page"], + pagination_total=response_json["pagination_total"], + pagination_per_page=response_json["pagination_per_page"], + pagination_item=response_json["pagination_item"], + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + messages_response = await self.get_messages( + message_filter=MessageFilter( + hashes=[item_hash], + channels=[channel] if channel else None, + ) + ) + if len(messages_response.messages) < 1: + raise MessageNotFoundError(f"No such hash {item_hash}") + if len(messages_response.messages) != 1: + raise MultipleMessagesError( + f"Multiple messages found for the same item_hash `{item_hash}`" + ) + message: GenericMessage = messages_response.messages[0] + if message_type: + expected_type = get_message_type_value(message_type) + if message.type != expected_type: + raise TypeError( + f"The message type '{message.type}' " + f"does not match the expected type '{expected_type}'" + ) + return message + + async def watch_messages( + self, + message_filter: Optional[MessageFilter] = None, + ) -> AsyncIterable[AlephMessage]: + message_filter = message_filter or MessageFilter() + params = message_filter.as_http_params() + + async with self.http_session.ws_connect( + "/api/ws0/messages", params=params + ) as ws: + logger.debug("Websocket connected") + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == "close cmd": + await ws.close() + break + else: + data = json.loads(msg.data) + yield parse_message(data) + elif msg.type == aiohttp.WSMsgType.ERROR: + break diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index 51762925..5f09e1bc 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -21,7 +21,7 @@ class MultipleMessagesError(QueryError): class BroadcastError(Exception): """ - Data could not be broadcast to the Aleph network. + Data could not be broadcast to the aleph.im network. """ pass @@ -29,7 +29,7 @@ class BroadcastError(Exception): class InvalidMessageError(BroadcastError): """ - The message could not be broadcast because it does not follow the Aleph + The message could not be broadcast because it does not follow the aleph.im message specification. """ diff --git a/src/aleph/sdk/models.py b/src/aleph/sdk/models.py deleted file mode 100644 index f5b1072b..00000000 --- a/src/aleph/sdk/models.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Any, Dict, List, Optional, Union - -from aleph_message.models import AlephMessage, BaseMessage, ChainRef, ItemHash -from pydantic import BaseModel, Field - - -class PaginationResponse(BaseModel): - pagination_page: int - pagination_total: int - pagination_per_page: int - pagination_item: str - - -class MessagesResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/messages.json""" - - messages: List[AlephMessage] - pagination_item = "messages" - - -class Post(BaseMessage): - """ - A post is a type of message that can be updated. Over the get_posts API - we get the latest version of a post. - """ - - hash: ItemHash = Field(description="Hash of the content (sha256 by default)") - original_item_hash: ItemHash = Field( - description="Hash of the original content (sha256 by default)" - ) - original_signature: Optional[str] = Field( - description="Cryptographic signature of the original message by the sender" - ) - original_type: str = Field( - description="The original, user-generated 'content-type' of the POST message" - ) - content: Dict[str, Any] = Field( - description="The content.content of the POST message" - ) - type: str = Field(description="The content.type of the POST message") - address: str = Field(description="The address of the sender of the POST message") - ref: Optional[Union[str, ChainRef]] = Field( - description="Other message referenced by this one" - ) - - -class PostsResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/posts.json""" - - posts: List[Post] - pagination_item = "posts" diff --git a/src/aleph/sdk/query/filters.py b/src/aleph/sdk/query/filters.py new file mode 100644 index 00000000..346f3a24 --- /dev/null +++ b/src/aleph/sdk/query/filters.py @@ -0,0 +1,162 @@ +from datetime import datetime +from typing import Dict, Iterable, Optional, Union + +from aleph_message.models import MessageType + +from ..utils import _date_field_to_float, serialize_list + + +class MessageFilter: + """ + A collection of filters that can be applied on message queries. + :param message_types: Filter by message type + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + + message_types: Optional[Iterable[MessageType]] + content_types: Optional[Iterable[str]] + content_keys: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + message_types: Optional[Iterable[MessageType]] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.message_types = message_types + self.content_types = content_types + self.content_keys = content_keys + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "msgType": serialize_list( + [type.value for type in self.message_types] + if self.message_types + else None + ), + "contentTypes": serialize_list(self.content_types), + "contentKeys": serialize_list(self.content_keys), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result + + +class PostFilter: + """ + A collection of filters that can be applied on post queries. + + """ + + types: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.types = types + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "types": serialize_list(self.types), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result diff --git a/src/aleph/sdk/query/responses.py b/src/aleph/sdk/query/responses.py new file mode 100644 index 00000000..5fb91804 --- /dev/null +++ b/src/aleph/sdk/query/responses.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union + +from aleph_message.models import ( + AlephMessage, + Chain, + ItemHash, + ItemType, + MessageConfirmation, +) +from pydantic import BaseModel, Field + + +class Post(BaseModel): + """ + A post is a type of message that can be updated. Over the get_posts API + we get the latest version of a post. + """ + + chain: Chain = Field(description="Blockchain this post is associated with") + item_hash: ItemHash = Field(description="Unique hash for this post") + sender: str = Field(description="Address of the sender") + type: str = Field(description="Type of the POST message") + channel: Optional[str] = Field(description="Channel this post is associated with") + confirmed: bool = Field(description="Whether the post is confirmed or not") + content: Dict[str, Any] = Field(description="The content of the POST message") + item_content: Optional[str] = Field( + description="The POSTs content field as serialized JSON, if of type inline" + ) + item_type: ItemType = Field( + description="Type of the item content, usually 'inline' or 'storage' for POSTs" + ) + signature: Optional[str] = Field( + description="Cryptographic signature of the message by the sender" + ) + size: int = Field(description="Size of the post") + time: float = Field(description="Timestamp of the post") + confirmations: List[MessageConfirmation] = Field( + description="Number of confirmations" + ) + original_item_hash: ItemHash = Field(description="Hash of the original content") + original_signature: Optional[str] = Field( + description="Cryptographic signature of the original message" + ) + original_type: str = Field(description="The original type of the message") + hash: ItemHash = Field(description="Hash of the original item") + ref: Optional[Union[str, Any]] = Field( + description="Other message referenced by this one" + ) + + class Config: + allow_extra = False + + +class PaginationResponse(BaseModel): + pagination_page: int + pagination_total: int + pagination_per_page: int + pagination_item: str + + +class PostsResponse(PaginationResponse): + """Response from an aleph.im node API on the path /api/v0/posts.json""" + + posts: List[Post] + pagination_item = "posts" + + +class MessagesResponse(PaginationResponse): + """Response from an aleph.im node API on the path /api/v0/messages.json""" + + messages: List[AlephMessage] + pagination_item = "messages" diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index be56cc2c..810d7326 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,10 +1,11 @@ import errno import logging import os +from datetime import datetime from enum import Enum from pathlib import Path from shutil import make_archive -from typing import Protocol, Tuple, Type, TypeVar, Union +from typing import Iterable, Optional, Protocol, Tuple, Type, TypeVar, Union from zipfile import BadZipFile, ZipFile from aleph_message.models import MessageType @@ -116,3 +117,21 @@ def enum_as_str(obj: Union[str, Enum]) -> str: return obj.value return obj + + +def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: + if values: + return ",".join(values) + else: + return None + + +def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: + if date is None: + return None + elif isinstance(date, float): + return date + elif hasattr(date, "timestamp"): + return date.timestamp() + else: + raise TypeError(f"Invalid type: `{type(date)}`") diff --git a/tests/integration/config.py b/tests/integration/config.py index 4ec95a27..3e613c18 100644 --- a/tests/integration/config.py +++ b/tests/integration/config.py @@ -1,3 +1,3 @@ -TARGET_NODE = "http://163.172.70.92:4024" +TARGET_NODE = "https://api1.aleph.im" REFERENCE_NODE = "https://api2.aleph.im" TEST_CHANNEL = "INTEGRATION_TESTS" diff --git a/tests/integration/itest_aggregates.py b/tests/integration/itest_aggregates.py index 5c5d4648..31f5c6cc 100644 --- a/tests/integration/itest_aggregates.py +++ b/tests/integration/itest_aggregates.py @@ -3,7 +3,7 @@ import pytest -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.types import Account from tests.integration.toolkit import try_until @@ -18,7 +18,7 @@ async def create_aggregate_on_target( receiver_node: str, channel="INTEGRATION_TESTS", ): - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=emitter_node ) as tx_session: aggregate_message, message_status = await tx_session.create_aggregate( @@ -38,7 +38,7 @@ async def create_aggregate_on_target( assert aggregate_message.content.address == account.get_address() assert aggregate_message.content.content == content - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=receiver_node ) as rx_session: aggregate_from_receiver = await try_until( diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 29b6c6d9..a6cc141c 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -1,33 +1,21 @@ -from typing import Callable, Dict +import asyncio +from typing import Tuple import pytest +from aleph_message.models import ItemHash -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient +from aleph.sdk.query.filters import MessageFilter from aleph.sdk.types import Account from .config import REFERENCE_NODE, TARGET_NODE, TEST_CHANNEL -from .toolkit import try_until +from .toolkit import has_messages, has_no_messages, try_until async def create_and_forget_post( account: Account, emitter_node: str, receiver_node: str, channel=TEST_CHANNEL -) -> str: - async def wait_matching_posts( - item_hash: str, - condition: Callable[[Dict], bool], - timeout: int = 5, - ): - async with AuthenticatedAlephClient( - account=account, api_server=receiver_node - ) as rx_session: - return await try_until( - rx_session.get_posts, - condition, - timeout=timeout, - hashes=[item_hash], - ) - - async with AuthenticatedAlephClient( +) -> Tuple[ItemHash, ItemHash]: + async with AuthenticatedAlephHttpClient( account=account, api_server=emitter_node ) as tx_session: post_message, message_status = await tx_session.create_post( @@ -36,17 +24,21 @@ async def wait_matching_posts( channel="INTEGRATION_TESTS", ) - # Wait for the message to appear on the receiver. We don't check the values, - # they're checked in other integration tests. - get_post_response = await wait_matching_posts( - post_message.item_hash, - lambda response: len(response["posts"]) > 0, - ) - print(get_post_response) + async with AuthenticatedAlephHttpClient( + account=account, api_server=receiver_node + ) as rx_session: + await try_until( + rx_session.get_messages, + has_messages, + timeout=5, + message_filter=MessageFilter( + hashes=[post_message.item_hash], + ), + ) post_hash = post_message.item_hash reason = "This well thought-out content offends me!" - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=emitter_node ) as tx_session: forget_message, forget_status = await tx_session.forget( @@ -54,27 +46,34 @@ async def wait_matching_posts( reason=reason, channel=channel, ) - assert forget_message.sender == account.get_address() assert forget_message.content.reason == reason assert forget_message.content.hashes == [post_hash] - - print(forget_message) + forget_hash = forget_message.item_hash # Wait until the message is forgotten - forgotten_posts = await wait_matching_posts( - post_hash, - lambda response: "forgotten_by" in response["posts"][0], - timeout=15, - ) + async with AuthenticatedAlephHttpClient( + account=account, api_server=receiver_node + ) as rx_session: + await try_until( + rx_session.get_messages, + has_messages, + timeout=5, + message_filter=MessageFilter( + hashes=[forget_hash], + ), + ) - assert len(forgotten_posts["posts"]) == 1 - forgotten_post = forgotten_posts["posts"][0] - assert forgotten_post["forgotten_by"] == [forget_message.item_hash] - assert forgotten_post["item_content"] is None - print(forgotten_post) + await try_until( + rx_session.get_messages, + has_no_messages, + timeout=5, + message_filter=MessageFilter( + hashes=[post_hash], + ), + ) - return post_hash + return post_hash, forget_hash @pytest.mark.asyncio @@ -83,7 +82,7 @@ async def test_create_and_forget_post_on_target(fixture_account): Create a post on the target node, then forget it and check that the change is propagated to the reference node. """ - _ = await create_and_forget_post(fixture_account, TARGET_NODE, REFERENCE_NODE) + _, _ = await create_and_forget_post(fixture_account, TARGET_NODE, REFERENCE_NODE) @pytest.mark.asyncio @@ -92,7 +91,7 @@ async def test_create_and_forget_post_on_reference(fixture_account): Create a post on the reference node, then forget it and check that the change is propagated to the target node. """ - _ = await create_and_forget_post(fixture_account, REFERENCE_NODE, TARGET_NODE) + _, _ = await create_and_forget_post(fixture_account, REFERENCE_NODE, TARGET_NODE) @pytest.mark.asyncio @@ -102,26 +101,33 @@ async def test_forget_a_forget_message(fixture_account): """ # TODO: this test should be moved to the PyAleph API tests, once a framework is in place. - post_hash = await create_and_forget_post(fixture_account, TARGET_NODE, TARGET_NODE) - async with AuthenticatedAlephClient( + post_hash, forget_hash = await create_and_forget_post( + fixture_account, TARGET_NODE, REFERENCE_NODE + ) + async with AuthenticatedAlephHttpClient( account=fixture_account, api_server=TARGET_NODE - ) as session: - get_post_response = await session.get_posts(hashes=[post_hash]) - assert len(get_post_response.posts) == 1 - post = get_post_response.posts[0] - - forget_message_hash = post.forgotten_by[0] - forget_message, forget_status = await session.forget( - hashes=[forget_message_hash], + ) as tx_session: + forget_message, forget_status = await tx_session.forget( + hashes=[forget_hash], reason="I want to remember this post. Maybe I can forget I forgot it?", channel=TEST_CHANNEL, ) print(forget_message) - get_forget_message_response = await session.get_messages( - hashes=[forget_message_hash], - channels=[TEST_CHANNEL], + # wait 5 seconds + await asyncio.sleep(5) + + async with AuthenticatedAlephHttpClient( + account=fixture_account, api_server=REFERENCE_NODE + ) as rx_session: + get_forget_message_response = await try_until( + rx_session.get_messages, + has_messages, + timeout=5, + message_filter=MessageFilter( + hashes=[forget_hash], + ), ) assert len(get_forget_message_response.messages) == 1 forget_message = get_forget_message_response.messages[0] diff --git a/tests/integration/itest_posts.py b/tests/integration/itest_posts.py index f30dc2b6..77b87b7f 100644 --- a/tests/integration/itest_posts.py +++ b/tests/integration/itest_posts.py @@ -1,20 +1,18 @@ import pytest -from aleph_message.models import MessagesResponse -from aleph.sdk.client import AuthenticatedAlephClient -from tests.integration.toolkit import try_until +from aleph.sdk.client import AuthenticatedAlephHttpClient +from aleph.sdk.query.filters import MessageFilter +from tests.integration.toolkit import has_messages, try_until from .config import REFERENCE_NODE, TARGET_NODE -async def create_message_on_target( - fixture_account, emitter_node: str, receiver_node: str -): +async def create_message_on_target(account, emitter_node: str, receiver_node: str): """ Create a POST message on the target node, then fetch it from the reference node. """ - async with AuthenticatedAlephClient( - account=fixture_account, api_server=emitter_node + async with AuthenticatedAlephHttpClient( + account=account, api_server=emitter_node ) as tx_session: post_message, message_status = await tx_session.create_post( post_content=None, @@ -22,17 +20,16 @@ async def create_message_on_target( channel="INTEGRATION_TESTS", ) - def response_contains_messages(response: MessagesResponse) -> bool: - return len(response.messages) > 0 - - async with AuthenticatedAlephClient( - account=fixture_account, api_server=receiver_node + async with AuthenticatedAlephHttpClient( + account=account, api_server=receiver_node ) as rx_session: responses = await try_until( rx_session.get_messages, - response_contains_messages, + has_messages, timeout=5, - hashes=[post_message.item_hash], + message_filter=MessageFilter( + hashes=[post_message.item_hash], + ), ) message_from_target = responses.messages[0] diff --git a/tests/integration/toolkit.py b/tests/integration/toolkit.py index 70bc3bbb..a72f9d6f 100644 --- a/tests/integration/toolkit.py +++ b/tests/integration/toolkit.py @@ -2,6 +2,8 @@ import time from typing import Awaitable, Callable, TypeVar +from aleph.sdk.query.responses import MessagesResponse + T = TypeVar("T") @@ -9,7 +11,7 @@ async def try_until( coroutine: Callable[..., Awaitable[T]], condition: Callable[[T], bool], timeout: float, - time_between_attempts: float = 0.5, + time_between_attempts: float = 1, *args, **kwargs, ) -> T: @@ -23,3 +25,11 @@ async def try_until( await asyncio.sleep(time_between_attempts) else: raise TimeoutError(f"No success in {timeout} seconds.") + + +def has_messages(response: MessagesResponse) -> bool: + return len(response.messages) > 0 + + +def has_no_messages(response: MessagesResponse) -> bool: + return len(response.messages) == 0 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4f62c0c5..a51b1483 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,8 +1,10 @@ import json from pathlib import Path from tempfile import NamedTemporaryFile +from typing import Any, Callable, Dict, List import pytest as pytest +from aleph_message.models import AggregateMessage, AlephMessage, PostMessage import aleph.sdk.chains.ethereum as ethereum import aleph.sdk.chains.sol as solana @@ -46,7 +48,77 @@ def substrate_account() -> substrate.DOTAccount: @pytest.fixture -def messages(): +def json_messages(): messages_path = Path(__file__).parent / "messages.json" with open(messages_path) as f: return json.load(f) + + +@pytest.fixture +def aleph_messages() -> List[AlephMessage]: + return [ + AggregateMessage.parse_obj( + { + "item_hash": "5b26d949fe05e38f535ef990a89da0473f9d700077cced228f2d36e73fca1fd6", + "type": "AGGREGATE", + "chain": "ETH", + "sender": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "signature": "0xca5825b6b93390482b436cb7f28b4628f8c9f56dc6af08260c869b79dd6017c94248839bd9fd0ffa1230dc3b1f4f7572a8d1f6fed6c6e1fb4d70ccda0ab5d4f21b", + "item_type": "inline", + "item_content": '{"address":"0x51A58800b26AA1451aaA803d1746687cB88E0501","key":"0xce844d79e5c0c325490c530aa41e8f602f0b5999binance","content":{"1692026263168":{"version":"x25519-xsalsa20-poly1305","nonce":"RT4Lbqs7Xzk+op2XC+VpXgwOgg21BotN","ephemPublicKey":"CVW8ECE3m8BepytHMTLan6/jgIfCxGdnKmX47YirF08=","ciphertext":"VuGJ9vMkJSbaYZCCv6Zemx4ixeb+9IW8H1vFB9vLtz1a8d87R4BfYUisLoCQxRkeUXqfW0/KIGQ5idVjr8Yj7QnKglW5AJ8UX7wEWMhiRFLatpWP8P9FI2n8Z7Rblu7Oz/OeKnuljKL3KsalcUQSsFa/1qACsIoycPZ6Wq6t1mXxVxxJWzClLyKRihv1pokZGT9UWxh7+tpoMGlRdYainyAt0/RygFw+r8iCMOilHnyv4ndLkKQJXyttb0tdNr/gr57+9761+trioGSysLQKZQWW6Ih6aE8V9t3BenfzYwiCnfFw3YAAKBPMdm9QdIETyrOi7YhD/w==","sha256":"bbeb499f681aed2bc18b6f3b6a30d25254bd30fbfde43444e9085f3bcd075c3c"}},"time":1692026263.662}', + "content": { + "key": "0xce844d79e5c0c325490c530aa41e8f602f0b5999binance", + "time": 1692026263.662, + "address": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "content": { + "hello": "world", + }, + }, + "time": 1692026263.662, + "channel": "UNSLASHED", + "size": 734, + "confirmations": [], + "confirmed": False, + } + ), + PostMessage.parse_obj( + { + "item_hash": "70f3798fdc68ce0ee03715a5547ee24e2c3e259bf02e3f5d1e4bf5a6f6a5e99f", + "type": "POST", + "chain": "SOL", + "sender": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "signature": "0x91616ee45cfba55742954ff87ebf86db4988bcc5e3334b49a4caa6436e28e28d4ab38667cbd4bfb8903abf8d71f70d9ceb2c0a8d0a15c04fc1af5657f0050c101b", + "item_type": "storage", + "item_content": None, + "content": { + "time": 1692026021.1257718, + "type": "aleph-network-metrics", + "address": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "ref": "0123456789abcdef", + "content": { + "tags": ["mainnet"], + "hello": "world", + "version": "1.0", + }, + }, + "time": 1692026021.132849, + "channel": "aleph-scoring", + "size": 122537, + "confirmations": [], + "confirmed": False, + } + ), + ] + + +@pytest.fixture +def raw_messages_response(aleph_messages) -> Callable[[int], Dict[str, Any]]: + return lambda page: { + "messages": [message.dict() for message in aleph_messages] + if int(page) == 1 + else [], + "pagination_item": "messages", + "pagination_page": int(page), + "pagination_per_page": max(len(aleph_messages), 20), + "pagination_total": len(aleph_messages) if page == 1 else 0, + } diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index 8973263b..dbccbaa6 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -11,14 +11,14 @@ ) from aleph_message.status import MessageStatus -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.types import Account, StorageEnum @pytest.fixture def mock_session_with_post_success( ethereum_account: Account, -) -> AuthenticatedAlephClient: +) -> AuthenticatedAlephHttpClient: class MockResponse: def __init__(self, sync: bool): self.sync = sync @@ -49,7 +49,7 @@ async def text(self): sync=kwargs.get("sync", False) ) - client = AuthenticatedAlephClient( + client = AuthenticatedAlephHttpClient( account=ethereum_account, api_server="http://localhost" ) client.http_session = http_session diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index db788e0b..f5e0c800 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -3,14 +3,15 @@ from unittest.mock import AsyncMock import pytest -from aleph_message.models import MessagesResponse +from aleph_message.models import MessagesResponse, MessageType -from aleph.sdk.client import AlephClient +from aleph.sdk import AlephHttpClient from aleph.sdk.conf import settings -from aleph.sdk.models import PostsResponse +from aleph.sdk.query.filters import MessageFilter, PostFilter +from aleph.sdk.query.responses import PostsResponse -def make_mock_session(get_return_value: Dict[str, Any]) -> AlephClient: +def make_mock_session(get_return_value: Dict[str, Any]) -> AlephHttpClient: class MockResponse: async def __aenter__(self): return self @@ -22,6 +23,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def status(self): return 200 + def raise_for_status(self): + ... + async def json(self): return get_return_value @@ -31,7 +35,7 @@ def get(self, *_args, **_kwargs): http_session = MockHttpSession() - client = AlephClient(api_server="http://localhost") + client = AlephHttpClient(api_server="http://localhost") client.http_session = http_session return client @@ -66,8 +70,13 @@ async def test_fetch_aggregates(): @pytest.mark.asyncio async def test_get_posts(): - async with AlephClient(api_server=settings.API_HOST) as session: - response: PostsResponse = await session.get_posts() + async with AlephHttpClient(api_server=settings.API_HOST) as session: + response: PostsResponse = await session.get_posts( + page_size=2, + post_filter=PostFilter( + channels=["TEST"], + ), + ) posts = response.posts assert len(posts) > 1 @@ -75,9 +84,12 @@ async def test_get_posts(): @pytest.mark.asyncio async def test_get_messages(): - async with AlephClient(api_server=settings.API_HOST) as session: + async with AlephHttpClient(api_server=settings.API_HOST) as session: response: MessagesResponse = await session.get_messages( - pagination=2, + page_size=2, + message_filter=MessageFilter( + message_types=[MessageType.post], + ), ) messages = response.messages diff --git a/tests/unit/test_chain_ethereum.py b/tests/unit/test_chain_ethereum.py index dea58c69..9a602b3d 100644 --- a/tests/unit/test_chain_ethereum.py +++ b/tests/unit/test_chain_ethereum.py @@ -82,8 +82,8 @@ async def test_verify_signature(ethereum_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(ethereum_account, messages): - message = messages[1] +async def test_verify_signature_with_processed_message(ethereum_account, json_messages): + message = json_messages[1] verify_signature( message["signature"], message["sender"], get_verification_buffer(message) ) diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index 5088158a..07b67602 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -103,8 +103,8 @@ async def test_verify_signature(solana_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(solana_account, messages): - message = messages[0] +async def test_verify_signature_with_processed_message(solana_account, json_messages): + message = json_messages[0] signature = json.loads(message["signature"])["signature"] verify_signature(signature, message["sender"], get_verification_buffer(message)) diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index b16e0d75..377e6d41 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -1,6 +1,6 @@ import pytest -from aleph.sdk import AlephClient +from aleph.sdk import AlephHttpClient from aleph.sdk.conf import settings as sdk_settings @@ -13,7 +13,7 @@ ) @pytest.mark.asyncio async def test_download(file_hash: str, expected_size: int): - async with AlephClient(api_server=sdk_settings.API_HOST) as client: + async with AlephHttpClient(api_server=sdk_settings.API_HOST) as client: file_content = await client.download_file(file_hash) # File is 5B file_size = len(file_content) assert file_size == expected_size @@ -28,7 +28,7 @@ async def test_download(file_hash: str, expected_size: int): ) @pytest.mark.asyncio async def test_download_ipfs(file_hash: str, expected_size: int): - async with AlephClient(api_server=sdk_settings.API_HOST) as client: + async with AlephHttpClient(api_server=sdk_settings.API_HOST) as client: file_content = await client.download_file_ipfs(file_hash) # 5817703 B FILE file_size = len(file_content) assert file_size == expected_size diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py deleted file mode 100644 index eee26dcf..00000000 --- a/tests/unit/test_synchronous_get.py +++ /dev/null @@ -1,18 +0,0 @@ -from aleph_message.models import MessagesResponse, MessageType - -from aleph.sdk.client import AlephClient -from aleph.sdk.conf import settings - - -def test_get_post_messages(): - with AlephClient(api_server=settings.API_HOST) as session: - # TODO: Remove deprecated message_type parameter after message_types changes on pyaleph are deployed - response: MessagesResponse = session.get_messages( - pagination=2, - message_type=MessageType.post, - ) - - messages = response.messages - assert len(messages) > 1 - for message in messages: - assert message.type == MessageType.post