diff --git a/camel/toolkits/__init__.py b/camel/toolkits/__init__.py index d1658959c..b7454890f 100644 --- a/camel/toolkits/__init__.py +++ b/camel/toolkits/__init__.py @@ -17,6 +17,7 @@ OpenAIFunction, get_openai_function_schema, get_openai_tool_schema, + generate_docstring, ) from .open_api_specs.security_config import openapi_security_config @@ -46,6 +47,7 @@ 'OpenAIFunction', 'get_openai_function_schema', 'get_openai_tool_schema', + "generate_docstring", 'openapi_security_config', 'GithubToolkit', 'MathToolkit', diff --git a/camel/toolkits/function_tool.py b/camel/toolkits/function_tool.py index 804bc894e..f1e74402a 100644 --- a/camel/toolkits/function_tool.py +++ b/camel/toolkits/function_tool.py @@ -11,8 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import logging import warnings -from inspect import Parameter, signature +from inspect import Parameter, getsource, signature from typing import Any, Callable, Dict, Mapping, Optional, Tuple from docstring_parser import parse @@ -21,8 +22,15 @@ from pydantic import create_model from pydantic.fields import FieldInfo +from camel.agents import ChatAgent +from camel.configs import ChatGPTConfig +from camel.messages import BaseMessage +from camel.models import BaseModelBackend, ModelFactory +from camel.types import ModelPlatformType, ModelType from camel.utils import get_pydantic_object_schema, to_pascal +logger = logging.getLogger(__name__) + def _remove_a_key(d: Dict, remove_key: Any) -> None: r"""Remove a key from a dictionary recursively.""" @@ -143,6 +151,84 @@ def _create_mol(name, field): return openai_tool_schema +def generate_docstring( + code: str, + model: Optional[BaseModelBackend] = None, +) -> str: + """Generates a docstring for a given function code using LLM. + + This function leverages a language model to generate a + PEP 8/PEP 257-compliant docstring for a provided Python function. + If no model is supplied, a default GPT_4O_MINI is used. + + Args: + code (str): The source code of the function. + model (Optional[BaseModelBackend]): An optional language model backend + instance. If not provided, a default GPT_4O_MINI is used. + + Returns: + str: The generated docstring. + """ + # Create the docstring prompt + docstring_prompt = ''' + **Role**: Generate professional Python docstrings conforming to + PEP 8/PEP 257. + + **Requirements**: + - Use appropriate format: reST, Google, or NumPy, as needed. + - Include parameters, return values, and exceptions. + - Reference any existing docstring in the function and + retain useful information. + + **Input**: Python function. + + **Output**: Docstring content (plain text, no code markers). + + **Example:** + + Input: + ```python + def add(a: int, b: int) -> int: + return a + b + ``` + + Output: + Adds two numbers. + Args: + a (int): The first number. + b (int): The second number. + + Returns: + int: The sum of the two numbers. + + **Task**: Generate a docstring for the function below. + + ''' + # Create the assistant model if not provided + if not model: + model = ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=ModelType.GPT_4O_MINI, + model_config_dict=ChatGPTConfig(temperature=1.0).as_dict(), + ) + # Initialize assistant with system message and model + assistant_sys_msg = BaseMessage.make_assistant_message( + role_name="Assistant", + content="You are a helpful assistant.", + ) + docstring_assistant = ChatAgent(assistant_sys_msg, model=model) + + # Create user message to prompt the assistant + user_msg = BaseMessage.make_user_message( + role_name="User", + content=docstring_prompt + code, + ) + + # Get the response containing the generated docstring + response = docstring_assistant.step(user_msg) + return response.msg.content + + class FunctionTool: r"""An abstraction of a function that OpenAI chat models can call. See https://platform.openai.com/docs/api-reference/chat/create. @@ -151,23 +237,59 @@ class FunctionTool: provide a user-defined tool schema to override. Args: - func (Callable): The function to call.The tool schema is parsed from - the signature and docstring by default. - openai_tool_schema (Optional[Dict[str, Any]], optional): A user-defined - openai tool schema to override the default result. + func (Callable): The function to call. The tool schema is parsed from + the function signature and docstring by default. + openai_tool_schema (Optional[Dict[str, Any]], optional): A + user-defined OpenAI tool schema to override the default result. (default: :obj:`None`) + use_schema_assistant (Optional[bool], optional): Whether to enable the + use of a schema assistant model to automatically generate the + schema if validation fails or no valid schema is provided. + (default: :obj:`False`) + schema_assistant_model (Optional[BaseModelBackend], optional): An + assistant model (e.g., an LLM model) used to generate the schema + if `use_schema_assistant` is enabled and no valid schema is + provided. + (default: :obj:`None`) + schema_generation_max_retries (int, optional): The maximum + number of attempts to retry schema generation using the schema + assistant model if the previous attempts fail. + (default: 2) """ def __init__( self, func: Callable, openai_tool_schema: Optional[Dict[str, Any]] = None, + use_schema_assistant: Optional[bool] = False, + schema_assistant_model: Optional[BaseModelBackend] = None, + schema_generation_max_retries: int = 2, ) -> None: self.func = func self.openai_tool_schema = openai_tool_schema or get_openai_tool_schema( func ) + if use_schema_assistant: + try: + self.validate_openai_tool_schema(self.openai_tool_schema) + except Exception: + print( + f"Warning: No valid schema found for " + f"{self.func.__name__}. " + f"Attempting to generate one using LLM." + ) + schema = self.generate_openai_tool_schema( + schema_generation_max_retries, schema_assistant_model + ) + if schema: + self.openai_tool_schema = schema + else: + raise ValueError( + f"Failed to generate valid schema for " + f"{self.func.__name__}" + ) + @staticmethod def validate_openai_tool_schema( openai_tool_schema: Dict[str, Any], @@ -260,8 +382,8 @@ def set_openai_function_schema( r"""Sets the schema of the function within the OpenAI tool schema. Args: - openai_function_schema (Dict[str, Any]): The function schema to set - within the OpenAI tool schema. + openai_function_schema (Dict[str, Any]): The function schema to + set within the OpenAI tool schema. """ self.openai_tool_schema["function"] = openai_function_schema @@ -362,6 +484,72 @@ def set_parameter(self, param_name: str, value: Dict[str, Any]): param_name ] = value + def generate_openai_tool_schema( + self, + max_retries: int, + schema_assistant: Optional[BaseModelBackend] = None, + ) -> Dict[str, Any]: + r"""Generates an OpenAI tool schema for the specified function. + + This method uses a language model (LLM) to generate the OpenAI tool + schema for the specified function by first generating a docstring and + then creating a schema based on the function's source code. If no LLM + is provided, it defaults to initializing a GPT_4O_MINI model. The + schema generation and validation process is retried up to + `max_retries` times in case of failure. + + + Args: + max_retries (int): The maximum number of retries for schema + generation and validation if the process fails. + schema_assistant (Optional[BaseModelBackend]): An optional LLM + backend model used for generating the docstring and schema. If + not provided, a GPT_4O_MINI model will be created. + + Returns: + Dict[str, Any]: The generated OpenAI tool schema for the function. + + Raises: + ValueError: If schema generation or validation fails after the + maximum number of retries, a ValueError is raised, + prompting manual schema setting. + """ + if not schema_assistant: + logger.warning( + "Warning: No model provided. " + "Use GPT_4O_MINI to generate the schema." + ) + schema_assistant = ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=ModelType.GPT_4O_MINI, + model_config_dict=ChatGPTConfig(temperature=1.0).as_dict(), + ) + code = getsource(self.func) + retries = 0 + # Retry loop to handle schema generation and validation + while retries < max_retries: + try: + # Generate the docstring and the schema + docstring = generate_docstring(code, schema_assistant) + self.func.__doc__ = docstring + schema = get_openai_tool_schema(self.func) + # Validate the schema + self.validate_openai_tool_schema(schema) + return schema + + except Exception as e: + retries += 1 + if retries == max_retries: + raise ValueError( + f"Failed to generate the OpenAI tool Schema after " + f"{max_retries} retries. " + f"Please set the OpenAI tool schema for " + f"function {self.func.__name__} manually." + ) from e + logger.warning("Schema validation failed. Retrying...") + + return {} + @property def parameters(self) -> Dict[str, Any]: r"""Getter method for the property :obj:`parameters`. diff --git a/examples/tool_call/generate_openai_tool_schema_example.py b/examples/tool_call/generate_openai_tool_schema_example.py new file mode 100644 index 000000000..6229640e0 --- /dev/null +++ b/examples/tool_call/generate_openai_tool_schema_example.py @@ -0,0 +1,89 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== + +import os + +from camel.agents import ChatAgent +from camel.configs import ChatGPTConfig +from camel.messages import BaseMessage +from camel.models import ModelFactory +from camel.toolkits import FunctionTool +from camel.types import ModelPlatformType, ModelType + +# Set OpenAI API key +api_key = os.getenv("OPENAI_API_KEY") +if not api_key: + raise ValueError("API key not found in environment variables.") + + +# Define a function which does't have a docstring +def get_perfect_square(n: int) -> int: + return n**2 + + +# Create a model instance +model_config_dict = ChatGPTConfig(temperature=1.0).as_dict() +agent_model = ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=ModelType.GPT_4O_MINI, + model_config_dict=model_config_dict, +) + +# Create a FunctionTool with the function +function_tool = FunctionTool( + get_perfect_square, + schema_assistant_model=agent_model, + use_schema_assistant=True, +) +print("\nGenerated OpenAI Tool Schema:") +print(function_tool.get_openai_tool_schema()) + +# Set system message for the assistant +assistant_sys_msg = BaseMessage.make_assistant_message( + role_name="Assistant", content="You are a helpful assistant." +) + +# Create a ChatAgent with the tool +camel_agent = ChatAgent( + system_message=assistant_sys_msg, model=agent_model, tools=[function_tool] +) +camel_agent.reset() + +# Define a user message +user_prompt = "What is the perfect square of 2024?" +user_msg = BaseMessage.make_user_message(role_name="User", content=user_prompt) + +# Get response from the assistant +response = camel_agent.step(user_msg) +print("\nAssistant Response:") +print(response.msg.content) + +print(""" +=============================================================================== +Warning: No model provided. Use GPT_4O_MINI to generate the schema for +the function get_perfect_square. Attempting to generate one using LLM. +Successfully generated the OpenAI tool schema for +the function get_perfect_square. + +Generated OpenAI Tool Schema: +{'type': 'function', 'function': {'name': 'get_perfect_square', +'description': 'Calculates the perfect square of a given integer.', +'parameters': {'properties': {'n': {'type': 'integer', +'description': 'The integer to be squared.'}}, 'required': ['n'], +'type': 'object'}}} + +[FunctionCallingRecord(func_name='get_perfect_square', args={'n': 2024}, +result={'result': 4096576})] +=============================================================================== +""") diff --git a/test/toolkits/test_generate_openai_tool_schema.py b/test/toolkits/test_generate_openai_tool_schema.py new file mode 100644 index 000000000..275332561 --- /dev/null +++ b/test/toolkits/test_generate_openai_tool_schema.py @@ -0,0 +1,224 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from unittest.mock import MagicMock, patch + +import pytest + +from camel.agents import ChatAgent +from camel.models import BaseModelBackend +from camel.toolkits import FunctionTool, generate_docstring + + +def sample_function(a: int, b: str = "default") -> bool: + """ + This function checks if the integer is positive and + if the string is non-empty. + + Args: + a (int): The integer value to check. + b (str): The string to verify. Default is 'default'. + + Returns: + bool: True if both conditions are met, otherwise False. + """ + return a > 0 and len(b) > 0 + + +@patch.object(FunctionTool, 'validate_openai_tool_schema') +@patch.object(FunctionTool, 'generate_openai_tool_schema') +def test_generate_openai_tool_schema( + mock_generate_schema, mock_validate_schema +): + # Mock the validate_openai_tool_schema to raise an exception + mock_validate_schema.side_effect = Exception("Invalid schema") + + # Mock the generate_openai_tool_schema to return a specific schema + mock_schema = { + 'type': 'function', + 'function': { + 'name': 'sample_function', + 'description': ( + 'This function checks if the integer is positive and\n' + 'if the string is non-empty.' + ), + 'parameters': { + 'type': 'object', + 'properties': { + 'a': { + 'type': 'integer', + 'description': 'The integer value to check.', + }, + 'b': { + 'type': 'string', + 'description': ( + "The string to verify. Default is 'default'." + ), + 'default': 'default', + }, + }, + 'required': ['a'], + }, + }, + } + mock_generate_schema.return_value = mock_schema + + # Create FunctionTool instance with use_schema_assistant=True + function_tool = FunctionTool( + func=sample_function, + use_schema_assistant=True, + schema_assistant_model=None, + ) + + # Assert that the generated schema matches the expected schema + assert function_tool.openai_tool_schema == mock_schema + mock_generate_schema.assert_called_once() + + +@pytest.fixture +def mock_model(): + # Create a mock model to simulate BaseModelBackend behavior + mock_model = MagicMock(spec=BaseModelBackend) + mock_model.model_type = MagicMock() + mock_model.model_type.value_for_tiktoken = "mock_value_for_tiktoken" + mock_model.model_config_dict = {} + mock_model.value_for_tiktoken = MagicMock(return_value=1000) + return mock_model + + +@patch('camel.models.ModelFactory.create') +@patch.object(ChatAgent, 'step') +def test_generate_docstring( + mock_chat_agent_step, mock_model_factory, mock_model +): + code = """ + def sample_function(a: int, b: str = "default") -> bool: + return a > 0 and len(b) > 0 + """ + + # Mock the model factory to return the mock model + mock_model_factory.return_value = mock_model + + # Ensure mock_model has required attributes + mock_model.model_type = MagicMock() + mock_model.model_type.value_for_tiktoken = "mock_value_for_tiktoken" + mock_model.model_config_dict = {} + mock_model.value_for_tiktoken = MagicMock(return_value=1000) + + # Mock ChatAgent's step method return value + mock_message = MagicMock() + mock_message.content = ( + "This function checks if the integer is positive and " + "if the string is non-empty.\n" + "Args:\n a (int): The integer value to check.\n" + " b (str): The string to verify. Default is 'default'.\n" + "Returns:\n bool: True if both conditions are met, " + "otherwise False." + ) + mock_response = MagicMock() + mock_response.msgs = [mock_message] + mock_response.msg = mock_message + mock_response.terminated = True + mock_chat_agent_step.return_value = mock_response + + # Generate docstring + try: + docstring = generate_docstring(code, mock_model) + except AttributeError as e: + pytest.fail( + f"generate_docstring() raised AttributeError unexpectedly: {e}" + ) + except RuntimeError as e: + pytest.fail( + f"generate_docstring() raised RuntimeError unexpectedly: {e}" + ) + + expected_docstring = ( + "This function checks if the integer is positive and " + "if the string is non-empty.\n" + "Args:\n a (int): The integer value to check.\n" + " b (str): The string to verify. Default is 'default'.\n" + "Returns:\n bool: True if both conditions are met, " + "otherwise False." + ) + + assert docstring == expected_docstring + + +@patch('camel.models.ModelFactory.create') +@patch.object(ChatAgent, 'step') +def test_function_tool_generate_schema_with_retries( + mock_chat_agent_step, mock_model_factory +): + # Mock the model factory to return a mock model + mock_model = MagicMock(spec=BaseModelBackend) + mock_model_factory.return_value = mock_model + + # Mock ChatAgent's step method to simulate retries + mock_message = MagicMock() + mock_message.content = ( + "This function checks if the integer is positive and\n" + "if the string is non-empty.\n" + "Args:\n a (int): The integer value to check.\n" + " b (str): The string to verify. Default is 'default'.\n" + "Returns:\n bool: True if both conditions are met, otherwise False." + ) + mock_response = MagicMock() + mock_response.msgs = [mock_message] + mock_response.msg = mock_message + mock_response.terminated = True + + # Configure the step method to fail the first time + # and succeed the second time + mock_chat_agent_step.side_effect = [ + Exception("Validation failed"), + mock_response, + ] + + # Create FunctionTool instance with use_schema_assistant=True + function_tool = FunctionTool( + func=sample_function, + use_schema_assistant=True, + schema_assistant_model=mock_model, + schema_generation_max_retries=2, + ) + + expected_schema = { + 'type': 'function', + 'function': { + 'name': 'sample_function', + 'description': ( + 'This function checks if the integer is positive and\n' + 'if the string is non-empty.' + ), + 'parameters': { + 'type': 'object', + 'properties': { + 'a': { + 'type': 'integer', + 'description': 'The integer value to check.', + }, + 'b': { + 'type': 'string', + 'description': ( + "The string to verify. Default is 'default'." + ), + 'default': 'default', + }, + }, + 'required': ['a'], + }, + }, + } + + assert function_tool.openai_tool_schema == expected_schema