Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api,integrations): Working integrations for tool-call step #521

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ LITELLM_REDIS_PASSWORD=<your_litellm_redis_password>
# AGENTS_API_DEBUG=false
# EMBEDDING_MODEL_ID=Alibaba-NLP/gte-large-en-v1.5
# NUM_GPUS=1
# INTEGRATION_SERVICE_URL=http://integrations:8000

# Temporal
# --------
Expand Down
25 changes: 12 additions & 13 deletions agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from temporalio import activity

from ..autogen.openapi_model import IntegrationDef
from ..clients import integrations
from ..common.protocol.tasks import StepContext
from ..env import testing
from ..models.tools import get_tool_args_from_metadata
Expand All @@ -24,31 +25,29 @@ async def execute_integration(
developer_id=developer_id, agent_id=agent_id, task_id=task_id
)

arguments = merged_tool_args.get(tool_name, {}) | arguments
arguments = (
merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments
)

try:
if integration.provider == "dummy":
return arguments

else:
raise NotImplementedError(
f"Unknown integration provider: {integration.provider}"
)
return await integrations.run_integration_service(
provider=integration.provider,
setup=integration.setup,
method=integration.method,
arguments=arguments,
)

except BaseException as e:
if activity.in_activity():
activity.logger.error(f"Error in execute_integration: {e}")

raise


async def mock_execute_integration(
context: StepContext,
tool_name: str,
integration: IntegrationDef,
arguments: dict[str, Any],
) -> Any:
return arguments

mock_execute_integration = execute_integration

