Skip to content

Commit

Permalink
[FEAT][Auto Swarm]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 30, 2024
1 parent 14b67e7 commit fc3151c
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 1 deletion.
156 changes: 156 additions & 0 deletions servers/autoswarm/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from dotenv import load_dotenv
import os

import torch
import uvicorn
from fastapi import Depends, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from sse_starlette.sse import EventSourceResponse

from swarms_cloud.auth_with_swarms_cloud import authenticate_user
from swarms_cloud.schema.cog_vlm_schemas import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessageResponse,
ModelCard,
ModelList,
UsageInfo,
)
from swarms_cloud.calculate_pricing import calculate_pricing, count_tokens
from swarms_cloud.auth_with_swarms_cloud import fetch_api_key_info
from swarms_cloud.log_api_request_to_supabase import log_to_supabase, ModelAPILogEntry

# Load environment variables from .env file
load_dotenv()

# Environment variables
MODEL_PATH = os.environ.get("COGVLM_MODEL_PATH", "THUDM/cogvlm-chat-hf")
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "lmsys/vicuna-7b-v1.5")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
QUANT_ENABLED = os.environ.get("QUANT_ENABLED", True)

# Create a FastAPI app
app = FastAPI(debug=True)


# Load the middleware to handle CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


@app.get("/v1/models", response_model=ModelList)
async def list_models():
"""
An endpoint to list available models. It returns a list of model cards.
This is useful for clients to query and understand what models are available for use.
"""
model_card = ModelCard(
id="cogvlm-chat-17b"
) # can be replaced by your model id like cogagent-chat-18b
return ModelList(data=[model_card])


@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(
request: ChatCompletionRequest, token: str = Depends(authenticate_user)
):
try:
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")

print(f"Request: {request}")

gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=request.stream,
)

if request.stream:
generate = predict(request.model, gen_params)
return EventSourceResponse(generate, media_type="text/event-stream")

# Generate response
response = generate_cogvlm(model, tokenizer, gen_params)

usage = UsageInfo()

# ChatMessageResponse
message = ChatMessageResponse(
role="assistant",
content=response["text"],
)

# # Log the entry to supabase
entry = ModelAPILogEntry(
user_id=fetch_api_key_info(token),
model_id="41a2869c-5f8d-403f-83bb-1f06c56bad47",
input_tokens=count_tokens(request.messages, tokenizer, request.model),
output_tokens=count_tokens(response["text"], tokenizer, request.model),
all_cost=calculate_pricing(
texts=[message.content], tokenizer=tokenizer, rate_per_million=15.0
),
input_cost=calculate_pricing(
texts=[message.content], tokenizer=tokenizer, rate_per_million=15.0
),
output_cost=calculate_pricing(
texts=response["text"], tokenizer=tokenizer, rate_per_million=15.0
)
* 5,
messages=request.messages,
# temperature=request.temperature,
top_p=request.top_p,
# echo=request.echo,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
max_tokens=request.max_tokens,
)

# Log the entry to supabase
log_to_supabase(entry=entry)

# ChatCompletionResponseChoice
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
)

# task_usage = UsageInfo.model_validate(response["usage"])
task_usage = UsageInfo.parse_obj(response["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)

out = ChatCompletionResponse(
model=request.model,
choices=[choice_data],
object="chat.completion",
usage=usage,
)

return out
except Exception as e:
logger.error(f"Error: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")


if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=int(os.environ.get("MODEL_API_PORT", 8000)),
# workers=5,
log_level="info",
use_colors=True,
# reload=True,
)
2 changes: 1 addition & 1 deletion servers/cogvlm/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
torch_dtype=torch_type,
low_cpu_mem_usage=True,
quantization_config=bnb_config,
)#.eval()
) # .eval()

model = prepare_model_for_ddp_inference(model)

Expand Down
70 changes: 70 additions & 0 deletions swarms_cloud/schema/auto_swarm_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import uuid
from pydantic import BaseModel
from typing import Optional, Sequence, Dict, List

swarm_id = uuid.uuid4()


class AutoSwarmSchemaResponse(BaseModel):
"""
Represents the schema for an auto swarm.
Attributes:
id (str): The ID of the swarm.
api_key (Optional[str]): The API key for the swarm.
swarm_name (Optional[str]): The name of the swarm.
num_of_agents (Optional[int]): The number of agents in the swarm.
messages (Optional[Dict[str, str]]): The messages for the swarm.
num_loops (Optional[int]): The number of loops for the swarm.
streaming (Optional[bool]): Indicates if the swarm is streaming.
tasks (Optional[Sequence[str]]): The tasks for the swarm.
max_tokens (Optional[int]): The maximum number of tokens for the swarm.
documents (Optional[Sequence[str]]): The documents for the swarm.
response_format (Optional[str]): The response format for the swarm.
stopping_token (Optional[List[str]]): The stopping tokens for the swarm.
number_of_choices (Optional[int]): The number of choices for the swarm.
"""

id: str = str(swarm_id)
api_key: Optional[str] = None
swarm_name: Optional[str] = None
num_of_agents: Optional[int] = None
messages: Optional[Dict[str, str]] = None
num_loops: Optional[int] = 1
streaming: Optional[bool] = False
tasks: Optional[Sequence[str]] = None
max_tokens: Optional[int] = 32096
documents: Optional[Sequence[str]] = None
response_format: Optional[str] = None
stopping_token: Optional[List[str]] = []
n: Optional[int] = 1


class AutoSwarmResponse(BaseModel):
"""
Represents the response for an auto swarm.
Attributes:
id (str): The ID of the auto swarm.
swarm_name (Optional[str]): The name of the auto swarm (optional).
num_of_agents (Optional[int]): The number of agents in the auto swarm (optional).
messages (Optional[Dict[str, str]]): Additional messages related to the auto swarm (optional).
num_loops (Optional[int]): The number of loops to run the auto swarm (optional, default: 1).
streaming (Optional[bool]): Indicates if the auto swarm should be streamed (optional, default: False).
tasks (Optional[Sequence[str]]): The tasks to be performed by the auto swarm (optional).
max_tokens (Optional[int]): The maximum number of tokens to generate for each task (optional).
response_format (Optional[str]): The format of the response (optional).
stopping_token (Optional[List[str]]): The stopping token(s) for each task (optional).
number_of_choices (Optional[int]): The number of choices to generate for each task (optional, default: 1).
"""
id: str
swarm_name: Optional[str] = None
num_of_agents: Optional[int] = None
messages: Optional[Dict[str, str]] = None
num_loops: Optional[int] = 1
streaming: Optional[bool] = False
tasks: Optional[Sequence[str]] = None
max_tokens: Optional[int] = None
response_format: Optional[str] = None
stopping_token: Optional[List[str]] = None
number_of_choices: Optional[int] = 1

0 comments on commit fc3151c

Please sign in to comment.