diff --git a/src/aleph/sdk/node.py b/src/aleph/sdk/node/__init__.py similarity index 71% rename from src/aleph/sdk/node.py rename to src/aleph/sdk/node/__init__.py index a9548e67..1477ac2c 100644 --- a/src/aleph/sdk/node.py +++ b/src/aleph/sdk/node/__init__.py @@ -1,16 +1,13 @@ import asyncio -import json import logging import typing from datetime import datetime -from functools import partial from pathlib import Path from typing import ( Any, AsyncIterable, Coroutine, Dict, - Generic, Iterable, Iterator, List, @@ -18,203 +15,26 @@ Optional, Tuple, Type, - TypeVar, Union, ) -from aleph_message import MessagesResponse, parse_message -from aleph_message.models import ( - AlephMessage, - Chain, - ItemHash, - MessageConfirmation, - MessageType, -) +from aleph_message import MessagesResponse +from aleph_message.models import AlephMessage, Chain, ItemHash, MessageType, PostMessage from aleph_message.models.execution.base import Encoding from aleph_message.status import MessageStatus -from peewee import ( - BooleanField, - CharField, - FloatField, - IntegerField, - Model, - SqliteDatabase, -) -from playhouse.shortcuts import model_to_dict -from playhouse.sqlite_ext import JSONField -from pydantic import BaseModel - -from aleph.sdk import AuthenticatedAlephClient -from aleph.sdk.base import AlephClientBase, AuthenticatedAlephClientBase -from aleph.sdk.conf import settings -from aleph.sdk.exceptions import MessageNotFoundError -from aleph.sdk.models import PostsResponse -from aleph.sdk.types import GenericMessage, StorageEnum - -db = SqliteDatabase(settings.CACHE_DATABASE_PATH) -T = TypeVar("T", bound=BaseModel) - - -class JSONDictEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, BaseModel): - return obj.dict() - return json.JSONEncoder.default(self, obj) - - -pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder) - - -class PydanticField(JSONField, Generic[T]): - """ - A field for storing pydantic model types as JSON in a database. Uses json for serialization. - """ - - type: T - def __init__(self, *args, **kwargs): - self.type = kwargs.pop("type") - super().__init__(*args, **kwargs) +from ..base import BaseAlephClient, BaseAuthenticatedAlephClient +from ..client import AuthenticatedAlephClient +from ..conf import settings +from ..exceptions import MessageNotFoundError +from ..models import PostsResponse +from ..types import GenericMessage, StorageEnum +from .common import db +from .message import MessageModel, get_message_query, message_to_model, model_to_message +from .post import PostModel, get_post_query, message_to_post, model_to_post - def db_value(self, value: Optional[T]) -> Optional[str]: - if value is None: - return None - return value.json() - - def python_value(self, value: Optional[str]) -> Optional[T]: - if value is None: - return None - return self.type.parse_raw(value) - - -class MessageModel(Model): - """ - A simple database model for storing AlephMessage objects. - """ - item_hash = CharField(primary_key=True) - chain = CharField(5) - type = CharField(9) - sender = CharField() - channel = CharField(null=True) - confirmations: PydanticField[MessageConfirmation] = PydanticField( - type=MessageConfirmation, null=True - ) - confirmed = BooleanField(null=True) - signature = CharField(null=True) - size = IntegerField(null=True) - time = FloatField() - item_type = CharField(7) - item_content = CharField(null=True) - hash_type = CharField(6, null=True) - content = JSONField(json_dumps=pydantic_json_dumps) - forgotten_by = CharField(null=True) - tags = JSONField(json_dumps=pydantic_json_dumps, null=True) - key = CharField(null=True) - ref = CharField(null=True) - content_type = CharField(null=True) - - class Meta: - database = db - - -def message_to_model(message: AlephMessage) -> Dict: - return { - "item_hash": str(message.item_hash), - "chain": message.chain, - "type": message.type, - "sender": message.sender, - "channel": message.channel, - "confirmations": message.confirmations[0] if message.confirmations else None, - "confirmed": message.confirmed, - "signature": message.signature, - "size": message.size, - "time": message.time, - "item_type": message.item_type, - "item_content": message.item_content, - "hash_type": message.hash_type, - "content": message.content, - "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, - "tags": message.content.content.get("tags", None) - if hasattr(message.content, "content") - else None, - "key": message.content.key if hasattr(message.content, "key") else None, - "ref": message.content.ref if hasattr(message.content, "ref") else None, - "content_type": message.content.type - if hasattr(message.content, "type") - else None, - } - - -def model_to_message(item: Any) -> AlephMessage: - item.confirmations = [item.confirmations] if item.confirmations else [] - item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None - - to_exclude = [ - MessageModel.tags, - MessageModel.ref, - MessageModel.key, - MessageModel.content_type, - ] - - item_dict = model_to_dict(item, exclude=to_exclude) - return parse_message(item_dict) - - -def query_field(field_name, field_values: Iterable[str]): - field = getattr(MessageModel, field_name) - values = list(field_values) - - if len(values) == 1: - return field == values[0] - return field.in_(values) - - -def get_message_query( - message_type: Optional[MessageType] = None, - content_keys: Optional[Iterable[str]] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, -): - query = MessageModel.select().order_by(MessageModel.time.desc()) - conditions = [] - if message_type: - conditions.append(query_field("type", [message_type.value])) - if content_keys: - conditions.append(query_field("key", content_keys)) - if content_types: - conditions.append(query_field("content_type", content_types)) - if refs: - conditions.append(query_field("ref", refs)) - if addresses: - conditions.append(query_field("sender", addresses)) - if tags: - for tag in tags: - conditions.append(MessageModel.tags.contains(tag)) - if hashes: - conditions.append(query_field("item_hash", hashes)) - if channels: - conditions.append(query_field("channel", channels)) - if chains: - conditions.append(query_field("chain", chains)) - if start_date: - conditions.append(MessageModel.time >= start_date) - if end_date: - conditions.append(MessageModel.time <= end_date) - - if conditions: - query = query.where(*conditions) - return query - - -class MessageCache(AlephClientBase): +class MessageCache(BaseAlephClient): """ A wrapper around a sqlite3 database for caching AlephMessage objects. @@ -222,12 +42,16 @@ class MessageCache(AlephClientBase): """ _instance_count = 0 # Class-level counter for active instances + missing_posts: Dict[ItemHash, PostMessage] = {} + """A dict of all posts by item_hash and their amend messages that are missing from the cache.""" def __init__(self): if db.is_closed(): db.connect() if not MessageModel.table_exists(): db.create_tables([MessageModel]) + if not PostModel.table_exists(): + db.create_tables([PostModel]) MessageCache._instance_count += 1 @@ -270,17 +94,57 @@ def __repr__(self) -> str: def __str__(self) -> str: return repr(self) - @staticmethod - def add(messages: Union[AlephMessage, Iterable[AlephMessage]]): + def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): if isinstance(messages, typing.get_args(AlephMessage)): messages = [messages] - data_source = (message_to_model(message) for message in messages) - MessageModel.insert_many(data_source).on_conflict_replace().execute() + message_data = (message_to_model(message) for message in messages) + MessageModel.insert_many(message_data).on_conflict_replace().execute() + + # Add posts and their amends to the PostModel + post_data = [] + amend_messages = [] + for message in messages: + if message.item_type != MessageType.post: + continue + if message.content.type == "amend": + amend_messages.append(message) + else: + post = message_to_post(message).dict() + post_data.append(post) + # Check if we can now add any amend messages that had missing refs + if message.item_hash in self.missing_posts: + amend_messages += self.missing_posts.pop(message.item_hash) + + PostModel.insert_many(post_data).on_conflict_replace().execute() + + # Handle amends in second step to avoid missing original posts + post_data = [] + for message in amend_messages: + # Find the original post and update it + original_post = MessageModel.get( + MessageModel.item_hash == message.content.ref + ) + if not original_post: + latest_amend = self.missing_posts.get(ItemHash(message.content.ref)) + if latest_amend and message.time < latest_amend.time: + self.missing_posts[ItemHash(message.content.ref)] = message + continue + if datetime.fromtimestamp(message.time) < original_post.last_updated: + continue + original_post.item_hash = message.item_hash + original_post.content = message.content.content + original_post.original_item_hash = message.content.ref + original_post.original_type = message.content.type + original_post.address = message.sender + original_post.channel = message.channel + original_post.last_updated = datetime.fromtimestamp(message.time) + post_data.append(original_post) + + PostModel.insert_many(post_data).on_conflict_replace().execute() - @staticmethod def get( - item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]] + self, item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]] ) -> List[AlephMessage]: """ Get many messages from the cache by their item hash. @@ -347,12 +211,11 @@ async def get_posts( chains: Optional[Iterable[str]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, - ignore_invalid_messages: bool = True, - invalid_messages_log_level: int = logging.NOTSET, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: - query = get_message_query( - message_type=MessageType.post, - content_types=types, + query = get_post_query( + types=types, refs=refs, addresses=addresses, tags=tags, @@ -365,7 +228,7 @@ async def get_posts( query = query.paginate(page, pagination) - posts = [model_to_message(item) for item in list(query)] + posts = [model_to_post(item) for item in list(query)] return PostsResponse( posts=posts, @@ -383,6 +246,7 @@ async def get_messages( pagination: int = 200, page: int = 1, message_type: Optional[MessageType] = None, + message_types: Optional[Iterable[MessageType]] = None, content_types: Optional[Iterable[str]] = None, content_keys: Optional[Iterable[str]] = None, refs: Optional[Iterable[str]] = None, @@ -393,14 +257,15 @@ async def get_messages( chains: Optional[Iterable[str]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, - ignore_invalid_messages: bool = True, - invalid_messages_log_level: int = logging.NOTSET, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: """ Get many messages from the cache. """ + message_types = message_types or [message_type] if message_type else None query = get_message_query( - message_type=message_type, + message_types=message_types, content_keys=content_keys, content_types=content_types, refs=refs, @@ -451,6 +316,7 @@ async def get_message( async def watch_messages( self, message_type: Optional[MessageType] = None, + message_types: Optional[Iterable[MessageType]] = None, content_types: Optional[Iterable[str]] = None, content_keys: Optional[Iterable[str]] = None, refs: Optional[Iterable[str]] = None, @@ -465,8 +331,9 @@ async def watch_messages( """ Watch messages from the cache. """ + message_types = message_types or [message_type] if message_type else None query = get_message_query( - message_type=message_type, + message_types=message_types, content_keys=content_keys, content_types=content_types, refs=refs, @@ -483,7 +350,7 @@ async def watch_messages( yield model_to_message(item) -class DomainNode(MessageCache, AuthenticatedAlephClientBase): +class DomainNode(MessageCache, BaseAuthenticatedAlephClient): """ A Domain Node is a queryable proxy for Aleph Messages that are stored in a database cache and/or in the Aleph network. diff --git a/src/aleph/sdk/node/common.py b/src/aleph/sdk/node/common.py new file mode 100644 index 00000000..baed8b39 --- /dev/null +++ b/src/aleph/sdk/node/common.py @@ -0,0 +1,44 @@ +import json +from functools import partial +from typing import Generic, Optional, TypeVar + +from peewee import SqliteDatabase +from playhouse.sqlite_ext import JSONField +from pydantic import BaseModel + +from aleph.sdk.conf import settings + +db = SqliteDatabase(settings.CACHE_DATABASE_PATH) +T = TypeVar("T", bound=BaseModel) + + +class JSONDictEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, BaseModel): + return obj.dict() + return json.JSONEncoder.default(self, obj) + + +pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder) + + +class PydanticField(JSONField, Generic[T]): + """ + A field for storing pydantic model types as JSON in a database. Uses json for serialization. + """ + + type: T + + def __init__(self, *args, **kwargs): + self.type = kwargs.pop("type") + super().__init__(*args, **kwargs) + + def db_value(self, value: Optional[T]) -> Optional[str]: + if value is None: + return None + return value.json() + + def python_value(self, value: Optional[str]) -> Optional[T]: + if value is None: + return None + return self.type.parse_raw(value) diff --git a/src/aleph/sdk/node/message.py b/src/aleph/sdk/node/message.py new file mode 100644 index 00000000..a3327d2a --- /dev/null +++ b/src/aleph/sdk/node/message.py @@ -0,0 +1,137 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, Optional, Union + +from aleph_message import parse_message +from aleph_message.models import AlephMessage, MessageConfirmation, MessageType +from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField + +from aleph.sdk.node.common import PydanticField, db, pydantic_json_dumps + + +class MessageModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + item_hash = CharField(primary_key=True) + chain = CharField(5) + type = CharField(9) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField(null=True) + signature = CharField(null=True) + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + hash_type = CharField(6, null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + forgotten_by = CharField(null=True) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + class Meta: + database = db + + +def message_to_model(message: AlephMessage) -> Dict: + return { + "item_hash": str(message.item_hash), + "chain": message.chain, + "type": message.type, + "sender": message.sender, + "channel": message.channel, + "confirmations": message.confirmations[0] if message.confirmations else None, + "confirmed": message.confirmed, + "signature": message.signature, + "size": message.size, + "time": message.time, + "item_type": message.item_type, + "item_content": message.item_content, + "hash_type": message.hash_type, + "content": message.content, + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, + "tags": message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None, + "key": message.content.key if hasattr(message.content, "key") else None, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "content_type": message.content.type + if hasattr(message.content, "type") + else None, + } + + +def model_to_message(item: Any) -> AlephMessage: + item.confirmations = [item.confirmations] if item.confirmations else [] + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None + + to_exclude = [ + MessageModel.tags, + MessageModel.ref, + MessageModel.key, + MessageModel.content_type, + ] + + item_dict = model_to_dict(item, exclude=to_exclude) + return parse_message(item_dict) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(MessageModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def get_message_query( + message_types: Optional[Iterable[MessageType]] = None, + content_keys: Optional[Iterable[str]] = None, + content_types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, +): + query = MessageModel.select().order_by(MessageModel.time.desc()) + conditions = [] + if message_types: + conditions.append(query_field("type", [type.value for type in message_types])) + if content_keys: + conditions.append(query_field("key", content_keys)) + if content_types: + conditions.append(query_field("content_type", content_types)) + if refs: + conditions.append(query_field("ref", refs)) + if addresses: + conditions.append(query_field("sender", addresses)) + if tags: + for tag in tags: + conditions.append(MessageModel.tags.contains(tag)) + if hashes: + conditions.append(query_field("item_hash", hashes)) + if channels: + conditions.append(query_field("channel", channels)) + if chains: + conditions.append(query_field("chain", chains)) + if start_date: + conditions.append(MessageModel.time >= start_date) + if end_date: + conditions.append(MessageModel.time <= end_date) + + if conditions: + query = query.where(*conditions) + return query diff --git a/src/aleph/sdk/node/post.py b/src/aleph/sdk/node/post.py new file mode 100644 index 00000000..b68a421d --- /dev/null +++ b/src/aleph/sdk/node/post.py @@ -0,0 +1,115 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, Optional, Union + +from aleph_message.models import PostMessage +from peewee import CharField, DateTimeField, Model +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField + +from aleph.sdk.models import Post +from aleph.sdk.node.common import db, pydantic_json_dumps + + +class PostModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + original_item_hash = CharField(primary_key=True) + item_hash = CharField() + content = JSONField(json_dumps=pydantic_json_dumps) + original_type = CharField() + address = CharField() + ref = CharField(null=True) + channel = CharField(null=True) + created = DateTimeField() + last_updated = DateTimeField() + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + chain = CharField(5) + + class Meta: + database = db + + +def post_to_model(post: Post) -> Dict: + return { + "item_hash": str(post.item_hash), + "content": post.content, + "original_item_hash": str(post.original_item_hash), + "original_type": post.original_type, + "address": post.address, + "ref": post.ref, + "channel": post.channel, + "created": post.created, + "last_updated": post.last_updated, + } + + +def message_to_post(message: PostMessage) -> Post: + return Post.parse_obj( + { + "item_hash": str(message.item_hash), + "content": message.content, + "original_item_hash": str(message.item_hash), + "original_type": message.content.type + if hasattr(message.content, "type") + else None, + "address": message.sender, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "channel": message.channel, + "created": datetime.fromtimestamp(message.time), + "last_updated": datetime.fromtimestamp(message.time), + } + ) + + +def model_to_post(item: Any) -> Post: + to_exclude = [PostModel.tags, PostModel.chain] + return Post.parse_obj(model_to_dict(item, exclude=to_exclude)) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(PostModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def get_post_query( + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, +): + query = PostModel.select().order_by(PostModel.created.desc()) + conditions = [] + if types: + conditions.append(query_field("original_type", types)) + if refs: + conditions.append(query_field("ref", refs)) + if addresses: + conditions.append(query_field("address", addresses)) + if tags: + for tag in tags: + conditions.append(PostModel.tags.contains(tag)) + if hashes: + conditions.append(query_field("item_hash", hashes)) + if channels: + conditions.append(query_field("channel", channels)) + if chains: + conditions.append(query_field("chain", chains)) + if start_date: + conditions.append(PostModel.time >= start_date) + if end_date: + conditions.append(PostModel.time <= end_date) + + if conditions: + query = query.where(*conditions) + return query