Skip to content

Commit

Permalink
Refactor thread posts code into get_thread_posts
Browse files Browse the repository at this point in the history
  • Loading branch information
elegiggle committed Apr 17, 2024
1 parent d82dc0b commit 422091b
Showing 1 changed file with 46 additions and 35 deletions.
81 changes: 46 additions & 35 deletions chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def cdc(*args, **kwargs):
mattermost_scheme = os.getenv("MATTERMOST_SCHEME", "https")
mattermost_port = int(os.getenv("MATTERMOST_PORT", "443"))
mattermost_basepath = os.getenv("MATTERMOST_BASEPATH", "/api/v4")
# pylint: disable=invalid-envvar-default
mattermost_cert_verify = os.getenv(
"MATTERMOST_CERT_VERIFY", True
) # pylint: disable=invalid-envvar-default
)
mattermost_token = os.getenv("MATTERMOST_TOKEN", "")
mattermost_ignore_sender_id = os.getenv("MATTERMOST_IGNORE_SENDER_ID", "")
mattermost_username = os.getenv("MATTERMOST_USERNAME", "")
Expand Down Expand Up @@ -350,6 +351,42 @@ def extract_post_data(post, event_data):
return message, channel_id, sender_name, root_id, post_id, channel_display_name


def get_thread_posts(root_id, post_id):
messages = []
chatbot_invoked = False

thread = driver.posts.get_thread(root_id)

# Sort the thread posts based on their create_at timestamp as the "order" prop is not suitable for this
sorted_posts = sorted(thread["posts"].values(), key=lambda x: x["create_at"])
for thread_post in sorted_posts:
# We ignore our own post here as we might need to fetch/extract some content later. Refactor this as we want to cache results anyway and grab all URL contents, even from thread posts
if thread_post["id"] != post_id:
thread_sender_name = get_username_from_user_id(thread_post["user_id"])
thread_message = thread_post["message"]
role = (
"assistant"
if thread_post["user_id"] == driver.client.userid
else "user"
)
messages.append(
{
"role": role,
"content": [
{
"type": "text",
"text": f"[CONTEXT, from:{thread_sender_name}] {thread_message}",
}
],
}
)

if role == "assistant":
chatbot_invoked = True

return messages, chatbot_invoked


async def message_handler(event):
try:
event_data = json.loads(event)
Expand All @@ -371,43 +408,15 @@ async def message_handler(event):
)

try:
# Retrieve the thread context
messages = []
chatbot_invoked = False

# Retrieve the thread context
if root_id:
thread = driver.posts.get_thread(root_id)
# Sort the thread posts based on their create_at timestamp
sorted_posts = sorted(
thread["posts"].values(), key=lambda x: x["create_at"]
thread_messages, chatbot_invoked = get_thread_posts(
root_id, post_id
)
for thread_post in sorted_posts:
if thread_post["id"] != post_id:
thread_sender_name = get_username_from_user_id(
thread_post["user_id"]
)
thread_message = thread_post["message"]
role = (
"assistant"
if thread_post["user_id"] == driver.client.userid
else "user"
)
messages.append(
{
"role": role,
"content": [
{
"type": "text",
"text": f"[CONTEXT, from:{thread_sender_name}] {thread_message}",
}
],
}
)

if role == "assistant":
chatbot_invoked = True
else:
# If the message is not part of a thread, reply to it to create a new thread
root_id = post["id"]
messages.extend(thread_messages)

# Add the current message to the messages array if "@chatbot" is mentioned, the chatbot has already been invoked in the thread or its a DM
if (
Expand Down Expand Up @@ -618,7 +627,9 @@ async def message_handler(event):
message,
messages,
channel_id,
root_id,
(
post_id if not root_id else root_id
), # If the message is not part of a thread, reply to it to create a new thread
sender_name,
links,
)
Expand Down

0 comments on commit 422091b

Please sign in to comment.