diff --git a/core/cat/mad_hatter/decorators/tool.py b/core/cat/mad_hatter/decorators/tool.py index c97f6295..7e3ba72f 100644 --- a/core/cat/mad_hatter/decorators/tool.py +++ b/core/cat/mad_hatter/decorators/tool.py @@ -9,7 +9,7 @@ # All @tool decorated functions in plugins become a CatTool. # The difference between base langchain Tool and CatTool is that CatTool has an instance of the cat as attribute (set by the MadHatter) -class CatTool(BaseTool): +class CatTool: def __init__( self, name: str, @@ -18,12 +18,7 @@ def __init__( examples: List[str] = [], ): description = func.__doc__.strip() - - # call parent contructor - super().__init__( - name=name, func=func, description=description, return_direct=return_direct - ) - + self.func = func self.procedure_type = "tool" self.name = name @@ -44,14 +39,9 @@ def start_examples(self): def __repr__(self) -> str: return f"CatTool(name={self.name}, return_direct={self.return_direct}, description={self.description})" - # we run tools always async, even if they are not defined so in a plugin - def _run(self, input_by_llm: str, stray) -> str: + def run(self, input_by_llm: str, stray) -> str: return self.func(input_by_llm, cat=stray) - # we run tools always async, even if they are not defined so in a plugin - async def _arun(self, input_by_llm, stray): - pass - # override `extra = 'forbid'` for Tool pydantic model in langchain class Config: extra = "allow" diff --git a/core/tests/agents/test_main_agent.py b/core/tests/agents/test_main_agent.py index 7f054edf..d3989498 100644 --- a/core/tests/agents/test_main_agent.py +++ b/core/tests/agents/test_main_agent.py @@ -17,7 +17,7 @@ def test_main_agent_instantiation(main_agent): @pytest.mark.asyncio # to test async functions async def test_execute_main_agent(main_agent, stray): # empty agent execution - out = await main_agent.execute(stray) + out = main_agent.execute(stray) assert isinstance(out, AgentOutput) assert not out.return_direct assert out.intermediate_steps == [] diff --git a/core/tests/looking_glass/test_stray_cat.py b/core/tests/looking_glass/test_stray_cat.py index a7cf40f7..c3030847 100644 --- a/core/tests/looking_glass/test_stray_cat.py +++ b/core/tests/looking_glass/test_stray_cat.py @@ -29,7 +29,7 @@ def test_stray_nlp(stray): def test_stray_call(stray): msg = {"text": "Where do I go?", "user_id": "Alice"} - reply = stray.loop.run_until_complete(stray.__call__(msg)) + reply = stray.__call__(msg) assert isinstance(reply, CatMessage) assert "You did not configure" in reply.content @@ -57,7 +57,7 @@ def test_recall_to_working_memory(stray): msg = {"text": msg_text, "user_id": "Alice"} # send message - stray.loop.run_until_complete(stray.__call__(msg)) + stray.__call__(msg) # recall after episodic memory was stored stray.recall_relevant_memories_to_working_memory(msg_text)