Skip to content

Commit

Permalink
feat: Get intergration arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Nov 12, 2024
1 parent c5970cb commit 005369e
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from litellm.types.utils import Choices, ModelResponse
from temporalio import activity
from temporalio.exceptions import ApplicationError
from pydantic import BaseModel

from ...autogen.Tools import (
BraveIntegrationDef,
Expand All @@ -28,6 +29,7 @@
Tool,
WeatherIntegrationDef,
WikipediaIntegrationDef,
BaseIntegrationDef,
)
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
Expand Down Expand Up @@ -63,20 +65,33 @@ def _get_integration_arguments(tool: Tool):
"remote_browser": RemoteBrowserIntegrationDef,
}

integration = providers_map.get(tool.integration.provider)
integration: BaseIntegrationDef | dict[str, BaseIntegrationDef] = providers_map.get(tool.integration.provider)
if isinstance(integration, dict):
integration = integration.get(tool.integration.method)
integration: BaseIntegrationDef = integration.get(tool.integration.method)

return integration.model_fields["arguments"].annotation if integration else None
properties = {
"type": "object",
"properties": {},
"required": [],
}

arguments: BaseModel | Any | None = integration.arguments
if not arguments:
return properties

if isinstance(arguments, BaseModel):
for fld_name, fld_annotation in arguments.model_fields.items():
properties["properties"][fld_name] = {
"type": fld_annotation.annotation,
"description": fld_name,
}
if fld_annotation.is_required:
properties["required"].append(fld_name)

def _annotation_input_schema(annotation: Any) -> dict:
# TODO: implement
def _tool(x: annotation):
pass
elif isinstance(arguments, dict):
properties["properties"] = arguments

lc_tool: BaseTool = tool_decorator(_tool)
return lc_tool.get_input_jsonschema()
return properties


def format_tool(tool: Tool) -> dict:
Expand Down Expand Up @@ -125,8 +140,7 @@ def format_tool(tool: Tool) -> dict:
formatted["function"]["parameters"] = json_schema

elif tool.type == "integration" and tool.integration:
if annotation := _get_integration_arguments(tool):
formatted["function"]["parameters"] = _annotation_input_schema(annotation)
formatted["function"]["parameters"] = _get_integration_arguments(tool)

elif tool.type == "api_call" and tool.api_call:
formatted["function"]["parameters"] = tool.api_call.schema_
Expand Down

0 comments on commit 005369e

Please sign in to comment.