Skip to content

Commit

Permalink
[CLEANUP][API]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Aug 16, 2024
1 parent 7be3ce4 commit 0cc01b2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 53 deletions.
75 changes: 23 additions & 52 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import os

import tiktoken
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from swarms import Agent, OpenAIChat
Expand All @@ -13,58 +12,14 @@
AgentOutput,
ModelList,
ModelSchema,
AllAgentsSchema,
AgentCreationOutput,
)
from swarms_memory import ChromaDB
from swarms.models.tiktoken_wrapper import TikTokenizer

logger.info("Starting the agent API server...")


def count_tokens(text: str) -> int:
try:
# Get the encoding for the specific model
enc = tiktoken.get_encoding("cl100k_base")

# Encode the text
tokens = enc.encode(text)

# Count the tokens
return len(tokens)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error counting tokens: {e}")


# async def model_router(model_name: str):
# """
# Function to switch to the specified model.

# Parameters:
# - model_name (str): The name of the model to switch to.

# Returns:
# - None

# Raises:
# - None

# """
# # Logic to switch to the specified model
# if model_name == "gpt-4o":
# # Switch to OpenAIChat model
# llm = OpenAIChat(max_tokens=4000, model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY"))
# elif model_name == "gpt-4-vision-preview":
# # Switch to GPT4VisionAPI model
# llm = GPT4VisionAPI(
# max_tokens=4000,
# )
# elif model_name == "Anthropic":
# # Switch to Anthropic model
# llm = Anthropic(anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"))
# else:
# # Invalid model name
# raise HTTPException(status_code=400, detail=f"Invalid model name: {model_name}")

# return llm

llm = OpenAIChat(
max_tokens=4000,
model_name="gpt-4o",
Expand Down Expand Up @@ -162,6 +117,23 @@ async def list_models():
return models


@app.get("/v1/agents", response_model=AllAgentsSchema)
async def list_agents(request: Request):
"""
An endpoint to list available models. It returns a list of model names.
This is useful for clients to query and understand what models are available for use.
"""
logger.info("Listing available agents...")

agents = AllAgentsSchema(
AgentCreationOutput(
name="Agent 1",
description="Description 1",
created_at=1628584185,
)
)


@app.post("/v1/agent/completions", response_model=AgentOutput)
async def agent_completions(agent_input: AgentInput):
try:
Expand Down Expand Up @@ -217,14 +189,13 @@ async def agent_completions(agent_input: AgentInput):
logger.info(f"Agent response: {completions}")

# Costs calculation
all_input_tokens = count_tokens(agent_history)
output_tokens = count_tokens(completions)
all_input_tokens = TikTokenizer().count_tokens(agent_history)
output_tokens = TikTokenizer().count_tokens(completions)
total_costs = all_input_tokens + output_tokens
logger.info(f"Token counts: {total_costs}")

# Prepare the output
out = AgentOutput(
agent=agent_input,
completions=ChatCompletionResponse(
model=model_name,
object="chat.completion",
Expand All @@ -238,7 +209,7 @@ async def agent_completions(agent_input: AgentInput):
},
}
],
usage_info=UsageInfo(
usage=UsageInfo(
prompt_tokens=all_input_tokens,
completion_tokens=output_tokens,
total_tokens=total_costs,
Expand Down
5 changes: 4 additions & 1 deletion parallel_swarm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def run_parallel_swarm_completions(


@app.post("v1/swarms", response_model=AllSwarmsSchema)
def get_all_swarms(request: Request):
def get_all_swarms(
request: Request,
Swa,
):
return AllSwarmsSchema(
swarms=[
SwarmAPISchema(
Expand Down
20 changes: 20 additions & 0 deletions swarms_cloud/schema/agent_api_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,23 @@ class ParallelSwarmAPIOutput(BaseModel):
# )

# print(full_example.dict())


class AgentCreationOutput(BaseModel):
id: str = uuid.uuid4().hex
name: str = Field(description="The name of the agent.")
description: str = Field(description="The description of the agent.")
tags: str = Field(
description="The tags associated with the agent, example: Finance Agent, Chat Agent, Math Agent"
)
use_cases: Dict[str, str] = Field(
description="The use cases of the agent, example: {'use_case_1': 'Use case 1 description', 'use_case_2': 'Use case 2 description'}"
)
created_at: int = time.time()
owned_by: str = "TGSC"


class AllAgentsSchema(BaseModel):
agents: List[AgentCreationOutput] = Field(
description="The list of agents available."
)

0 comments on commit 0cc01b2

Please sign in to comment.