generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye Gomez
authored and
Kye Gomez
committed
Jun 18, 2024
1 parent
8a6c915
commit 0b547da
Showing
9 changed files
with
279 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import asyncio | ||
import os | ||
from typing import List | ||
|
||
import tiktoken | ||
from fastapi import Body, FastAPI, HTTPException | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from pydantic import BaseModel | ||
from swarms import Agent, Anthropic, GPT4o, GPT4VisionAPI, OpenAIChat | ||
from swarms.utils.loguru_logger import logger | ||
|
||
from swarms_cloud.schema.cog_vlm_schemas import ( | ||
ChatCompletionResponse, | ||
ModelCard, | ||
ModelList, | ||
UsageInfo, | ||
) | ||
|
||
|
||
async def count_tokens( | ||
text: str, | ||
): | ||
try: | ||
# Get the encoding for the specific model | ||
encoding = tiktoken.get_encoding("gpt-4o") | ||
|
||
# Encode the text | ||
tokens = encoding.encode(text) | ||
|
||
# Count the tokens | ||
token_count = len(tokens) | ||
|
||
return token_count | ||
except Exception as e: | ||
raise HTTPException(status_code=400, detail=str(e)) | ||
|
||
|
||
async def model_router(model_name: str): | ||
""" | ||
Function to switch to the specified model. | ||
Parameters: | ||
- model_name (str): The name of the model to switch to. | ||
Returns: | ||
- None | ||
Raises: | ||
- None | ||
""" | ||
# Logic to switch to the specified model | ||
if model_name == "OpenAIChat": | ||
# Switch to OpenAIChat model | ||
llm = OpenAIChat() | ||
elif model_name == "GPT4o": | ||
# Switch to GPT4o model | ||
llm = GPT4o(openai_api_key=os.getenv("OPENAI_API_KEY")) | ||
elif model_name == "GPT4VisionAPI": | ||
# Switch to GPT4VisionAPI model | ||
llm = GPT4VisionAPI() | ||
elif model_name == "Anthropic": | ||
# Switch to Anthropic model | ||
llm = Anthropic(anthropic_api_key=os.getenv("ANTHROPIC_API_KEY")) | ||
else: | ||
# Invalid model name | ||
pass | ||
|
||
return llm | ||
|
||
|
||
# Define the input model using Pydantic | ||
class AgentInput(BaseModel): | ||
agent_name: str = "Swarm Agent" | ||
system_prompt: str = None | ||
agent_description: str = None | ||
model_name: str = "OpenAIChat" | ||
max_loops: int = 1 | ||
autosave: bool = False | ||
dynamic_temperature_enabled: bool = False | ||
dashboard: bool = False | ||
verbose: bool = False | ||
streaming_on: bool = True | ||
saved_state_path: str = None | ||
sop: str = None | ||
sop_list: List[str] = None | ||
user_name: str = "User" | ||
retry_attempts: int = 3 | ||
context_length: int = 8192 | ||
task: str = None | ||
|
||
|
||
# Define the input model using Pydantic | ||
class AgentOutput(BaseModel): | ||
agent: AgentInput | ||
completions: ChatCompletionResponse | ||
|
||
|
||
# Create a FastAPI app | ||
app = FastAPI(debug=True) | ||
|
||
# Load the middleware to handle CORS | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
|
||
|
||
@app.get("/v1/models", response_model=ModelList) | ||
async def list_models(): | ||
""" | ||
An endpoint to list available models. It returns a list of model cards. | ||
This is useful for clients to query and understand what models are available for use. | ||
""" | ||
model_card = ModelCard( | ||
id="cogvlm-chat-17b" | ||
) # can be replaced by your model id like cogagent-chat-18b | ||
return ModelList(data=[model_card]) | ||
|
||
|
||
@app.post("v1/agent/completions", response_model=AgentOutput) | ||
async def agent_completions(agent_input: AgentInput = Body(...)): | ||
logger.info(f"Received request: {agent_input}") | ||
|
||
llm = model_router(agent_input.model_name) | ||
|
||
agent = Agent( | ||
agent_name=agent_input.agent_name, | ||
system_prompt=agent_input.system_prompt, | ||
agent_description=agent_input.agent_description, | ||
llm=llm, | ||
max_loops=agent_input.max_loops, | ||
autosave=agent_input.autosave, | ||
dynamic_temperature_enabled=agent_input.dynamic_temperature_enabled, | ||
dashboard=agent_input.dashboard, | ||
verbose=agent_input.verbose, | ||
streaming_on=agent_input.streaming_on, | ||
saved_state_path=agent_input.saved_state_path, | ||
sop=agent_input.sop, | ||
sop_list=agent_input.sop_list, | ||
user_name=agent_input.user_name, | ||
retry_attempts=agent_input.retry_attempts, | ||
context_length=agent_input.context_length, | ||
) | ||
|
||
# Run the agent | ||
completions = await agent.run(agent_input.task) | ||
|
||
all_input_tokens, output_tokens = await asyncio.gather( | ||
count_tokens(agent.short_memory.return_history_as_string()), | ||
count_tokens(completions), | ||
) | ||
|
||
return AgentOutput( | ||
agent=agent_input, | ||
completions=ChatCompletionResponse( | ||
choices=[ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": agent_input.agent_name, | ||
"content": completions, | ||
"name": None, | ||
}, | ||
} | ||
], | ||
stream_choices=None, | ||
usage_info=UsageInfo( | ||
prompt_tokens=all_input_tokens, | ||
completion_tokens=output_tokens, | ||
total_tokens=all_input_tokens + output_tokens, | ||
), | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,10 @@ | ||
from dotenv import load_dotenv | ||
from openai import OpenAI | ||
from swarms import llama3Hosted | ||
|
||
load_dotenv() | ||
openai_api_key = "sk-9c34d01b0095c16b987d925402fb283972ec64548828ca8ae321930e4c45745d" | ||
|
||
openai_api_base = "https://api.swarms.world/v1" | ||
model = "Meta-Llama-3-8B-Instruct" | ||
|
||
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base) | ||
# Note that this model expects the image to come before the main text | ||
chat_response = client.chat.completions.create( | ||
model=model, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": "What's in this image?"}, | ||
], | ||
} | ||
], | ||
llama3 = llama3Hosted( | ||
model="meta-llama/Meta-Llama-3-8B", | ||
base_url="http://199.204.135.78:8090/v1/chat/completions", | ||
temperature=0.1, | ||
max_tokens=3400, | ||
) | ||
print("Chat response:", chat_response) | ||
|
||
out = llama3.run("what is your name?") | ||
print(out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" | |
|
||
[tool.poetry] | ||
name = "swarms-cloud" | ||
version = "0.2.5" | ||
version = "0.2.6" | ||
description = "Swarms Cloud - Pytorch" | ||
license = "MIT" | ||
authors = ["Kye Gomez <[email protected]>"] | ||
|
@@ -24,7 +24,7 @@ classifiers = [ | |
[tool.poetry.dependencies] | ||
python = "^3.10" | ||
swarms = "*" | ||
fastapi = "0.110.1" | ||
fastapi = "*" | ||
skypilot = "*" | ||
torch = "*" | ||
einops = "*" | ||
|
@@ -34,7 +34,6 @@ transformers = "*" | |
sse-starlette = "2.1.0" | ||
uvicorn = "*" | ||
shortuuid = "*" | ||
xformers = "*" | ||
|
||
[tool.poetry.group.lint.dependencies] | ||
ruff = ">=0.1.6,<0.4.0" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#!/bin/bash | ||
|
||
# Environment Variables | ||
export MODEL_NAME=meta-llama/Meta-Llama-3-8B | ||
export HF_TOKEN=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 # Change to your own huggingface token. | ||
export HF_HUB_ENABLE_HF_TRANSFER=True | ||
|
||
# Setup | ||
conda activate vllm | ||
if [ $? -ne 0 ]; then | ||
conda create -n vllm python=3.10 -y | ||
conda activate vllm | ||
fi | ||
|
||
pip install vllm==0.4.0.post1 | ||
pip install gradio openai | ||
pip install flash-attn | ||
pip install hf_transfer | ||
|
||
# Function to print colored log statements | ||
log() { | ||
local GREEN='\033[0;32m' | ||
local NC='\033[0m' # No Color | ||
echo -e "${GREEN}[LOG] $1${NC}" | ||
} | ||
|
||
# Run VLM | ||
conda activate vllm | ||
log "Starting vllm api server..." | ||
export PATH=$PATH:/sbin | ||
|
||
python3 -u -m vllm.entrypoints.openai.api_server \ | ||
--port 8090 \ | ||
--model $MODEL_NAME \ | ||
--trust-remote-code --tensor-parallel-size 4 \ | ||
--gpu-memory-utilization 0.95 \ | ||
--max-num-seqs 64 \ | ||
>> /var/log/vllm_api.log 2>&1 & | ||
|
||
# Check if VLM server started successfully | ||
if [ $? -eq 0 ]; then | ||
log "VLLM API server started successfully." | ||
else | ||
log "Failed to start VLLM API server." | ||
exit 1 | ||
fi | ||
|
||
# Run Gradio | ||
log "Starting gradio server..." | ||
git clone https://github.com/vllm-project/vllm.git || true | ||
|
||
python3 vllm/examples/gradio_openai_chatbot_webserver.py \ | ||
-m $MODEL_NAME \ | ||
--port 8811 \ | ||
--model-url http://localhost:8081/v1 \ | ||
--stop-token-ids 128009,128001 \ | ||
>> /var/log/gradio_server.log 2>&1 & | ||
|
||
# Check if Gradio server started successfully | ||
if [ $? -eq 0 ]; then | ||
log "Gradio server started successfully." | ||
else | ||
log "Failed to start Gradio server." | ||
exit 1 | ||
fi | ||
|
||
|
||
|
||
919039 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.