diff --git a/core/cat/agents/base_agent.py b/core/cat/agents/base_agent.py index f94155ca..d9116d50 100644 --- a/core/cat/agents/base_agent.py +++ b/core/cat/agents/base_agent.py @@ -13,5 +13,5 @@ class AgentOutput(BaseModelDict): class BaseAgent(ABC): @abstractmethod - async def execute(*args, **kwargs) -> AgentOutput: + def execute(*args, **kwargs) -> AgentOutput: pass \ No newline at end of file diff --git a/core/cat/agents/form_agent.py b/core/cat/agents/form_agent.py index d393c8fd..8be84fd3 100644 --- a/core/cat/agents/form_agent.py +++ b/core/cat/agents/form_agent.py @@ -5,7 +5,7 @@ class FormAgent(BaseAgent): - async def execute(self, stray) -> AgentOutput: + def execute(self, stray) -> AgentOutput: # get active form from working memory active_form = stray.working_memory.active_form diff --git a/core/cat/agents/main_agent.py b/core/cat/agents/main_agent.py index 7ea5a022..e52389bb 100644 --- a/core/cat/agents/main_agent.py +++ b/core/cat/agents/main_agent.py @@ -26,7 +26,7 @@ def __init__(self): else: self.verbose = False - async def execute(self, stray) -> AgentOutput: + def execute(self, stray) -> AgentOutput: """Execute the agents. Returns @@ -66,7 +66,7 @@ async def execute(self, stray) -> AgentOutput: # run tools and forms procedures_agent = ProceduresAgent() - procedures_agent_out : AgentOutput = await procedures_agent.execute(stray) + procedures_agent_out : AgentOutput = procedures_agent.execute(stray) if procedures_agent_out.return_direct: return procedures_agent_out @@ -74,7 +74,7 @@ async def execute(self, stray) -> AgentOutput: # - no procedures were recalled or selected or # - procedures have all return_direct=False memory_agent = MemoryAgent() - memory_agent_out : AgentOutput = await memory_agent.execute( + memory_agent_out : AgentOutput = memory_agent.execute( # TODO: should all agents only receive stray? stray, prompt_prefix, prompt_suffix ) diff --git a/core/cat/agents/memory_agent.py b/core/cat/agents/memory_agent.py index 2308c483..22c34f95 100644 --- a/core/cat/agents/memory_agent.py +++ b/core/cat/agents/memory_agent.py @@ -11,7 +11,7 @@ class MemoryAgent(BaseAgent): - async def execute(self, stray, prompt_prefix, prompt_suffix) -> AgentOutput: + def execute(self, stray, prompt_prefix, prompt_suffix) -> AgentOutput: prompt_variables = stray.working_memory.agent_input.model_dump() sys_prompt = prompt_prefix + prompt_suffix diff --git a/core/cat/agents/procedures_agent.py b/core/cat/agents/procedures_agent.py index 86baa1b3..9c28a77a 100644 --- a/core/cat/agents/procedures_agent.py +++ b/core/cat/agents/procedures_agent.py @@ -24,10 +24,10 @@ class ProceduresAgent(BaseAgent): form_agent = FormAgent() allowed_procedures: Dict[str, CatTool | CatForm] = {} - async def execute(self, stray) -> AgentOutput: + def execute(self, stray) -> AgentOutput: # Run active form if present - form_output: AgentOutput = await self.form_agent.execute(stray) + form_output: AgentOutput = self.form_agent.execute(stray) if form_output.return_direct: return form_output @@ -38,7 +38,7 @@ async def execute(self, stray) -> AgentOutput: log.debug(f"Procedural memories retrived: {len(procedural_memories)}.") try: - procedures_result: AgentOutput = await self.execute_procedures(stray) + procedures_result: AgentOutput = self.execute_procedures(stray) if procedures_result.return_direct: # exit agent if a return_direct procedure was executed return procedures_result @@ -64,7 +64,7 @@ async def execute(self, stray) -> AgentOutput: return AgentOutput() - async def execute_procedures(self, stray): + def execute_procedures(self, stray): # using some hooks mad_hatter = MadHatter() @@ -87,13 +87,13 @@ async def execute_procedures(self, stray): ) # Execute chain and obtain a choice of procedure from the LLM - llm_action: LLMAction = await self.execute_chain(stray, procedures_prompt_template, allowed_procedures) + llm_action: LLMAction = self.execute_chain(stray, procedures_prompt_template, allowed_procedures) # route execution to subagents - return await self.execute_subagents(stray, llm_action, allowed_procedures) + return self.execute_subagents(stray, llm_action, allowed_procedures) - async def execute_chain(self, stray, procedures_prompt_template, allowed_procedures) -> LLMAction: + def execute_chain(self, stray, procedures_prompt_template, allowed_procedures) -> LLMAction: # Prepare info to fill up the prompt prompt_variables = { @@ -136,7 +136,7 @@ async def execute_chain(self, stray, procedures_prompt_template, allowed_procedu return llm_action - async def execute_subagents(self, stray, llm_action, allowed_procedures): + def execute_subagents(self, stray, llm_action, allowed_procedures): # execute chosen tool / form # loop over allowed tools and forms if llm_action.action: @@ -144,7 +144,7 @@ async def execute_subagents(self, stray, llm_action, allowed_procedures): try: if Plugin._is_cat_tool(chosen_procedure): # execute tool - tool_output = await chosen_procedure._arun(llm_action.action_input, stray=stray) + tool_output = chosen_procedure.run(llm_action.action_input, stray=stray) return AgentOutput( output=tool_output, return_direct=chosen_procedure.return_direct, @@ -158,7 +158,7 @@ async def execute_subagents(self, stray, llm_action, allowed_procedures): # store active form in working memory stray.working_memory.active_form = form_instance # execute form - return await self.form_agent.execute(stray) + return self.form_agent.execute(stray) except Exception as e: log.error(f"Error executing {chosen_procedure.procedure_type} `{chosen_procedure.name}`") diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index 5108c4fc..968159fb 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -43,8 +43,6 @@ def __init__( self.__main_loop = main_loop - self.__loop = asyncio.new_event_loop() - def __repr__(self): return f"StrayCat(user_id={self.user_id})" @@ -342,7 +340,7 @@ def llm(self, prompt: str, stream: bool = False) -> str: return output - async def __call__(self, message_dict): + def __call__(self, message_dict): """Call the Cat instance. This method is called on the user's message received from the client. @@ -408,7 +406,7 @@ async def __call__(self, message_dict): # reply with agent try: - agent_output: AgentOutput = await self.main_agent.execute(self) + agent_output: AgentOutput = self.main_agent.execute(self) except Exception as e: # This error happens when the LLM # does not respect prompt instructions. @@ -472,7 +470,7 @@ async def __call__(self, message_dict): def run(self, user_message_json, return_message=False): try: - cat_message = self.loop.run_until_complete(self.__call__(user_message_json)) + cat_message = self.__call__(user_message_json) if return_message: # return the message for HTTP usage return cat_message @@ -648,7 +646,3 @@ def main_agent(self): @property def white_rabbit(self): return CheshireCat().white_rabbit - - @property - def loop(self): - return self.__loop diff --git a/core/cat/mad_hatter/decorators/tool.py b/core/cat/mad_hatter/decorators/tool.py index b9cfd16e..7e3ba72f 100644 --- a/core/cat/mad_hatter/decorators/tool.py +++ b/core/cat/mad_hatter/decorators/tool.py @@ -9,7 +9,7 @@ # All @tool decorated functions in plugins become a CatTool. # The difference between base langchain Tool and CatTool is that CatTool has an instance of the cat as attribute (set by the MadHatter) -class CatTool(BaseTool): +class CatTool: def __init__( self, name: str, @@ -18,12 +18,7 @@ def __init__( examples: List[str] = [], ): description = func.__doc__.strip() - - # call parent contructor - super().__init__( - name=name, func=func, description=description, return_direct=return_direct - ) - + self.func = func self.procedure_type = "tool" self.name = name @@ -44,21 +39,8 @@ def start_examples(self): def __repr__(self) -> str: return f"CatTool(name={self.name}, return_direct={self.return_direct}, description={self.description})" - # we run tools always async, even if they are not defined so in a plugin - def _run(self, input_by_llm: str) -> str: - pass # do nothing - - # we run tools always async, even if they are not defined so in a plugin - async def _arun(self, input_by_llm, stray): - - # await if the tool is async - if inspect.iscoroutinefunction(self.func): - return await self.func(input_by_llm, cat=stray) - - # run in executor if the tool is not async - return await stray.loop.run_in_executor( - None, self.func, input_by_llm, stray - ) + def run(self, input_by_llm: str, stray) -> str: + return self.func(input_by_llm, cat=stray) # override `extra = 'forbid'` for Tool pydantic model in langchain class Config: diff --git a/core/tests/agents/test_main_agent.py b/core/tests/agents/test_main_agent.py index 7f054edf..d3989498 100644 --- a/core/tests/agents/test_main_agent.py +++ b/core/tests/agents/test_main_agent.py @@ -17,7 +17,7 @@ def test_main_agent_instantiation(main_agent): @pytest.mark.asyncio # to test async functions async def test_execute_main_agent(main_agent, stray): # empty agent execution - out = await main_agent.execute(stray) + out = main_agent.execute(stray) assert isinstance(out, AgentOutput) assert not out.return_direct assert out.intermediate_steps == [] diff --git a/core/tests/looking_glass/test_stray_cat.py b/core/tests/looking_glass/test_stray_cat.py index a7cf40f7..c3030847 100644 --- a/core/tests/looking_glass/test_stray_cat.py +++ b/core/tests/looking_glass/test_stray_cat.py @@ -29,7 +29,7 @@ def test_stray_nlp(stray): def test_stray_call(stray): msg = {"text": "Where do I go?", "user_id": "Alice"} - reply = stray.loop.run_until_complete(stray.__call__(msg)) + reply = stray.__call__(msg) assert isinstance(reply, CatMessage) assert "You did not configure" in reply.content @@ -57,7 +57,7 @@ def test_recall_to_working_memory(stray): msg = {"text": msg_text, "user_id": "Alice"} # send message - stray.loop.run_until_complete(stray.__call__(msg)) + stray.__call__(msg) # recall after episodic memory was stored stray.recall_relevant_memories_to_working_memory(msg_text)