Skip to content

Commit

Permalink
fix post problems
Browse files Browse the repository at this point in the history
  • Loading branch information
MHHukiewitz committed Sep 6, 2023
1 parent 603afef commit 6bf3144
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
5 changes: 4 additions & 1 deletion src/aleph/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ class Post(BaseModel):
)
address: str = Field(description="The address of the sender of the POST message")
ref: Optional[str] = Field(description="Other message referenced by this one")
channel: str = Field(description="The channel where the POST message was published")
channel: Optional[str] = Field(description="The channel where the POST message was published")
created: datetime = Field(description="The time when the POST message was created")
last_updated: datetime = Field(
description="The time when the POST message was last updated"
)

class Config:
allow_extra = False


class PostsResponse(PaginationResponse):
"""Response from an Aleph node API on the path /api/v0/posts.json"""
Expand Down
18 changes: 11 additions & 7 deletions src/aleph/sdk/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,27 @@ def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]):
if isinstance(messages, typing.get_args(AlephMessage)):
messages = [messages]

messages = list(messages)

message_data = (message_to_model(message) for message in messages)
MessageModel.insert_many(message_data).on_conflict_replace().execute()

# Add posts and their amends to the PostModel
post_data = []
amend_messages = []
for message in messages:
if message.item_type != MessageType.post:
if message.type != MessageType.post.value:
continue
if message.content.type == "amend":
amend_messages.append(message)
else:
post = message_to_post(message).dict()
post_data.append(post)
# Check if we can now add any amend messages that had missing refs
if message.item_hash in self.missing_posts:
amend_messages += self.missing_posts.pop(message.item_hash)
continue
post = message_to_post(message).dict()
post["chain"] = message.chain.value
post["tags"] = message.content.content.get("tags", None)
post_data.append(post)
# Check if we can now add any amend messages that had missing refs
if message.item_hash in self.missing_posts:
amend_messages += self.missing_posts.pop(message.item_hash)

PostModel.insert_many(post_data).on_conflict_replace().execute()

Expand Down
2 changes: 1 addition & 1 deletion src/aleph/sdk/node/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def message_to_post(message: PostMessage) -> Post:
"ref": message.content.ref if hasattr(message.content, "ref") else None,
"channel": message.channel,
"created": datetime.fromtimestamp(message.time),
"last_updated": datetime.fromtimestamp(message.time),
"last_updated": datetime.fromtimestamp(message.time)
}
)

Expand Down
9 changes: 5 additions & 4 deletions tests/unit/test_node_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from aleph.sdk.chains.ethereum import get_fallback_account
from aleph.sdk.exceptions import MessageNotFoundError
from aleph.sdk.node import MessageCache
from aleph.sdk.node import MessageCache, message_to_post


@pytest.mark.asyncio
Expand Down Expand Up @@ -137,7 +137,7 @@ def class_teardown(self):
@pytest.mark.asyncio
async def test_addresses(self):
items = (await self.cache.get_posts(addresses=[self.messages[1].sender])).posts
assert items[0] == self.messages[1]
assert items[0] == message_to_post(self.messages[1])

@pytest.mark.asyncio
async def test_tags(self):
Expand All @@ -153,15 +153,16 @@ async def test_types(self):

@pytest.mark.asyncio
async def test_channels(self):
print(self.messages[1])
assert (await self.cache.get_posts(channels=[self.messages[1].channel])).posts[
0
] == self.messages[1]
] == message_to_post(self.messages[1])

@pytest.mark.asyncio
async def test_chains(self):
assert (await self.cache.get_posts(chains=[self.messages[1].chain])).posts[
0
] == self.messages[1]
] == message_to_post(self.messages[1])


@pytest.mark.asyncio
Expand Down

0 comments on commit 6bf3144

Please sign in to comment.