Skip to content

Commit

Permalink
Merge pull request #565 from cheshire-cat-ai/back_to_threadpool
Browse files Browse the repository at this point in the history
Back to sync agent
  • Loading branch information
pieroit authored Nov 14, 2023
2 parents 0986789 + 50a5275 commit 1979671
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
18 changes: 9 additions & 9 deletions core/cat/looking_glass/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, cat):
self.cat = cat


async def execute_tool_agent(self, agent_input, allowed_tools):
def execute_tool_agent(self, agent_input, allowed_tools):

allowed_tools_names = [t.name for t in allowed_tools]
# TODO: dynamic input_variables as in the main prompt
Expand Down Expand Up @@ -69,11 +69,11 @@ async def execute_tool_agent(self, agent_input, allowed_tools):
verbose=True
)

out = await agent_executor.acall(agent_input)
out = agent_executor(agent_input)
return out


async def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix, working_memory: WorkingMemory):
def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix, working_memory: WorkingMemory):

input_variables = [i for i in agent_input.keys() if i in prompt_prefix + prompt_suffix]
# memory chain (second step)
Expand All @@ -88,13 +88,13 @@ async def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix,
verbose=True
)

out = await memory_chain.acall(agent_input, callbacks=[NewTokenHandler(self.cat, working_memory)])
out = memory_chain(agent_input, callbacks=[NewTokenHandler(self.cat, working_memory)])
out["output"] = out["text"]
del out["text"]
return out


async def execute_agent(self, working_memory):
def execute_agent(self, working_memory):
"""Instantiate the Agent with tools.
The method formats the main prompt and gather the allowed tools. It also instantiates a conversational Agent
Expand Down Expand Up @@ -134,11 +134,11 @@ async def execute_agent(self, working_memory):
log.debug(f"{len(allowed_tools)} allowed tools retrived.")

try:
tools_result = await self.execute_tool_agent(agent_input, allowed_tools)
tools_result = self.execute_tool_agent(agent_input, allowed_tools)

# If tools_result["output"] is None the LLM has used the fake tool none_of_the_others
# so no relevant information has been obtained from the tools.
if tools_result["output"] != None:
if tools_result["output"] is not None:

# Extract of intermediate steps in the format ((tool_name, tool_input), output)
used_tools = list(map(lambda x:((x[0].tool, x[0].tool_input), x[1]), tools_result["intermediate_steps"]))
Expand All @@ -160,7 +160,7 @@ async def execute_agent(self, working_memory):
agent_input["tools_output"] = "## Tools output: \n" + tools_result["output"] if tools_result["output"] else ""

# Execute the memory chain
out = await self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, working_memory)
out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, working_memory)

# If some tools are used the intermediate step are added to the agent output
out["intermediate_steps"] = used_tools
Expand All @@ -177,7 +177,7 @@ async def execute_agent(self, working_memory):
#Adding the tools_output key in agent input, needed by the memory chain
agent_input["tools_output"] = ""
# Execute the memory chain
out = await self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, working_memory)
out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, working_memory)

return out

Expand Down
6 changes: 3 additions & 3 deletions core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification", wo
)
)

async def __call__(self, user_message_json):
def __call__(self, user_message_json):
"""Call the Cat instance.
This method is called on the user's message received from the client.
Expand Down Expand Up @@ -456,7 +456,7 @@ async def __call__(self, user_message_json):

# reply with agent
try:
cat_message = await self.agent_manager.execute_agent(user_working_memory)
cat_message = self.agent_manager.execute_agent(user_working_memory)
except Exception as e:
# This error happens when the LLM
# does not respect prompt instructions.
Expand All @@ -465,7 +465,7 @@ async def __call__(self, user_message_json):
error_description = str(e)

log.error(error_description)
if not "Could not parse LLM output: `" in error_description:
if "Could not parse LLM output: `" not in error_description:
raise e

unparsable_llm_output = error_description.replace("Could not parse LLM output: `", "").replace("`", "")
Expand Down
2 changes: 1 addition & 1 deletion core/cat/routes/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def receive_message(ccat: CheshireCat, user_id: str = "user"):
user_message["user_id"] = user_id

# Run the `ccat` object's method in a threadpool since it might be a CPU-bound operation.
cat_message = await ccat(user_message)
cat_message = await run_in_threadpool(ccat, user_message)

# Send the response message back to the user.
await manager.send_personal_message(cat_message, user_id)
Expand Down

0 comments on commit 1979671

Please sign in to comment.