Skip to content

Commit

Permalink
Merge branch 'Pingdred-carbonara' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
pieroit committed Apr 22, 2024
2 parents 59c1391 + 9c24f4f commit c100773
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 65 deletions.
29 changes: 29 additions & 0 deletions core/cat/convo/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@





from typing import List, Dict
from cat.utils import BaseModelDict

#class WorkingMemory(BaseModelDict):
# history : List = []


class MessageWhy(BaseModelDict):
input: str
intermediate_steps: List
memory: dict


class CatMessage(BaseModelDict):
type: str
content: str
user_id: str
why: MessageWhy


class UserMessage(BaseModelDict):
text: str
user_id: str

6 changes: 3 additions & 3 deletions core/cat/experimental/form/cat_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def submit(self, form_data) -> str:
def confirm(self) -> bool:

# Get user message
user_message = self.cat.working_memory["user_message_json"]["text"]
user_message = self.cat.working_memory.user_message_json.text

# Confirm prompt
confirm_prompt = \
Expand Down Expand Up @@ -79,7 +79,7 @@ def confirm(self) -> bool:
def check_exit_intent(self) -> bool:

# Get user message
user_message = self.cat.working_memory["user_message_json"]["text"]
user_message = self.cat.working_memory.user_message_json.text

# Stop examples
stop_examples = """
Expand Down Expand Up @@ -117,7 +117,7 @@ def check_exit_intent(self) -> bool:
def next(self):

# could we enrich prompt completion with episodic/declarative memories?
#self.cat.working_memory["episodic_memories"] = []
#self.cat.working_memory.episodic_memories = []

if self.check_exit_intent():
self._state = CatFormState.CLOSED
Expand Down
17 changes: 8 additions & 9 deletions core/cat/looking_glass/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def execute_procedures_agent(self, agent_input, stray):

# gather recalled procedures
recalled_procedures_names = set()
for p in stray.working_memory["procedural_memories"]:
for p in stray.working_memory.procedural_memories:
procedure = p[0]
if procedure.metadata["type"] in ["tool","form"] and procedure.metadata["trigger_type"] in ["description", "start_example"]:
recalled_procedures_names.add(procedure.metadata["source"])
Expand Down Expand Up @@ -128,7 +128,7 @@ async def execute_procedures_agent(self, agent_input, stray):
if "form" in out.keys():
FormClass = allowed_procedures.get(out["form"], None)
f = FormClass(stray)
stray.working_memory["active_form"] = f
stray.working_memory.active_form = f
# let the form reply directly
out = f.next()
out["return_direct"] = True
Expand All @@ -137,12 +137,11 @@ async def execute_procedures_agent(self, agent_input, stray):

async def execute_form_agent(self, stray):

active_form = stray.working_memory.get("active_form", None)
active_form = stray.working_memory.active_form
if active_form:
log.warning(active_form._state)
# closing form if state is closed
if active_form._state == CatFormState.CLOSED:
del stray.working_memory["active_form"]
stray.working_memory.active_form = None
else:
# continue form
return active_form.next()
Expand Down Expand Up @@ -201,7 +200,7 @@ async def execute_agent(self, stray):

# Select and run useful procedures
intermediate_steps = []
procedural_memories = stray.working_memory["procedural_memories"]
procedural_memories = stray.working_memory.procedural_memories
if len(procedural_memories) > 0:

log.debug(f"Procedural memories retrived: {len(procedural_memories)}.")
Expand Down Expand Up @@ -272,17 +271,17 @@ def format_agent_input(self, stray):

# format memories to be inserted in the prompt
episodic_memory_formatted_content = self.agent_prompt_episodic_memories(
stray.working_memory["episodic_memories"]
stray.working_memory.episodic_memories
)
declarative_memory_formatted_content = self.agent_prompt_declarative_memories(
stray.working_memory["declarative_memories"]
stray.working_memory.declarative_memories
)

# format conversation history to be inserted in the prompt
conversation_history_formatted_content = stray.stringify_chat_history()

return {
"input": stray.working_memory["user_message_json"]["text"], # TODO: deprecate, since it is included in chat history
"input": stray.working_memory.user_message_json.text, # TODO: deprecate, since it is included in chat history
"episodic_memory": episodic_memory_formatted_content,
"declarative_memory": declarative_memory_formatted_content,
"chat_history": conversation_history_formatted_content,
Expand Down
105 changes: 73 additions & 32 deletions core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import asyncio
import traceback
from typing import Literal, get_args, List, Dict
from typing import Literal, get_args, List, Dict, Union

from langchain.docstore.document import Document
from langchain_core.language_models.chat_models import BaseChatModel
Expand All @@ -13,9 +13,11 @@
from cat.looking_glass.cheshire_cat import CheshireCat
from cat.looking_glass.callbacks import NewTokenHandler
from cat.memory.working_memory import WorkingMemory
from cat.convo.messages import CatMessage, UserMessage, MessageWhy

MSG_TYPES = Literal["notification", "chat", "error", "chat_token"]


# The Stray cat goes around tools and hook, making troubles
class StrayCat:
"""User/session based object containing working memory and a few utility pointers"""
Expand Down Expand Up @@ -70,7 +72,7 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES="notification"):
{
"type": msg_type,
"name": "GenericError",
"description": content
"description": str(content)
}
)
else:
Expand All @@ -81,6 +83,36 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES="notification"):
"content": content
}
)

def send_chat_message(self, message: Union[str, CatMessage], save=False):
if isinstance(message, str):
message = CatMessage(
msg_type="chat",
user_id=self.user_id,
content=message,
)

if save:
self.working_memory.update_conversation_history(who="AI", message=message["content"], why=message["why"])

self.__main_loop.call_soon_threadsafe(
self.__ws_messages.put_nowait,
{
**message.to_dict()
}
)

def send_notification(self, content: str):
self.send_ws_message(
content=content,
msg_type="notification"
)

def send_error(self, error):
self.send_ws_message(
content=error,
msg_type="error"
)

def recall_relevant_memories_to_working_memory(self, query=None):
"""Retrieve context from memory.
Expand Down Expand Up @@ -112,15 +144,15 @@ def recall_relevant_memories_to_working_memory(self, query=None):

