diff --git a/chatbot.py b/chatbot.py index 551d892..007e564 100644 --- a/chatbot.py +++ b/chatbot.py @@ -53,9 +53,10 @@ def cdc(*args, **kwargs): mattermost_scheme = os.getenv("MATTERMOST_SCHEME", "https") mattermost_port = int(os.getenv("MATTERMOST_PORT", "443")) mattermost_basepath = os.getenv("MATTERMOST_BASEPATH", "/api/v4") +# pylint: disable=invalid-envvar-default mattermost_cert_verify = os.getenv( "MATTERMOST_CERT_VERIFY", True -) # pylint: disable=invalid-envvar-default +) mattermost_token = os.getenv("MATTERMOST_TOKEN", "") mattermost_ignore_sender_id = os.getenv("MATTERMOST_IGNORE_SENDER_ID", "") mattermost_username = os.getenv("MATTERMOST_USERNAME", "") @@ -350,6 +351,42 @@ def extract_post_data(post, event_data): return message, channel_id, sender_name, root_id, post_id, channel_display_name +def get_thread_posts(root_id, post_id): + messages = [] + chatbot_invoked = False + + thread = driver.posts.get_thread(root_id) + + # Sort the thread posts based on their create_at timestamp as the "order" prop is not suitable for this + sorted_posts = sorted(thread["posts"].values(), key=lambda x: x["create_at"]) + for thread_post in sorted_posts: + # We ignore our own post here as we might need to fetch/extract some content later. Refactor this as we want to cache results anyway and grab all URL contents, even from thread posts + if thread_post["id"] != post_id: + thread_sender_name = get_username_from_user_id(thread_post["user_id"]) + thread_message = thread_post["message"] + role = ( + "assistant" + if thread_post["user_id"] == driver.client.userid + else "user" + ) + messages.append( + { + "role": role, + "content": [ + { + "type": "text", + "text": f"[CONTEXT, from:{thread_sender_name}] {thread_message}", + } + ], + } + ) + + if role == "assistant": + chatbot_invoked = True + + return messages, chatbot_invoked + + async def message_handler(event): try: event_data = json.loads(event) @@ -371,43 +408,15 @@ async def message_handler(event): ) try: - # Retrieve the thread context messages = [] chatbot_invoked = False + + # Retrieve the thread context if root_id: - thread = driver.posts.get_thread(root_id) - # Sort the thread posts based on their create_at timestamp - sorted_posts = sorted( - thread["posts"].values(), key=lambda x: x["create_at"] + thread_messages, chatbot_invoked = get_thread_posts( + root_id, post_id ) - for thread_post in sorted_posts: - if thread_post["id"] != post_id: - thread_sender_name = get_username_from_user_id( - thread_post["user_id"] - ) - thread_message = thread_post["message"] - role = ( - "assistant" - if thread_post["user_id"] == driver.client.userid - else "user" - ) - messages.append( - { - "role": role, - "content": [ - { - "type": "text", - "text": f"[CONTEXT, from:{thread_sender_name}] {thread_message}", - } - ], - } - ) - - if role == "assistant": - chatbot_invoked = True - else: - # If the message is not part of a thread, reply to it to create a new thread - root_id = post["id"] + messages.extend(thread_messages) # Add the current message to the messages array if "@chatbot" is mentioned, the chatbot has already been invoked in the thread or its a DM if ( @@ -618,7 +627,9 @@ async def message_handler(event): message, messages, channel_id, - root_id, + ( + post_id if not root_id else root_id + ), # If the message is not part of a thread, reply to it to create a new thread sender_name, links, )