From fc3151c865e41972d657bba024fadc24555aec69 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 30 Mar 2024 00:02:39 -0700 Subject: [PATCH] [FEAT][Auto Swarm] --- servers/autoswarm/api.py | 156 +++++++++++++++++++++++ servers/cogvlm/cogvlm.py | 2 +- swarms_cloud/schema/auto_swarm_schema.py | 70 ++++++++++ 3 files changed, 227 insertions(+), 1 deletion(-) create mode 100644 servers/autoswarm/api.py create mode 100644 swarms_cloud/schema/auto_swarm_schema.py diff --git a/servers/autoswarm/api.py b/servers/autoswarm/api.py new file mode 100644 index 0000000..55ae38a --- /dev/null +++ b/servers/autoswarm/api.py @@ -0,0 +1,156 @@ +from dotenv import load_dotenv +import os + +import torch +import uvicorn +from fastapi import Depends, FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from loguru import logger +from sse_starlette.sse import EventSourceResponse + +from swarms_cloud.auth_with_swarms_cloud import authenticate_user +from swarms_cloud.schema.cog_vlm_schemas import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessageResponse, + ModelCard, + ModelList, + UsageInfo, +) +from swarms_cloud.calculate_pricing import calculate_pricing, count_tokens +from swarms_cloud.auth_with_swarms_cloud import fetch_api_key_info +from swarms_cloud.log_api_request_to_supabase import log_to_supabase, ModelAPILogEntry + +# Load environment variables from .env file +load_dotenv() + +# Environment variables +MODEL_PATH = os.environ.get("COGVLM_MODEL_PATH", "THUDM/cogvlm-chat-hf") +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "lmsys/vicuna-7b-v1.5") +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +QUANT_ENABLED = os.environ.get("QUANT_ENABLED", True) + +# Create a FastAPI app +app = FastAPI(debug=True) + + +# Load the middleware to handle CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/v1/models", response_model=ModelList) +async def list_models(): + """ + An endpoint to list available models. It returns a list of model cards. + This is useful for clients to query and understand what models are available for use. + """ + model_card = ModelCard( + id="cogvlm-chat-17b" + ) # can be replaced by your model id like cogagent-chat-18b + return ModelList(data=[model_card]) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion( + request: ChatCompletionRequest, token: str = Depends(authenticate_user) +): + try: + if len(request.messages) < 1 or request.messages[-1].role == "assistant": + raise HTTPException(status_code=400, detail="Invalid request") + + print(f"Request: {request}") + + gen_params = dict( + messages=request.messages, + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens or 1024, + echo=False, + stream=request.stream, + ) + + if request.stream: + generate = predict(request.model, gen_params) + return EventSourceResponse(generate, media_type="text/event-stream") + + # Generate response + response = generate_cogvlm(model, tokenizer, gen_params) + + usage = UsageInfo() + + # ChatMessageResponse + message = ChatMessageResponse( + role="assistant", + content=response["text"], + ) + + # # Log the entry to supabase + entry = ModelAPILogEntry( + user_id=fetch_api_key_info(token), + model_id="41a2869c-5f8d-403f-83bb-1f06c56bad47", + input_tokens=count_tokens(request.messages, tokenizer, request.model), + output_tokens=count_tokens(response["text"], tokenizer, request.model), + all_cost=calculate_pricing( + texts=[message.content], tokenizer=tokenizer, rate_per_million=15.0 + ), + input_cost=calculate_pricing( + texts=[message.content], tokenizer=tokenizer, rate_per_million=15.0 + ), + output_cost=calculate_pricing( + texts=response["text"], tokenizer=tokenizer, rate_per_million=15.0 + ) + * 5, + messages=request.messages, + # temperature=request.temperature, + top_p=request.top_p, + # echo=request.echo, + stream=request.stream, + repetition_penalty=request.repetition_penalty, + max_tokens=request.max_tokens, + ) + + # Log the entry to supabase + log_to_supabase(entry=entry) + + # ChatCompletionResponseChoice + logger.debug(f"==== message ====\n{message}") + choice_data = ChatCompletionResponseChoice( + index=0, + message=message, + ) + + # task_usage = UsageInfo.model_validate(response["usage"]) + task_usage = UsageInfo.parse_obj(response["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + out = ChatCompletionResponse( + model=request.model, + choices=[choice_data], + object="chat.completion", + usage=usage, + ) + + return out + except Exception as e: + logger.error(f"Error: {e}") + raise HTTPException(status_code=500, detail="Internal Server Error") + + +if __name__ == "__main__": + uvicorn.run( + app, + host="0.0.0.0", + port=int(os.environ.get("MODEL_API_PORT", 8000)), + # workers=5, + log_level="info", + use_colors=True, + # reload=True, + ) diff --git a/servers/cogvlm/cogvlm.py b/servers/cogvlm/cogvlm.py index b1efded..0a020dc 100644 --- a/servers/cogvlm/cogvlm.py +++ b/servers/cogvlm/cogvlm.py @@ -89,7 +89,7 @@ torch_dtype=torch_type, low_cpu_mem_usage=True, quantization_config=bnb_config, -)#.eval() +) # .eval() model = prepare_model_for_ddp_inference(model) diff --git a/swarms_cloud/schema/auto_swarm_schema.py b/swarms_cloud/schema/auto_swarm_schema.py new file mode 100644 index 0000000..93c8d0b --- /dev/null +++ b/swarms_cloud/schema/auto_swarm_schema.py @@ -0,0 +1,70 @@ +import uuid +from pydantic import BaseModel +from typing import Optional, Sequence, Dict, List + +swarm_id = uuid.uuid4() + + +class AutoSwarmSchemaResponse(BaseModel): + """ + Represents the schema for an auto swarm. + + Attributes: + id (str): The ID of the swarm. + api_key (Optional[str]): The API key for the swarm. + swarm_name (Optional[str]): The name of the swarm. + num_of_agents (Optional[int]): The number of agents in the swarm. + messages (Optional[Dict[str, str]]): The messages for the swarm. + num_loops (Optional[int]): The number of loops for the swarm. + streaming (Optional[bool]): Indicates if the swarm is streaming. + tasks (Optional[Sequence[str]]): The tasks for the swarm. + max_tokens (Optional[int]): The maximum number of tokens for the swarm. + documents (Optional[Sequence[str]]): The documents for the swarm. + response_format (Optional[str]): The response format for the swarm. + stopping_token (Optional[List[str]]): The stopping tokens for the swarm. + number_of_choices (Optional[int]): The number of choices for the swarm. + """ + + id: str = str(swarm_id) + api_key: Optional[str] = None + swarm_name: Optional[str] = None + num_of_agents: Optional[int] = None + messages: Optional[Dict[str, str]] = None + num_loops: Optional[int] = 1 + streaming: Optional[bool] = False + tasks: Optional[Sequence[str]] = None + max_tokens: Optional[int] = 32096 + documents: Optional[Sequence[str]] = None + response_format: Optional[str] = None + stopping_token: Optional[List[str]] = [] + n: Optional[int] = 1 + + +class AutoSwarmResponse(BaseModel): + """ + Represents the response for an auto swarm. + + Attributes: + id (str): The ID of the auto swarm. + swarm_name (Optional[str]): The name of the auto swarm (optional). + num_of_agents (Optional[int]): The number of agents in the auto swarm (optional). + messages (Optional[Dict[str, str]]): Additional messages related to the auto swarm (optional). + num_loops (Optional[int]): The number of loops to run the auto swarm (optional, default: 1). + streaming (Optional[bool]): Indicates if the auto swarm should be streamed (optional, default: False). + tasks (Optional[Sequence[str]]): The tasks to be performed by the auto swarm (optional). + max_tokens (Optional[int]): The maximum number of tokens to generate for each task (optional). + response_format (Optional[str]): The format of the response (optional). + stopping_token (Optional[List[str]]): The stopping token(s) for each task (optional). + number_of_choices (Optional[int]): The number of choices to generate for each task (optional, default: 1). + """ + id: str + swarm_name: Optional[str] = None + num_of_agents: Optional[int] = None + messages: Optional[Dict[str, str]] = None + num_loops: Optional[int] = 1 + streaming: Optional[bool] = False + tasks: Optional[Sequence[str]] = None + max_tokens: Optional[int] = None + response_format: Optional[str] = None + stopping_token: Optional[List[str]] = None + number_of_choices: Optional[int] = 1