Skip to content

Commit

Permalink
feat: add Terminator for agent to terminate (#254)
Browse files Browse the repository at this point in the history
Co-authored-by: Guohao Li <[email protected]>
Co-authored-by: lig <[email protected]>
  • Loading branch information
3 people authored Oct 25, 2023
1 parent 9a9d718 commit 8ace5e4
Show file tree
Hide file tree
Showing 20 changed files with 475 additions and 56 deletions.
3 changes: 1 addition & 2 deletions camel/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .base import BaseAgent
from .chat_agent import ChatAgent, ChatAgentResponse
from .chat_agent import ChatAgent
from .task_agent import (
TaskSpecifyAgent,
TaskPlannerAgent,
Expand All @@ -28,7 +28,6 @@
__all__ = [
'BaseAgent',
'ChatAgent',
'ChatAgentResponse',
'TaskSpecifyAgent',
'TaskPlannerAgent',
'TaskCreationAgent',
Expand Down
78 changes: 43 additions & 35 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,12 @@
from camel.functions import OpenAIFunction
from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
from camel.models import BaseModelBackend, ModelFactory
from camel.responses import ChatAgentResponse
from camel.terminators import ResponseTerminator, TokenLimitTerminator
from camel.typing import ModelType, RoleType
from camel.utils import get_model_encoding, openai_api_key_required


@dataclass(frozen=True)
class ChatAgentResponse:
r"""Response of a ChatAgent.
Attributes:
msgs (List[BaseMessage]): A list of zero, one or several messages.
If the list is empty, there is some error in message generation.
If the list has one message, this is normal mode.
If the list has several messages, this is the critic mode.
terminated (bool): A boolean indicating whether the agent decided
to terminate the chat session.
info (Dict[str, Any]): Extra information about the chat message.
"""
msgs: List[BaseMessage]
terminated: bool
info: Dict[str, Any]

@property
def msg(self):
if len(self.msgs) != 1:
raise RuntimeError("Property msg is only available "
"for a single message in msgs.")
return self.msgs[0]


@dataclass(frozen=True)
class ChatRecord:
r"""Historical records of who made what message.
Expand Down Expand Up @@ -116,8 +93,11 @@ class ChatAgent(BaseAgent):
is performed. (default: :obj:`None`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
function_list (Optional[List[OpenAIFunction]]): List of available
function_list (List[OpenAIFunction], optional): List of available
:obj:`OpenAIFunction`. (default: :obj:`None`)
response_terminators (List[ResponseTerminator], optional): List of
:obj:`ResponseTerminator` bind to one chat agent.
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -128,6 +108,7 @@ def __init__(
message_window_size: Optional[int] = None,
output_language: Optional[str] = None,
function_list: Optional[List[OpenAIFunction]] = None,
response_terminators: Optional[List[ResponseTerminator]] = None,
) -> None:

self.orig_sys_message: BaseMessage = system_message
Expand All @@ -153,6 +134,9 @@ def __init__(
self.model_token_limit: int = self.model_backend.token_limit

self.terminated: bool = False
self.token_limit_terminator = TokenLimitTerminator(
self.model_token_limit)
self.response_terminators = response_terminators or []
self.stored_messages: List[ChatRecord]
self.init_messages()

Expand All @@ -165,6 +149,9 @@ def reset(self):
"""
self.terminated = False
self.init_messages()
self.token_limit_terminator.reset()
for terminator in self.response_terminators:
terminator.reset()

@property
def system_message(self) -> BaseMessage:
Expand Down Expand Up @@ -252,6 +239,7 @@ def update_messages(self, role: str,
r"""Updates the stored messages list with a new message.
Args:
role (str): Role of the message at the backend.
message (BaseMessage): The new message to add to the stored
messages.
Expand Down Expand Up @@ -306,8 +294,11 @@ def step(
openai_messages, num_tokens = self.preprocess_messages(messages)

# Terminate when number of tokens exceeds the limit
if num_tokens >= self.model_token_limit:
return self.step_token_exceed(num_tokens, called_funcs)
self.terminated, termination_reason = \
self.token_limit_terminator.is_terminated(num_tokens)
if self.terminated and termination_reason is not None:
return self.step_token_exceed(num_tokens, called_funcs,
termination_reason)

# Obtain LLM's response and validate it
response = self.model_backend.run(openai_messages)
Expand All @@ -333,6 +324,23 @@ def step(
called_funcs.append(func_record)
else:
# Function calling disabled or chat stopped

# Loop over responses terminators, get list of termination
# tuples with whether the terminator terminates the agent
# and termination reason
termination = [
terminator.is_terminated(output_messages)
for terminator in self.response_terminators
]
# Terminate the agent if any of the terminator terminates
self.terminated, termination_reason = next(
((terminated, termination_reason)
for terminated, termination_reason in termination
if terminated), (False, None))
# For now only retain the first termination reason
if self.terminated and termination_reason is not None:
finish_reasons = [termination_reason] * len(finish_reasons)

info = self.get_info(
response_id,
usage_dict,
Expand Down Expand Up @@ -456,29 +464,29 @@ def handle_stream_response(
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
return output_messages, finish_reasons, usage_dict, response_id

def step_token_exceed(
self, num_tokens: int,
called_funcs: List[FunctionCallingRecord]) -> ChatAgentResponse:
def step_token_exceed(self, num_tokens: int,
called_funcs: List[FunctionCallingRecord],
termination_reason: str) -> ChatAgentResponse:
r"""Return trivial response containing number of tokens and information
of called functions when the number of tokens exceeds.
Args:
num_tokens (int): Number of tokens in the messages.
called_funcs (List[FunctionCallingRecord]): List of information
objects of functions called in the current step.
termination_reason (str): String of termination reason.
Returns:
ChatAgentResponse: The struct containing trivial outputs and
information about token number and called functions.
"""

self.terminated = True
output_messages: List[BaseMessage] = []

info = self.get_info(
None,
None,
["max_tokens_exceeded"],
[termination_reason],
num_tokens,
called_funcs,
)
Expand All @@ -496,11 +504,11 @@ def step_function_call(
r"""Execute the function with arguments following the model's response.
Args:
response (Dict[str, Any]): the response obtained by calling the
response (Dict[str, Any]): The response obtained by calling the
model.
Returns:
tuple: a tuple consisting of two obj:`FunctionCallingMessage`,
tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
one about the arguments and the other about the execution
result, and a struct for logging information about this
function call.
Expand Down
6 changes: 4 additions & 2 deletions camel/agents/critic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from colorama import Fore

from camel.agents import ChatAgent, ChatAgentResponse
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from camel.responses import ChatAgentResponse
from camel.typing import ModelType
from camel.utils import get_first_int, print_text_animated

Expand Down Expand Up @@ -149,7 +150,8 @@ def reduce_step(
critic, getting the option, and parsing the choice.
Args:
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
input_messages (Sequence[BaseMessage]): A list of BaseMessage
objects.
Returns:
ChatAgentResponse: A `ChatAgentResponse` object includes the
Expand Down
8 changes: 2 additions & 6 deletions camel/agents/embodied_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@

from colorama import Fore

from camel.agents import (
BaseToolAgent,
ChatAgent,
ChatAgentResponse,
HuggingFaceToolAgent,
)
from camel.agents import BaseToolAgent, ChatAgent, HuggingFaceToolAgent
from camel.messages import BaseMessage
from camel.responses import ChatAgentResponse
from camel.typing import ModelType
from camel.utils import PythonInterpreter, print_text_animated

Expand Down
6 changes: 3 additions & 3 deletions camel/functions/search_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def search_wiki(entity: str) -> str:


def search_google(query: str) -> List[Dict[str, Any]]:
r"""Use google search engine to search information for the given query.
r"""Use Google search engine to search information for the given query.
Args:
query (string): The query to be searched.
Expand Down Expand Up @@ -135,7 +135,7 @@ def search_google(query: str) -> List[Dict[str, Any]]:
responses.append({"error": "google search failed."})

except requests.RequestException:
responses.append({"erro": "google search failed."})
responses.append({"error": "google search failed."})

return responses

Expand Down Expand Up @@ -266,7 +266,7 @@ def summarize_text(text: str, query: str) -> str:

def search_google_and_summarize(query: str) -> str:
r"""Search webs for information. Given a query, this function will use
the google search engine to search for related information from the
the Google search engine to search for related information from the
internet, and then return a summarized answer.
Args:
Expand Down
2 changes: 1 addition & 1 deletion camel/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from colorama import Fore

from camel.agents import ChatAgentResponse
from camel.messages import BaseMessage
from camel.responses import ChatAgentResponse
from camel.utils import print_text_animated


Expand Down
2 changes: 0 additions & 2 deletions camel/messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class BaseMessage:
:obj:`RoleType.ASSISTANT` or :obj:`RoleType.USER`.
meta_dict (Optional[Dict[str, str]]): Additional metadata dictionary
for the message.
role (str): The role of the message in OpenAI chat system, either
:obj:`"system"`, :obj:`"user"`, or :obj:`"assistant"`.
content (str): The content of the message.
"""
role_name: str
Expand Down
18 changes: 18 additions & 0 deletions camel/responses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .agent_responses import ChatAgentResponse

__all__ = [
'ChatAgentResponse',
]
42 changes: 42 additions & 0 deletions camel/responses/agent_responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from dataclasses import dataclass
from typing import Any, Dict, List

from camel.messages import BaseMessage


@dataclass(frozen=True)
class ChatAgentResponse:
r"""Response of a ChatAgent.
Attributes:
msgs (List[BaseMessage]): A list of zero, one or several messages.
If the list is empty, there is some error in message generation.
If the list has one message, this is normal mode.
If the list has several messages, this is the critic mode.
terminated (bool): A boolean indicating whether the agent decided
to terminate the chat session.
info (Dict[str, Any]): Extra information about the chat message.
"""
msgs: List[BaseMessage]
terminated: bool
info: Dict[str, Any]

@property
def msg(self):
if len(self.msgs) != 1:
raise RuntimeError("Property msg is only available "
"for a single message in msgs.")
return self.msgs[0]
2 changes: 1 addition & 1 deletion camel/societies/role_playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
TaskPlannerAgent,
TaskSpecifyAgent,
)
from camel.agents.chat_agent import ChatAgentResponse
from camel.generators import SystemMessageGenerator
from camel.human import Human
from camel.messages import BaseMessage
from camel.prompts import TextPrompt
from camel.responses import ChatAgentResponse
from camel.typing import ModelType, RoleType, TaskType


Expand Down
23 changes: 23 additions & 0 deletions camel/terminators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .base import BaseTerminator
from .response_terminator import ResponseWordsTerminator, ResponseTerminator
from .token_limit_terminator import TokenLimitTerminator

__all__ = [
'BaseTerminator',
'ResponseTerminator',
'ResponseWordsTerminator',
'TokenLimitTerminator',
]
Loading

0 comments on commit 8ace5e4

Please sign in to comment.