diff --git a/.env.example b/.env.example index 1003c5b..ac4aa81 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,4 @@ -EVAL_PORT=8000 +MODEL_API_PORT=8000 SERVER="http://localhost:8000" USE_GPU=True WORLD_SIZE=4 diff --git a/scripts/entry_point.sh b/scripts/entry_point.sh index 3a72106..f97d8ef 100644 --- a/scripts/entry_point.sh +++ b/scripts/entry_point.sh @@ -1,8 +1,8 @@ #!/bin/bash # Default values for environment variables -export EVAL_PORT=${EVAL_PORT:-8000} -export SERVER=${SERVER:-"http://localhost:${EVAL_PORT}"} +export MODEL_API_PORT=${MODEL_API_PORT:-8000} +export SERVER=${SERVER:-"http://localhost:${MODEL_API_PORT}"} export USE_GPU=${USE_GPU:-True} export WORLD_SIZE=${WORLD_SIZE:-4} export PSG_CONNECTION_STRING=${PSG_CONNECTION_STRING:-""} @@ -12,4 +12,4 @@ export SUPABASE_KEY=${SUPABASE_KEY:-""} export HF_HUB_ENABLE_HF_TRANSFER=${HF_HUB_ENABLE_HF_TRANSFER:-True} # Run the application -exec python3.10 -m uvicorn cogvlm:app --host 0.0.0.0 --port $EVAL_PORT +exec python3.10 -m uvicorn cogvlm:app --host 0.0.0.0 --port $MODEL_API_PORT diff --git a/servers/cogvlm/cogvlm.py b/servers/cogvlm/cogvlm.py index e40d485..ebdc8f9 100644 --- a/servers/cogvlm/cogvlm.py +++ b/servers/cogvlm/cogvlm.py @@ -1,10 +1,8 @@ import base64 -import gc import os -import time from contextlib import asynccontextmanager from io import BytesIO -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import uvicorn @@ -12,7 +10,6 @@ from fastapi.middleware.cors import CORSMiddleware from loguru import logger from PIL import Image -from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse from transformers import ( AutoModelForCausalLM, @@ -26,11 +23,21 @@ from swarms_cloud.auth_with_swarms_cloud import authenticate_user, fetch_api_key_info from swarms_cloud.calculate_pricing import calculate_pricing, count_tokens from swarms_cloud.log_api_request_to_supabase import ModelAPILogEntry, log_to_supabase - -# from swarms_cloud.supabase_logger import SupabaseLogger - -# Supabase logger -# supabase_logger = SupabaseLogger("swarm_cloud_usage") +from swarms_cloud.schema.cog_vlm_schemas import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatMessageInput, + ChatMessageResponse, + DeltaMessage, + ImageUrlContent, + ModelCard, + ModelList, + TextContent, + UsageInfo, +) +from swarms_cloud.utils.count_cores_for_workers import calculate_workers # Environment variables MODEL_PATH = os.environ.get("COGVLM_MODEL_PATH", "THUDM/cogvlm-chat-hf") @@ -39,53 +46,6 @@ QUANT_ENABLED = os.environ.get("QUANT_ENABLED", True) -# Model and tokenizer -tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) - -if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: - torch_type = torch.bfloat16 -else: - torch_type = torch.float16 - -print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n") - -quantization_config = { - "load_in_4bit": True, - "bnb_4bit_use_double_quant": True, - "bnb_4bit_compute_dtype": torch_type, -} - -bnb_config = BitsAndBytesConfig(**quantization_config) - -model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, - # load_in_4bit=True, - trust_remote_code=True, - torch_dtype=torch_type, - low_cpu_mem_usage=True, - # attn_implementation="flash_attention_2", - quantization_config=bnb_config, -).eval() - -# Torch type -if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: - torch_type = torch.bfloat16 -else: - torch_type = torch.float16 - - -if os.environ.get("QUANT_ENABLED"): - QUANT_ENABLED = True -else: - with torch.cuda.device(DEVICE): - __, total_bytes = torch.cuda.mem_get_info() - total_gb = total_bytes / (1 << 30) - if total_gb < 40: - QUANT_ENABLED = True - else: - QUANT_ENABLED = False - - @asynccontextmanager async def lifespan(app: FastAPI): """ @@ -98,8 +58,11 @@ async def lifespan(app: FastAPI): torch.cuda.ipc_collect() -app = FastAPI(lifespan=lifespan) +# Create a FastAPI app +app = FastAPI(lifespan=lifespan, debug=True) + +# Load the middleware to handle CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -109,103 +72,54 @@ async def lifespan(app: FastAPI): ) -class ModelCard(BaseModel): - """ - A Pydantic model representing a model card, which provides metadata about a machine learning model. - It includes fields like model ID, owner, and creation time. - """ - - id: str - object: str = "model" - created: int = Field(default_factory=lambda: int(time.time())) - owned_by: str = "owner" - root: Optional[str] = None - parent: Optional[str] = None - permission: Optional[list] = None - - -class ModelList(BaseModel): - object: str = "list" - data: List[ModelCard] = [] - - -class ImageUrl(BaseModel): - url: str - - -class TextContent(BaseModel): - type: Literal["text"] - text: str - - -class ImageUrlContent(BaseModel): - type: Literal["image_url"] - image_url: ImageUrl - - -ContentItem = Union[TextContent, ImageUrlContent] - - -class ChatMessageInput(BaseModel): - role: Literal["user", "assistant", "system"] - content: Union[str, List[ContentItem]] - name: Optional[str] = None - +# On startup +@app.on_event("startup") +async def load_model(): + global model, tokenizer, torch_type, QUANT_ENABLED -class ChatMessageResponse(BaseModel): - role: Literal["assistant"] - content: str = None - name: Optional[str] = None + tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: + torch_type = torch.bfloat16 + else: + torch_type = torch.float16 -class DeltaMessage(BaseModel): - role: Optional[Literal["user", "assistant", "system"]] = None - content: Optional[str] = None - - -class ChatCompletionRequest(BaseModel): - model: str - messages: List[ChatMessageInput] - temperature: Optional[float] = 0.8 - top_p: Optional[float] = 0.8 - max_tokens: Optional[int] = None - stream: Optional[bool] = False - # Additional parameters - repetition_penalty: Optional[float] = 1.0 - - -class ChatCompletionResponseChoice(BaseModel): - index: int - message: ChatMessageResponse - - -class ChatCompletionResponseStreamChoice(BaseModel): - index: int - delta: DeltaMessage - - -class UsageInfo(BaseModel): - prompt_tokens: int = 0 - total_tokens: int = 0 - completion_tokens: Optional[int] = 0 - - -class ChatCompletionResponse(BaseModel): - model: str - object: Literal["chat.completion", "chat.completion.chunk"] - choices: List[ - Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] - ] - created: Optional[int] = Field(default_factory=lambda: int(time.time())) - usage: Optional[UsageInfo] = None + print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n") + quantization_config = { + "load_in_4bit": True, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_compute_dtype": torch_type, + } -@app.on_event("shutdown") -async def shutdown_event(): - print("Application shutdown, cleaning up artifacts") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + bnb_config = BitsAndBytesConfig(**quantization_config) + + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + # load_in_4bit=True, + trust_remote_code=True, + torch_dtype=torch_type, + low_cpu_mem_usage=True, + # attn_implementation="flash_attention_2", + quantization_config=bnb_config, + ).eval() + + # Torch type + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: + torch_type = torch.bfloat16 + else: + torch_type = torch.float16 + + if os.environ.get("QUANT_ENABLED"): + QUANT_ENABLED = True + else: + with torch.cuda.device(DEVICE): + __, total_bytes = torch.cuda.mem_get_info() + total_gb = total_bytes / (1 << 30) + if total_gb < 40: + QUANT_ENABLED = True + else: + QUANT_ENABLED = False @app.get("/v1/models", response_model=ModelList) @@ -525,9 +439,23 @@ def generate_stream_cogvlm( yield ret -gc.collect() -torch.cuda.empty_cache() +@app.on_event("shutdown") +async def shutdown_event(): + print("Application shutdown, cleaning up artifacts") + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except Exception as e: + print(f"Error during shutdown: {e}") if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) + uvicorn.run( + app, + host="0.0.0.0", + port=os.environ.get("MODEL_API_PORT", 8000), + workers=calculate_workers(), + log_level="info", + use_colors=True, + ) diff --git a/swarms_cloud/schema/cog_vlm_schemas.py b/swarms_cloud/schema/cog_vlm_schemas.py new file mode 100644 index 0000000..434753a --- /dev/null +++ b/swarms_cloud/schema/cog_vlm_schemas.py @@ -0,0 +1,95 @@ +import time +from typing import List, Literal, Optional, Union + +from pydantic import BaseModel, Field + + +class ModelCard(BaseModel): + """ + A Pydantic model representing a model card, which provides metadata about a machine learning model. + It includes fields like model ID, owner, and creation time. + """ + + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "owner" + root: Optional[str] = None + parent: Optional[str] = None + permission: Optional[list] = None + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class ImageUrl(BaseModel): + url: str + + +class TextContent(BaseModel): + type: Literal["text"] + text: str + + +class ImageUrlContent(BaseModel): + type: Literal["image_url"] + image_url: ImageUrl + + +ContentItem = Union[TextContent, ImageUrlContent] + + +class ChatMessageInput(BaseModel): + role: Literal["user", "assistant", "system"] + content: Union[str, List[ContentItem]] + name: Optional[str] = None + + +class ChatMessageResponse(BaseModel): + role: Literal["assistant"] + content: str = None + name: Optional[str] = None + + +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessageInput] + temperature: Optional[float] = 0.8 + top_p: Optional[float] = 0.8 + max_tokens: Optional[int] = None + stream: Optional[bool] = False + # Additional parameters + repetition_penalty: Optional[float] = 1.0 + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessageResponse + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ChatCompletionResponse(BaseModel): + model: str + object: Literal["chat.completion", "chat.completion.chunk"] + choices: List[ + Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] + ] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + usage: Optional[UsageInfo] = None diff --git a/swarms_cloud/utils/count_cores_for_workers.py b/swarms_cloud/utils/count_cores_for_workers.py new file mode 100644 index 0000000..1b63f23 --- /dev/null +++ b/swarms_cloud/utils/count_cores_for_workers.py @@ -0,0 +1,14 @@ +import multiprocessing + + +def calculate_workers(): + """ + Calculates the number of workers based on the number of CPU cores. + + Returns: + int: The number of workers. + """ + cores = multiprocessing.cpu_count() + workers = 2 * cores + 1 + + return workers