execute_integration = activity.defn(name="execute_integration")(
execute_integration if not testing else mock_execute_integration
Expand Down
18 changes: 6 additions & 12 deletions agents-api/agents_api/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,11 @@ class IntegrationDef(BaseModel):
)
provider: Literal[
"dummy",
"dall-e",
"duckduckgo",
"hackernews",
"dalle_image_generator",
"duckduckgo_search",
"hacker_news",
"weather",
"wikipedia",
"twitter",
"webpage",
"requests",
]
"""
The provider of the integration
Expand Down Expand Up @@ -132,14 +129,11 @@ class IntegrationDefUpdate(BaseModel):
provider: (
Literal[
"dummy",
"dall-e",
"duckduckgo",
"hackernews",
"dalle_image_generator",
"duckduckgo_search",
"hacker_news",
"weather",
"wikipedia",
"twitter",
"webpage",
"requests",
]
| None
) = None
Expand Down
31 changes: 31 additions & 0 deletions agents-api/agents_api/clients/integrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any, List

from beartype import beartype
from httpx import AsyncClient

from ..env import integration_service_url

__all__: List[str] = ["run_integration_service"]


@beartype
async def run_integration_service(
*,
provider: str,
arguments: dict,
setup: dict | None = None,
method: str | None = None,
) -> Any:
slug = f"{provider}/{method}" if method else provider
url = f"{integration_service_url}/execute/{slug}"

setup = setup or {}

async with AsyncClient() as client:
response = await client.post(
url,
json={"arguments": arguments, "setup": setup},
)
response.raise_for_status()

return response.json()
7 changes: 7 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@
embedding_dimensions: int = env.int("EMBEDDING_DIMENSIONS", default=1024)


# Integration service
# -------------------
integration_service_url: str = env.str(
"INTEGRATION_SERVICE_URL", default="http://0.0.0.0:8000"
)


# Temporal
# --------
temporal_worker_url: str = env.str("TEMPORAL_WORKER_URL", default="localhost:7233")
Expand Down
14 changes: 11 additions & 3 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ForeachStep,
GetStep,
IfElseWorkflowStep,
IntegrationDef,
LogStep,
MapReduceStep,
ParallelStep,
Expand Down Expand Up @@ -60,7 +61,7 @@

# WorkflowStep = (
# EvaluateStep # ✅
# | ToolCallStep # ❌ <--- high priority
# | ToolCallStep #
# | PromptStep # 🟡 <--- high priority
# | GetStep # ✅
# | SetStep # ✅
Expand Down Expand Up @@ -482,13 +483,20 @@ async def run(
call = tool_call["integration"]
tool_name = call["name"]
arguments = call["arguments"]
integration = next(
integration_spec = next(
(t for t in context.tools if t.name == tool_name), None
)

if integration is None:
if integration_spec is None:
raise ApplicationError(f"Integration {tool_name} not found")

integration = IntegrationDef(
provider=integration_spec.spec["provider"],
setup=integration_spec.spec["setup"],
method=integration_spec.spec["method"],
arguments=arguments,
)

tool_call_response = await workflow.execute_activity(
execute_integration,
args=[context, tool_name, integration, arguments],
Expand Down
1 change: 1 addition & 0 deletions agents-api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ x--shared-environment: &shared-environment
COZO_HOST: ${COZO_HOST:-http://memory-store:9070}
DEBUG: ${AGENTS_API_DEBUG:-False}
EMBEDDING_MODEL_ID: ${EMBEDDING_MODEL_ID:-Alibaba-NLP/gte-large-en-v1.5}
INTEGRATION_SERVICE_URL: ${INTEGRATION_SERVICE_URL:-http://integrations:8000}
LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY}
LITELLM_URL: ${LITELLM_URL:-http://litellm:4000}
SUMMARIZATION_MODEL_NAME: ${SUMMARIZATION_MODEL_NAME:-gpt-4-turbo}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#/usr/bin/env python3
# /usr/bin/env python3

MIGRATION_ID = "add_forward_tool_calls_option"
CREATED_AT = 1727235852.744035
Expand Down
44 changes: 44 additions & 0 deletions agents-api/tests/sample_tasks/integration_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: Simple multi step task

input_schema:
type: object
properties:
topics:
type: array
items:
type: string

tools:
- type: function
function:
name: generate_questions
description: Generate a list of questions for a given topic
parameters:
type: object
properties:
topic:
type: string
description: The topic to generate questions for

- type: integration
name: duckduckgo_search
integration:
provider: duckduckgo
setup:
api_key: <something>
arguments:
language: en-US

main:
- foreach:
in: _["topics"]
do:
prompt:
- role: system
content: |-
Generate a list of 10 questions for the topic {{_}} as valid yaml.
unwrap: true

- tool: duckduckgo_search
arguments:
query: "'\n'.join(_)"
64 changes: 62 additions & 2 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from agents_api.routers.tasks.create_task_execution import start_execution

from .fixtures import cozo_client, test_agent, test_developer_id
from .utils import patch_testing_temporal
from .utils import patch_integration_service, patch_testing_temporal

EMBEDDING_SIZE: int = 1024

Expand Down Expand Up @@ -441,7 +441,7 @@ async def _(
assert result["hello"] == data.input["test"]


@test("workflow: tool call integration type step")
@test("workflow: tool call integration dummy")
async def _(
client=cozo_client,
developer_id=test_developer_id,
Expand Down Expand Up @@ -494,6 +494,65 @@ async def _(
assert result["test"] == data.input["test"]


@test("workflow: tool call integration mocked weather")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"tools": [
{
"type": "integration",
"name": "get_weather",
"integration": {
"provider": "weather",
"setup": {"openweathermap_api_key": "test"},
"arguments": {"test": "fake"},
},
}
],
"main": [
{
"tool": "get_weather",
"arguments": {"location": "_.test"},
},
],
}
),
client=client,
)

expected_output = {"temperature": 20, "humidity": 60}

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
with patch_integration_service(expected_output) as mock_integration_service:
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input
mock_run_task_execution_workflow.assert_called_once()
mock_integration_service.assert_called_once()

result = await handle.result()
assert result == expected_output


# FIXME: This test is not working. It gets stuck
# @test("workflow: wait for input step start")
async def _(
Expand Down Expand Up @@ -1026,3 +1085,4 @@ async def _(
mock_run_task_execution_workflow.assert_called_once()

await handle.result()

10 changes: 10 additions & 0 deletions agents-api/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,13 @@ def patch_embed_acompletion(output={"role": "assistant", "content": "Hello, worl
acompletion.return_value = mock_model_response

yield embed, acompletion


@contextmanager
def patch_integration_service(output: dict = {"result": "ok"}):
with patch(
"agents_api.clients.integrations.run_integration_service"
) as run_integration_service:
run_integration_service.return_value = output

yield run_integration_service
12 changes: 6 additions & 6 deletions typespec/common/scalars.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ scalar JinjaTemplate extends string;
/** Integration provider name */
alias integrationProvider = (
| "dummy"
| "dall-e"
| "duckduckgo"
| "hackernews"
| "dalle_image_generator"
| "duckduckgo_search"
| "hacker_news"
| "weather"
| "wikipedia"
| "twitter"
| "webpage"
| "requests"
// | "twitter"
// | "webpage"
// | "requests"
);
Loading