Skip to content

Commit

Permalink
fix/test: ensure polls are being correctly processed (#1714)
Browse files Browse the repository at this point in the history
* fix: use answer_id from data, not options

* fix: correctly deserialize question for polls

* ci: add tests for polls

* test: make poll dict test more resilient
  • Loading branch information
AstreaTSS authored Jul 19, 2024
1 parent 42df28b commit fcd8efe
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 4 deletions.
4 changes: 2 additions & 2 deletions interactions/api/events/processors/message_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def _on_raw_message_poll_vote_add(self, event: "RawGatewayEvent") -> None:
event.data["channel_id"],
event.data["message_id"],
event.data["user_id"],
event.data["option"],
event.data["answer_id"],
)
)

Expand All @@ -118,6 +118,6 @@ async def _on_raw_message_poll_vote_remove(self, event: "RawGatewayEvent") -> No
event.data["channel_id"],
event.data["message_id"],
event.data["user_id"],
event.data["option"],
event.data["answer_id"],
)
)
2 changes: 1 addition & 1 deletion interactions/models/discord/poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class PollResults(DictSerializationMixin):

@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class Poll(DictSerializationMixin):
question: PollMedia = attrs.field(repr=False)
question: PollMedia = attrs.field(repr=False, converter=PollMedia.from_dict)
"""The question of the poll. Only text media is supported."""
answers: list[PollAnswer] = attrs.field(repr=False, factory=list, converter=PollAnswer.from_list)
"""Each of the answers available in the poll, up to 10."""
Expand Down
93 changes: 92 additions & 1 deletion tests/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from asyncio import AbstractEventLoop
from contextlib import suppress
from datetime import datetime
from datetime import datetime, timedelta

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -33,6 +33,8 @@
ParagraphText,
Message,
GuildVoice,
Poll,
PollMedia,
)
from interactions.models.discord.asset import Asset
from interactions.models.discord.components import ActionRow, Button, StringSelectMenu
Expand Down Expand Up @@ -432,6 +434,95 @@ async def test_components(bot: Client, channel: GuildText) -> None:
await thread.delete()


@pytest.mark.asyncio
async def test_polls(bot: Client, channel: GuildText) -> None:
msg = await channel.send("Polls Tests")
thread = await msg.create_thread("Test Thread")

try:
poll_1 = Poll.create("Test Poll", duration=1, answers=["Answer 1", "Answer 2"])
test_data_1 = {
"question": {"text": "Test Poll"},
"layout_type": 1,
"duration": 1,
"allow_multiselect": False,
"answers": [{"poll_media": {"text": "Answer 1"}}, {"poll_media": {"text": "Answer 2"}}],
}
poll_1_dict = poll_1.to_dict()
for key in poll_1_dict.keys():
assert poll_1_dict[key] == test_data_1[key]

msg_1 = await thread.send(poll=poll_1)

assert msg_1.poll is not None
assert msg_1.poll.question.to_dict() == PollMedia(text="Test Poll").to_dict()
assert msg_1.poll.expiry <= msg_1.created_at + timedelta(hours=1, minutes=1)
poll_1_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_1.poll.answers]
assert poll_1_answer_medias == [
PollMedia.create(text="Answer 1").to_dict(),
PollMedia.create(text="Answer 2").to_dict(),
]

poll_2 = Poll.create("Test Poll 2", duration=1, allow_multiselect=True)
poll_2.add_answer("Answer 1")
poll_2.add_answer("Answer 2")
test_data_2 = {
"question": {"text": "Test Poll 2"},
"layout_type": 1,
"duration": 1,
"allow_multiselect": True,
"answers": [{"poll_media": {"text": "Answer 1"}}, {"poll_media": {"text": "Answer 2"}}],
}
poll_2_dict = poll_2.to_dict()
for key in poll_2_dict.keys():
assert poll_2_dict[key] == test_data_2[key]
msg_2 = await thread.send(poll=poll_2)

assert msg_2.poll is not None
assert msg_2.poll.question.to_dict() == PollMedia(text="Test Poll 2").to_dict()
assert msg_2.poll.expiry <= msg_2.created_at + timedelta(hours=1, minutes=1)
assert msg_2.poll.allow_multiselect
poll_2_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_2.poll.answers]
assert poll_2_answer_medias == [
PollMedia.create(text="Answer 1").to_dict(),
PollMedia.create(text="Answer 2").to_dict(),
]

poll_3 = Poll.create(
"Test Poll 3",
duration=1,
answers=[PollMedia.create(text="One", emoji="1️⃣"), PollMedia.create(text="Two", emoji="2️⃣")],
)
test_data_3 = {
"question": {"text": "Test Poll 3"},
"layout_type": 1,
"duration": 1,
"allow_multiselect": False,
"answers": [
{"poll_media": {"text": "One", "emoji": {"name": "1️⃣", "animated": False}}},
{"poll_media": {"text": "Two", "emoji": {"name": "2️⃣", "animated": False}}},
],
}
poll_3_dict = poll_3.to_dict()
for key in poll_3_dict.keys():
assert poll_3_dict[key] == test_data_3[key]

msg_3 = await thread.send(poll=poll_3)

assert msg_3.poll is not None
assert msg_3.poll.question.to_dict() == PollMedia(text="Test Poll 3").to_dict()
assert msg_3.poll.expiry <= msg_3.created_at + timedelta(hours=1, minutes=1)
poll_3_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_3.poll.answers]
assert poll_3_answer_medias == [
PollMedia.create(text="One", emoji="1️⃣").to_dict(),
PollMedia.create(text="Two", emoji="2️⃣").to_dict(),
]

finally:
with suppress(interactions.errors.NotFound):
await thread.delete()


@pytest.mark.asyncio
async def test_webhooks(bot: Client, guild: Guild, channel: GuildText) -> None:
test_thread = await channel.create_thread("Test Thread")
Expand Down

0 comments on commit fcd8efe

Please sign in to comment.