Skip to content

Commit

Permalink
add posts table for caching posts; handle amend messages; refactor no…
Browse files Browse the repository at this point in the history
…de.py as a package
  • Loading branch information
MHHukiewitz committed Sep 6, 2023
1 parent 30f505e commit 603afef
Show file tree
Hide file tree
Showing 4 changed files with 372 additions and 209 deletions.
285 changes: 76 additions & 209 deletions src/aleph/sdk/node.py → src/aleph/sdk/node/__init__.py
Original file line number Diff line number Diff line change
@@ -1,233 +1,57 @@
import asyncio
import json
import logging
import typing
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import (
Any,
AsyncIterable,
Coroutine,
Dict,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

from aleph_message import MessagesResponse, parse_message
from aleph_message.models import (
AlephMessage,
Chain,
ItemHash,
MessageConfirmation,
MessageType,
)
from aleph_message import MessagesResponse
from aleph_message.models import AlephMessage, Chain, ItemHash, MessageType, PostMessage
from aleph_message.models.execution.base import Encoding
from aleph_message.status import MessageStatus
from peewee import (
BooleanField,
CharField,
FloatField,
IntegerField,
Model,
SqliteDatabase,
)
from playhouse.shortcuts import model_to_dict
from playhouse.sqlite_ext import JSONField
from pydantic import BaseModel

from aleph.sdk import AuthenticatedAlephClient
from aleph.sdk.base import AlephClientBase, AuthenticatedAlephClientBase
from aleph.sdk.conf import settings
from aleph.sdk.exceptions import MessageNotFoundError
from aleph.sdk.models import PostsResponse
from aleph.sdk.types import GenericMessage, StorageEnum

db = SqliteDatabase(settings.CACHE_DATABASE_PATH)
T = TypeVar("T", bound=BaseModel)


class JSONDictEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, BaseModel):
return obj.dict()
return json.JSONEncoder.default(self, obj)


pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder)


class PydanticField(JSONField, Generic[T]):
"""
A field for storing pydantic model types as JSON in a database. Uses json for serialization.
"""

type: T

def __init__(self, *args, **kwargs):
self.type = kwargs.pop("type")
super().__init__(*args, **kwargs)
from ..base import BaseAlephClient, BaseAuthenticatedAlephClient
from ..client import AuthenticatedAlephClient
from ..conf import settings
from ..exceptions import MessageNotFoundError
from ..models import PostsResponse
from ..types import GenericMessage, StorageEnum
from .common import db
from .message import MessageModel, get_message_query, message_to_model, model_to_message
from .post import PostModel, get_post_query, message_to_post, model_to_post

def db_value(self, value: Optional[T]) -> Optional[str]:
if value is None:
return None
return value.json()

def python_value(self, value: Optional[str]) -> Optional[T]:
if value is None:
return None
return self.type.parse_raw(value)


class MessageModel(Model):
"""
A simple database model for storing AlephMessage objects.
"""

item_hash = CharField(primary_key=True)
chain = CharField(5)
type = CharField(9)
sender = CharField()
channel = CharField(null=True)
confirmations: PydanticField[MessageConfirmation] = PydanticField(
type=MessageConfirmation, null=True
)
confirmed = BooleanField(null=True)
signature = CharField(null=True)
size = IntegerField(null=True)
time = FloatField()
item_type = CharField(7)
item_content = CharField(null=True)
hash_type = CharField(6, null=True)
content = JSONField(json_dumps=pydantic_json_dumps)
forgotten_by = CharField(null=True)
tags = JSONField(json_dumps=pydantic_json_dumps, null=True)
key = CharField(null=True)
ref = CharField(null=True)
content_type = CharField(null=True)

class Meta:
database = db


def message_to_model(message: AlephMessage) -> Dict:
return {
"item_hash": str(message.item_hash),
"chain": message.chain,
"type": message.type,
"sender": message.sender,
"channel": message.channel,
"confirmations": message.confirmations[0] if message.confirmations else None,
"confirmed": message.confirmed,
"signature": message.signature,
"size": message.size,
"time": message.time,
"item_type": message.item_type,
"item_content": message.item_content,
"hash_type": message.hash_type,
"content": message.content,
"forgotten_by": message.forgotten_by[0] if message.forgotten_by else None,
"tags": message.content.content.get("tags", None)
if hasattr(message.content, "content")
else None,
"key": message.content.key if hasattr(message.content, "key") else None,
"ref": message.content.ref if hasattr(message.content, "ref") else None,
"content_type": message.content.type
if hasattr(message.content, "type")
else None,
}


def model_to_message(item: Any) -> AlephMessage:
item.confirmations = [item.confirmations] if item.confirmations else []
item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None

to_exclude = [
MessageModel.tags,
MessageModel.ref,
MessageModel.key,
MessageModel.content_type,
]

item_dict = model_to_dict(item, exclude=to_exclude)
return parse_message(item_dict)


def query_field(field_name, field_values: Iterable[str]):
field = getattr(MessageModel, field_name)
values = list(field_values)

if len(values) == 1:
return field == values[0]
return field.in_(values)