if query is None:
# If query is not provided, use the user's message as the query
recall_query = self.working_memory["user_message_json"]["text"]
recall_query = self.working_memory.user_message_json.text

# We may want to search in memory
recall_query = self.mad_hatter.execute_hook("cat_recall_query", recall_query, cat=self)
log.info(f"Recall query: '{recall_query}'")

# Embed recall query
recall_query_embedding = self.embedder.embed_query(recall_query)
self.working_memory["recall_query"] = recall_query
self.working_memory.recall_query = recall_query

# hook to do something before recall begins
self.mad_hatter.execute_hook("before_cat_recalls_memories", cat=self)
Expand Down Expand Up @@ -167,7 +199,7 @@ def recall_relevant_memories_to_working_memory(self, query=None):
vector_memory = getattr(self.memory.vectors, memory_type)
memories = vector_memory.recall_memories_from_embedding(**config)

self.working_memory[memory_key] = memories
setattr(self.working_memory, memory_key, memories) # self.working_memory.procedural_memories = ...

# hook to modify/enrich retrieved memories
self.mad_hatter.execute_hook("after_cat_recalls_memories", cat=self)
Expand Down Expand Up @@ -202,14 +234,14 @@ def llm(self, prompt: str, stream: bool = False) -> str:
if isinstance(self._llm, BaseChatModel):
return self._llm.call_as_llm(prompt, callbacks=callbacks)

async def __call__(self, message):
async def __call__(self, message_dict):
"""Call the Cat instance.
This method is called on the user's message received from the client.
Parameters
----------
message : dict
message_dict : dict
Dictionary received from the Websocket client.
save : bool, optional
If True, the user's message is stored in the chat history. Default is True.
Expand All @@ -226,20 +258,23 @@ async def __call__(self, message):
answer. This is formatted in a dictionary to be sent as a JSON via Websocket to the client.
"""
log.info(message)

# Parse websocket message into UserMessage obj
user_message = UserMessage.model_validate(message_dict)
log.info(user_message)

# set a few easy access variables
self.working_memory["user_message_json"] = message
self.working_memory.user_message_json = user_message

# hook to modify/enrich user input
self.working_memory["user_message_json"] = self.mad_hatter.execute_hook(
self.working_memory.user_message_json = self.mad_hatter.execute_hook(
"before_cat_reads_message",
self.working_memory["user_message_json"],
self.working_memory.user_message_json,
cat=self
)

# text of latest Human message
user_message_text = self.working_memory["user_message_json"]["text"]
user_message_text = self.working_memory.user_message_json.text

# update conversation history (Human turn)
self.working_memory.update_conversation_history(who="Human", message=user_message_text)
Expand Down Expand Up @@ -309,31 +344,37 @@ async def __call__(self, message):

# build data structure for output (response and why with memories)
# TODO: these 3 lines are a mess, simplify
episodic_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory["episodic_memories"]]
declarative_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory["declarative_memories"]]
procedural_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory["procedural_memories"]]
episodic_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory.episodic_memories]
declarative_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory.declarative_memories]
procedural_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory.procedural_memories]

