Skip to content

Commit

Permalink
feat: Get integration arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Nov 13, 2024
1 parent 1e0ab6c commit 193c764
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 74 deletions.
76 changes: 3 additions & 73 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,90 +12,20 @@
from temporalio import activity
from temporalio.exceptions import ApplicationError

from ...autogen.Tools import (
BaseIntegrationDef,
BraveIntegrationDef,
BrowserbaseCompleteSessionIntegrationDef,
BrowserbaseContextIntegrationDef,
BrowserbaseCreateSessionIntegrationDef,
BrowserbaseExtensionIntegrationDef,
BrowserbaseGetSessionConnectUrlIntegrationDef,
BrowserbaseGetSessionIntegrationDef,
BrowserbaseGetSessionLiveUrlsIntegrationDef,
BrowserbaseListSessionsIntegrationDef,
DummyIntegrationDef,
EmailIntegrationDef,
RemoteBrowserIntegrationDef,
SpiderIntegrationDef,
Tool,
WeatherIntegrationDef,
WikipediaIntegrationDef,
)
from ...autogen.Tools import Tool
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...common.utils.template import render_template
from ...env import anthropic_api_key, debug
from ..utils import get_handler_with_filtered_params
from ..utils import get_handler_with_filtered_params, get_integration_arguments
from .base_evaluate import base_evaluate

COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"


def _get_integration_arguments(tool: Tool):
providers_map = {
"brave": BraveIntegrationDef,
"dummy": DummyIntegrationDef,
"email": EmailIntegrationDef,
"spider": SpiderIntegrationDef,
"wikipedia": WikipediaIntegrationDef,
"weather": WeatherIntegrationDef,
"browserbase": {
"create_context": BrowserbaseContextIntegrationDef,
"install_extension_from_github": BrowserbaseExtensionIntegrationDef,
"list_sessions": BrowserbaseListSessionsIntegrationDef,
"create_session": BrowserbaseCreateSessionIntegrationDef,
"get_session": BrowserbaseGetSessionIntegrationDef,
"complete_session": BrowserbaseCompleteSessionIntegrationDef,
"get_live_urls": BrowserbaseGetSessionLiveUrlsIntegrationDef,
"get_connect_url": BrowserbaseGetSessionConnectUrlIntegrationDef,
},
"remote_browser": RemoteBrowserIntegrationDef,
}

integration: BaseIntegrationDef | dict[str, BaseIntegrationDef] = providers_map.get(
tool.integration.provider
)
if isinstance(integration, dict):
integration: BaseIntegrationDef = integration.get(tool.integration.method)

properties = {
"type": "object",
"properties": {},
"required": [],
}

arguments: BaseModel | Any | None = integration.arguments
if not arguments:
return properties

if isinstance(arguments, BaseModel):
for fld_name, fld_annotation in arguments.model_fields.items():
properties["properties"][fld_name] = {
"type": fld_annotation.annotation,
"description": fld_name,
}
if fld_annotation.is_required:
properties["required"].append(fld_name)

elif isinstance(arguments, dict):
properties["properties"] = arguments

return properties


def format_tool(tool: Tool) -> dict:
if tool.type == "computer_20241022":
return {
Expand Down Expand Up @@ -142,7 +72,7 @@ def format_tool(tool: Tool) -> dict:
formatted["function"]["parameters"] = json_schema

elif tool.type == "integration" and tool.integration:
formatted["function"]["parameters"] = _get_integration_arguments(tool)
formatted["function"]["parameters"] = get_integration_arguments(tool)

elif tool.type == "api_call" and tool.api_call:
formatted["function"]["parameters"] = tool.api_call.schema_
Expand Down
98 changes: 97 additions & 1 deletion agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,32 @@
import string
import time
import urllib.parse
from typing import Any, Callable, ParamSpec, TypeVar
from typing import Any, Callable, Literal, ParamSpec, TypeVar, get_origin

import re2
import zoneinfo
from beartype import beartype
from pydantic import BaseModel
from simpleeval import EvalWithCompoundTypes, SimpleEval

from ..autogen.openapi_model import SystemDef
from ..autogen.Tools import (
BraveSearchArguments,
BrowserbaseCompleteSessionArguments,
BrowserbaseContextArguments,
BrowserbaseCreateSessionArguments,
BrowserbaseExtensionArguments,
BrowserbaseGetSessionArguments,
BrowserbaseGetSessionConnectUrlArguments,
BrowserbaseGetSessionLiveUrlsArguments,
BrowserbaseListSessionsArguments,
EmailArguments,
RemoteBrowserArguments,
SpiderFetchArguments,
Tool,
WeatherGetArguments,
WikipediaSearchArguments,
)
from ..common.utils import yaml

T = TypeVar("T")
Expand Down Expand Up @@ -378,3 +396,81 @@ def get_handler(system: SystemDef) -> Callable:
raise NotImplementedError(
f"System call not implemented for {system.resource}.{system.operation}"
)


def _annotation_to_type(annotation: type) -> dict[str, str]:
type_, enum = None, None
if get_origin(annotation) is Literal:
type_ = "string"
enum = ",".join(annotation.__args__)
elif annotation is str:
type_ = "string"
elif annotation in (int, float):
type_ = "number"
elif annotation is list:
type_ = "array"
elif annotation is bool:
type_ = "boolean"
elif annotation == type(None):
type_ = "null"
else:
type_ = "object"

result = {
"type": type_,
}
if enum is not None:
result.update({"enum": enum})

return result


def get_integration_arguments(tool: Tool):
providers_map = {
"brave": BraveSearchArguments,
# "dummy": DummyIntegrationDef,
"email": EmailArguments,
"spider": SpiderFetchArguments,
"wikipedia": WikipediaSearchArguments,
"weather": WeatherGetArguments,
"browserbase": {
"create_context": BrowserbaseContextArguments,
"install_extension_from_github": BrowserbaseExtensionArguments,
"list_sessions": BrowserbaseListSessionsArguments,
"create_session": BrowserbaseCreateSessionArguments,
"get_session": BrowserbaseGetSessionArguments,
"complete_session": BrowserbaseCompleteSessionArguments,
"get_live_urls": BrowserbaseGetSessionLiveUrlsArguments,
"get_connect_url": BrowserbaseGetSessionConnectUrlArguments,
},
"remote_browser": RemoteBrowserArguments,
}
properties = {
"type": "object",
"properties": {},
"required": [],
}

integration_args: type[BaseModel] | dict[str, type[BaseModel]] | None = (
providers_map.get(tool.integration.provider)
)

if integration_args is None:
return properties

if isinstance(integration_args, dict):
integration_args: type[BaseModel] | None = integration_args.get(
tool.integration.method
)

if integration_args is None:
return properties

for fld_name, fld_annotation in integration_args.model_fields.items():
tp = _annotation_to_type(fld_annotation.annotation)
tp["description"] = fld_name
properties["properties"][fld_name] = tp
if fld_annotation.is_required:
properties["required"].append(fld_name)

return properties
26 changes: 26 additions & 0 deletions agents-api/tests/test_activities_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from datetime import datetime, timezone
from uuid import uuid4

from ward import test

from agents_api.activities.utils import get_integration_arguments
from agents_api.autogen.Tools import DummyIntegrationDef, Tool


@test("get_integration_arguments: dummy search")
async def _():
tool = Tool(
id=uuid4(),
name="tool1",
type="integration",
integration=DummyIntegrationDef(),
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
result = get_integration_arguments(tool)

assert result == {
"type": "object",
"properties": {},
"required": [],
}

0 comments on commit 193c764

Please sign in to comment.