def get_message_query(
message_type: Optional[MessageType] = None,
content_keys: Optional[Iterable[str]] = None,
content_types: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
):
query = MessageModel.select().order_by(MessageModel.time.desc())
conditions = []
if message_type:
conditions.append(query_field("type", [message_type.value]))
if content_keys:
conditions.append(query_field("key", content_keys))
if content_types:
conditions.append(query_field("content_type", content_types))
if refs:
conditions.append(query_field("ref", refs))
if addresses:
conditions.append(query_field("sender", addresses))
if tags:
for tag in tags:
conditions.append(MessageModel.tags.contains(tag))
if hashes:
conditions.append(query_field("item_hash", hashes))
if channels:
conditions.append(query_field("channel", channels))
if chains:
conditions.append(query_field("chain", chains))
if start_date:
conditions.append(MessageModel.time >= start_date)
if end_date:
conditions.append(MessageModel.time <= end_date)

if conditions:
query = query.where(*conditions)
return query


class MessageCache(AlephClientBase):
class MessageCache(BaseAlephClient):
"""
A wrapper around a sqlite3 database for caching AlephMessage objects.
It can be used independently of a DomainNode to implement any kind of caching strategy.
"""

_instance_count = 0 # Class-level counter for active instances
missing_posts: Dict[ItemHash, PostMessage] = {}
"""A dict of all posts by item_hash and their amend messages that are missing from the cache."""

def __init__(self):
if db.is_closed():
db.connect()
if not MessageModel.table_exists():
db.create_tables([MessageModel])
if not PostModel.table_exists():
db.create_tables([PostModel])

MessageCache._instance_count += 1

Expand Down Expand Up @@ -270,17 +94,57 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return repr(self)

@staticmethod
def add(messages: Union[AlephMessage, Iterable[AlephMessage]]):
def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]):
if isinstance(messages, typing.get_args(AlephMessage)):
messages = [messages]

data_source = (message_to_model(message) for message in messages)
MessageModel.insert_many(data_source).on_conflict_replace().execute()
message_data = (message_to_model(message) for message in messages)
MessageModel.insert_many(message_data).on_conflict_replace().execute()

# Add posts and their amends to the PostModel
post_data = []
amend_messages = []
for message in messages:
if message.item_type != MessageType.post:
continue
if message.content.type == "amend":
amend_messages.append(message)
else:
post = message_to_post(message).dict()
post_data.append(post)
# Check if we can now add any amend messages that had missing refs
if message.item_hash in self.missing_posts:
amend_messages += self.missing_posts.pop(message.item_hash)

PostModel.insert_many(post_data).on_conflict_replace().execute()

# Handle amends in second step to avoid missing original posts
post_data = []
for message in amend_messages:
# Find the original post and update it
original_post = MessageModel.get(
MessageModel.item_hash == message.content.ref
)
if not original_post:
latest_amend = self.missing_posts.get(ItemHash(message.content.ref))
if latest_amend and message.time < latest_amend.time:
self.missing_posts[ItemHash(message.content.ref)] = message
continue
if datetime.fromtimestamp(message.time) < original_post.last_updated:
continue
original_post.item_hash = message.item_hash
original_post.content = message.content.content
original_post.original_item_hash = message.content.ref
original_post.original_type = message.content.type
original_post.address = message.sender
original_post.channel = message.channel
original_post.last_updated = datetime.fromtimestamp(message.time)
post_data.append(original_post)

PostModel.insert_many(post_data).on_conflict_replace().execute()

@staticmethod
def get(
item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]]
self, item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]]
) -> List[AlephMessage]:
"""
Get many messages from the cache by their item hash.
Expand Down Expand Up @@ -347,12 +211,11 @@ async def get_posts(
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
ignore_invalid_messages: bool = True,
invalid_messages_log_level: int = logging.NOTSET,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
) -> PostsResponse:
query = get_message_query(
message_type=MessageType.post,
content_types=types,
query = get_post_query(
types=types,
refs=refs,
addresses=addresses,
tags=tags,
Expand All @@ -365,7 +228,7 @@ async def get_posts(

query = query.paginate(page, pagination)

posts = [model_to_message(item) for item in list(query)]
posts = [model_to_post(item) for item in list(query)]

return PostsResponse(
posts=posts,
Expand All @@ -383,6 +246,7 @@ async def get_messages(
pagination: int = 200,
page: int = 1,
message_type: Optional[MessageType] = None,
message_types: Optional[Iterable[MessageType]] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
Expand All @@ -393,14 +257,15 @@ async def get_messages(
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
ignore_invalid_messages: bool = True,
invalid_messages_log_level: int = logging.NOTSET,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
) -> MessagesResponse:
"""
Get many messages from the cache.
"""
message_types = message_types or [message_type] if message_type else None
query = get_message_query(
message_type=message_type,
message_types=message_types,
content_keys=content_keys,
content_types=content_types,
refs=refs,
Expand Down Expand Up @@ -451,6 +316,7 @@ async def get_message(
async def watch_messages(
self,
message_type: Optional[MessageType] = None,
message_types: Optional[Iterable[MessageType]] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
Expand All @@ -465,8 +331,9 @@ async def watch_messages(
"""
Watch messages from the cache.
"""
message_types = message_types or [message_type] if message_type else None
query = get_message_query(
message_type=message_type,
message_types=message_types,
content_keys=content_keys,
content_types=content_types,
refs=refs,
Expand All @@ -483,7 +350,7 @@ async def watch_messages(
yield model_to_message(item)


class DomainNode(MessageCache, AuthenticatedAlephClientBase):
class DomainNode(MessageCache, BaseAuthenticatedAlephClient):
"""
A Domain Node is a queryable proxy for Aleph Messages that are stored in a database cache and/or in the Aleph
network.
Expand Down
Loading

0 comments on commit 603afef

Please sign in to comment.