# why this response?
why = MessageWhy(
input=cat_message.get("input", ""),
intermediate_steps=cat_message.get("intermediate_steps", []),
memory={
"episodic": episodic_report,
"declarative": declarative_report,
"procedural": procedural_report,
}
)

final_output = {
"type": "chat",
"user_id": self.user_id,
"content": str(cat_message.get("output")),
"why": {
"input": cat_message.get("input"),
"intermediate_steps": cat_message.get("intermediate_steps", []),
"memory": {
"episodic": episodic_report,
"declarative": declarative_report,
"procedural": procedural_report,
},
},
}
# prepare final cat message
final_output = CatMessage(
type="chat",
user_id=self.user_id,
content=str(cat_message.get("output")),
why=why
)

# run message through plugins
final_output = self.mad_hatter.execute_hook("before_cat_sends_message", final_output, cat=self)

# update conversation history (AI turn)
self.working_memory.update_conversation_history(who="AI", message=final_output["content"], why=final_output["why"])
self.working_memory.update_conversation_history(who="AI", message=final_output.content, why=final_output.why)

return final_output
# send message back to client
return final_output.dict()

def run(self, user_message_json):
return self.loop.run_until_complete(
Expand Down Expand Up @@ -424,7 +465,7 @@ def stringify_chat_history(self, latest_n: int = 5) -> str:
"""

history = self.working_memory["history"][-latest_n:]
history = self.working_memory.history[-latest_n:]

history_string = ""
for turn in history:
Expand Down
2 changes: 1 addition & 1 deletion core/cat/mad_hatter/core_plugin/hooks/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def agent_fast_reply(fast_reply, cat) -> Union[None, Dict]:
Example 2: don't remember (no uploaded documents about topic)
```python
num_declarative_memories = len( cat.working_memory["declarative_memories"] )
num_declarative_memories = len( cat.working_memory.declarative_memories )
if num_declarative_memories == 0:
return {
"output": "Sorry, I have no memories about that."
Expand Down
7 changes: 1 addition & 6 deletions core/cat/mad_hatter/core_plugin/hooks/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def cat_recall_query(user_message: str, cat) -> str:
Returns
-------
Edited string to be used for context retrieval in memory. The returned string is further stored in the
Working Memory at `cat.working_memory["memory_query"]`.
Working Memory at `cat.working_memory.recall_query`.
Notes
-----
Expand All @@ -128,11 +128,6 @@ def cat_recall_query(user_message: str, cat) -> str:
arXiv preprint arXiv:2212.10496.
"""
# example 1: HyDE embedding
# return cat.hypothetis_chain.run(user_message)

# example 2: Condense recent conversation
# TODO

# here we just return the latest user message as is
return user_message
Expand Down
17 changes: 11 additions & 6 deletions core/cat/memory/working_memory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@

import time
from typing import List
from cat.utils import BaseModelDict
from cat.convo.messages import UserMessage
from cat.experimental.form import CatForm


class WorkingMemory(dict):
class WorkingMemory(BaseModelDict):
"""Cat's volatile memory.
Handy class that behaves like a `dict` to store temporary custom data.
Expand All @@ -18,10 +22,11 @@ class WorkingMemory(dict):
the conversation turns between the Human and the AI.
"""

def __init__(self):
# The constructor instantiates a `dict` with a 'history' key to store conversation history
# and the asyncio queue to manage the session notifications
super().__init__(history=[])
# stores conversation history
history: List = []
recall_query: str = ""
user_message_json : None | UserMessage = None
active_form: None | CatForm = None

def update_conversation_history(self, who, message, why={}):
"""Update the conversation history.
Expand All @@ -37,6 +42,6 @@ def update_conversation_history(self, who, message, why={}):
"""
# append latest message in conversation
self["history"].append({"who": who, "message": message, "why": why, "when": time.time()})
self.history.append({"who": who, "message": message, "why": why, "when": time.time()})


Loading

0 comments on commit c100773

Please sign in to comment.