Skip to content

Commit

Permalink
Merge pull request #441 from zAlweNy26/refactored-ws-messages
Browse files Browse the repository at this point in the history
Refactored WS messages functionality
  • Loading branch information
pieroit authored Aug 29, 2023
2 parents dc2e9be + 2f75f19 commit c6f41a8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 35 deletions.
32 changes: 28 additions & 4 deletions core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from copy import deepcopy
import traceback

from typing import Literal, get_args
import langchain
import os
from cat.log import log
Expand All @@ -12,6 +12,7 @@
from cat.memory.long_term_memory import LongTermMemory
from cat.looking_glass.agent_manager import AgentManager

MSG_TYPES = Literal["notification", "chat"]

# main class
class CheshireCat:
Expand All @@ -21,7 +22,7 @@ class CheshireCat:
Attributes
----------
web_socket_notifications : list
ws_messages : list
List of notifications to be sent to the frontend.
"""
Expand All @@ -33,7 +34,7 @@ def __init__(self):
"""

# bootstrap the cat!
# reinstantiate MadHatter (reloads all plugins' hooks and tools)
# reinstantiate MadHatter (reloads all plugins' hooks and tools)
self.mad_hatter = MadHatter(self)

# allows plugins to do something before cat components are loaded
Expand All @@ -59,7 +60,7 @@ def __init__(self):

# queue of cat messages not directly related to last user input
# i.e. finished uploading a file
self.web_socket_notifications = []
self.ws_messages = []

def load_natural_language(self):
"""Load Natural Language related objects.
Expand Down Expand Up @@ -281,6 +282,29 @@ def store_new_message_in_working_memory(self, user_message_json):

self.working_memory["user_message_json"]["prompt_settings"] = prompt_settings

def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"):
"""Send a message via websocket.
This method is useful for sending a message via websocket directly without passing through the LLM
Parameters
----------
type : str
The type of the message. Should be either `notification` or `chat`
content : str
The content of the message.
"""

options = get_args(MSG_TYPES)

if msg_type not in options:
raise ValueError(f"The message type `{msg_type}` is not valid. Valid types: {', '.join(options)}")

self.ws_messages.append({
"type": msg_type,
"content": content
})

def get_base_url(self):
"""Allows the Cat expose the base url."""
secure = os.getenv('CORE_USE_SECURE_PROTOCOLS', '')
Expand Down
29 changes: 6 additions & 23 deletions core/cat/rabbit_hole.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,36 +217,19 @@ def file_to_docs(
blob = Blob(data=file_bytes,
mimetype=content_type,
source=source).from_data(data=file_bytes,
mime_type=content_type)
mime_type=content_type,
path=source)
# Parser based on the mime type
parser = MimeTypeBasedParser(handlers=self.file_handlers)

# Parse the text
self.send_rabbit_thought("I'm parsing the content. Big content could require some minutes...")
self.cat.send_ws_message("I'm parsing the content. Big content could require some minutes...")
text = parser.parse(blob)

self.send_rabbit_thought(f"Parsing completed. Now let's go with reading process...")
self.cat.send_ws_message(f"Parsing completed. Now let's go with reading process...")
docs = self.split_text(text, chunk_size, chunk_overlap)
return docs

def send_rabbit_thought(self, thought):
"""Append a message to the notification list.
This method receives a string and creates the message to append to the list of notifications.
Parameters
----------
thought : str
Text of the message to append to the notification list.
"""

self.cat.web_socket_notifications.append({
"error": False,
"type": "notification",
"content": thought,
"why": {},
})

def store_documents(self, docs: List[Document], source: str) -> None:
"""Add documents to the Cat's declarative memory.
Expand Down Expand Up @@ -284,7 +267,7 @@ def store_documents(self, docs: List[Document], source: str) -> None:
if time.time() - time_last_notification > time_interval:
time_last_notification = time.time()
perc_read = int(d / len(docs) * 100)
self.send_rabbit_thought(f"Read {perc_read}% of {source}")
self.cat.send_ws_message(f"Read {perc_read}% of {source}")

doc.metadata["source"] = source
doc.metadata["when"] = time.time()
Expand All @@ -310,7 +293,7 @@ def store_documents(self, docs: List[Document], source: str) -> None:
finished_reading_message = f"Finished reading {source}, " \
f"I made {len(docs)} thoughts on it."

self.send_rabbit_thought(finished_reading_message)
self.cat.send_ws_message(finished_reading_message)

print(f"\n\nDone uploading {source}")

Expand Down
16 changes: 8 additions & 8 deletions core/cat/routes/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
router = APIRouter()

# This constant sets the interval (in seconds) at which the system checks for notifications.
NOTIFICATION_CHECK_INTERVAL = 1 # seconds
QUEUE_CHECK_INTERVAL = 1 # seconds


class ConnectionManager:
Expand Down Expand Up @@ -64,24 +64,24 @@ async def receive_message(websocket: WebSocket, ccat: object):
await manager.send_personal_message(cat_message, websocket)


async def check_notification(websocket: WebSocket, ccat: object):
async def check_messages(websocket: WebSocket, ccat):
"""
Periodically check if there are any new notifications from the `ccat` object and send them to the user.
Periodically check if there are any new notifications from the `ccat` instance and send them to the user.
"""
while True:
if ccat.web_socket_notifications:
if ccat.ws_messages:
# extract from FIFO list websocket notification
notification = ccat.web_socket_notifications.pop(0)
notification = ccat.ws_messages.pop(0)
await manager.send_personal_message(notification, websocket)

# Sleep for the specified interval before checking for notifications again.
await asyncio.sleep(NOTIFICATION_CHECK_INTERVAL)
await asyncio.sleep(QUEUE_CHECK_INTERVAL)


@router.websocket_route("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""
Endpoint to handle incoming WebSocket connections, process messages, and check for notifications.
Endpoint to handle incoming WebSocket connections, process messages, and check for messages.
"""

# Retrieve the `ccat` instance from the application's state.
Expand All @@ -94,7 +94,7 @@ async def websocket_endpoint(websocket: WebSocket):
# Process messages and check for notifications concurrently.
await asyncio.gather(
receive_message(websocket, ccat),
check_notification(websocket, ccat)
check_messages(websocket, ccat)
)
except WebSocketDisconnect:
# Handle the event where the user disconnects their WebSocket.
Expand Down

0 comments on commit c6f41a8

Please sign in to comment.