Skip to content

Commit

Permalink
Merge pull request #539 from Pingdred/async_agent
Browse files Browse the repository at this point in the history
Async agent
  • Loading branch information
pieroit authored Nov 9, 2023
2 parents 239abc1 + e53f9fd commit 1578b7a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
4 changes: 4 additions & 0 deletions core/cat/factory/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def _llm_type(self):
def _call(self, prompt, stop=None):
return "AI: You did not configure a Language Model. " \
"Do it in the settings!"

async def _acall(self, prompt, stop=None):
return "AI: You did not configure a Language Model. " \
"Do it in the settings!"


# elaborated from
Expand Down
14 changes: 7 additions & 7 deletions core/cat/looking_glass/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, cat):
self.cat = cat


def execute_tool_agent(self, agent_input, allowed_tools):
async def execute_tool_agent(self, agent_input, allowed_tools):

allowed_tools_names = [t.name for t in allowed_tools]
# TODO: dynamic input_variables as in the main prompt
Expand Down Expand Up @@ -69,11 +69,11 @@ def execute_tool_agent(self, agent_input, allowed_tools):
verbose=True
)

out = agent_executor(agent_input)
out = await agent_executor.acall(agent_input)
return out


def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix, working_memory: WorkingMemory):
async def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix, working_memory: WorkingMemory):

input_variables = [i for i in agent_input.keys() if i in prompt_prefix + prompt_suffix]
# memory chain (second step)
Expand All @@ -88,13 +88,13 @@ def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix, workin
verbose=True
)

out = memory_chain(agent_input, callbacks=[NewTokenHandler(self.cat, working_memory)])
out = await memory_chain.acall(agent_input, callbacks=[NewTokenHandler(self.cat, working_memory)])
out["output"] = out["text"]
del out["text"]
return out


def execute_agent(self, working_memory):
async def execute_agent(self, working_memory):
"""Instantiate the Agent with tools.
The method formats the main prompt and gather the allowed tools. It also instantiates a conversational Agent
Expand Down Expand Up @@ -134,7 +134,7 @@ def execute_agent(self, working_memory):
log.debug(f"{len(allowed_tools)} allowed tools retrived.")

try:
tools_result = self.execute_tool_agent(agent_input, allowed_tools)
tools_result = await self.execute_tool_agent(agent_input, allowed_tools)

# If tools_result["output"] is None the LLM has used the fake tool none_of_the_others
# so no relevant information has been obtained from the tools.
Expand Down Expand Up @@ -177,7 +177,7 @@ def execute_agent(self, working_memory):
#Adding the tools_output key in agent input, needed by the memory chain
agent_input["tools_output"] = ""
# Execute the memory chain
out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, working_memory)
out = await self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, working_memory)

return out

Expand Down
4 changes: 2 additions & 2 deletions core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification", wo
)
)

def __call__(self, user_message_json):
async def __call__(self, user_message_json):
"""Call the Cat instance.
This method is called on the user's message received from the client.
Expand Down Expand Up @@ -456,7 +456,7 @@ def __call__(self, user_message_json):

# reply with agent
try:
cat_message = self.agent_manager.execute_agent(user_working_memory)
cat_message = await self.agent_manager.execute_agent(user_working_memory)
except Exception as e:
# This error happens when the LLM
# does not respect prompt instructions.
Expand Down
2 changes: 1 addition & 1 deletion core/cat/routes/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def receive_message(ccat: CheshireCat, user_id: str = "user"):
user_message["user_id"] = user_id

# Run the `ccat` object's method in a threadpool since it might be a CPU-bound operation.
cat_message = await run_in_threadpool(ccat, user_message)
cat_message = await ccat(user_message)

# Send the response message back to the user.
await manager.send_personal_message(cat_message, user_id)
Expand Down

0 comments on commit 1578b7a

Please sign in to comment.