Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 20, 2024
1 parent d6ac5ec commit 74eb7c4
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 156 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
EVAL_PORT=8000
MODEL_API_PORT=8000
SERVER="http://localhost:8000"
USE_GPU=True
WORLD_SIZE=4
Expand Down
6 changes: 3 additions & 3 deletions scripts/entry_point.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash

# Default values for environment variables
export EVAL_PORT=${EVAL_PORT:-8000}
export SERVER=${SERVER:-"http://localhost:${EVAL_PORT}"}
export MODEL_API_PORT=${MODEL_API_PORT:-8000}
export SERVER=${SERVER:-"http://localhost:${MODEL_API_PORT}"}
export USE_GPU=${USE_GPU:-True}
export WORLD_SIZE=${WORLD_SIZE:-4}
export PSG_CONNECTION_STRING=${PSG_CONNECTION_STRING:-""}
Expand All @@ -12,4 +12,4 @@ export SUPABASE_KEY=${SUPABASE_KEY:-""}
export HF_HUB_ENABLE_HF_TRANSFER=${HF_HUB_ENABLE_HF_TRANSFER:-True}

# Run the application
exec python3.10 -m uvicorn cogvlm:app --host 0.0.0.0 --port $EVAL_PORT
exec python3.10 -m uvicorn cogvlm:app --host 0.0.0.0 --port $MODEL_API_PORT
232 changes: 80 additions & 152 deletions servers/cogvlm/cogvlm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import base64
import gc
import os
import time
from contextlib import asynccontextmanager
from io import BytesIO
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Optional, Tuple

import torch
import uvicorn
from fastapi import Depends, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from transformers import (
AutoModelForCausalLM,
Expand All @@ -26,11 +23,21 @@
from swarms_cloud.auth_with_swarms_cloud import authenticate_user, fetch_api_key_info
from swarms_cloud.calculate_pricing import calculate_pricing, count_tokens
from swarms_cloud.log_api_request_to_supabase import ModelAPILogEntry, log_to_supabase

# from swarms_cloud.supabase_logger import SupabaseLogger

# Supabase logger
# supabase_logger = SupabaseLogger("swarm_cloud_usage")
from swarms_cloud.schema.cog_vlm_schemas import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatMessageInput,
ChatMessageResponse,
DeltaMessage,
ImageUrlContent,
ModelCard,
ModelList,
TextContent,
UsageInfo,
)
from swarms_cloud.utils.count_cores_for_workers import calculate_workers

# Environment variables
MODEL_PATH = os.environ.get("COGVLM_MODEL_PATH", "THUDM/cogvlm-chat-hf")
Expand All @@ -39,53 +46,6 @@
QUANT_ENABLED = os.environ.get("QUANT_ENABLED", True)


# Model and tokenizer
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16

print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n")

quantization_config = {
"load_in_4bit": True,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch_type,
}

bnb_config = BitsAndBytesConfig(**quantization_config)

model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
# load_in_4bit=True,
trust_remote_code=True,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
# attn_implementation="flash_attention_2",
quantization_config=bnb_config,
).eval()

# Torch type
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16


if os.environ.get("QUANT_ENABLED"):
QUANT_ENABLED = True
else:
with torch.cuda.device(DEVICE):
__, total_bytes = torch.cuda.mem_get_info()
total_gb = total_bytes / (1 << 30)
if total_gb < 40:
QUANT_ENABLED = True
else:
QUANT_ENABLED = False


@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Expand All @@ -98,8 +58,11 @@ async def lifespan(app: FastAPI):
torch.cuda.ipc_collect()


app = FastAPI(lifespan=lifespan)
# Create a FastAPI app
app = FastAPI(lifespan=lifespan, debug=True)


# Load the middleware to handle CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand All @@ -109,103 +72,54 @@ async def lifespan(app: FastAPI):
)


class ModelCard(BaseModel):
"""
A Pydantic model representing a model card, which provides metadata about a machine learning model.
It includes fields like model ID, owner, and creation time.
"""

id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None


class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []


class ImageUrl(BaseModel):
url: str


class TextContent(BaseModel):
type: Literal["text"]
text: str


class ImageUrlContent(BaseModel):
type: Literal["image_url"]
image_url: ImageUrl


ContentItem = Union[TextContent, ImageUrlContent]


class ChatMessageInput(BaseModel):
role: Literal["user", "assistant", "system"]
content: Union[str, List[ContentItem]]
name: Optional[str] = None

# On startup
@app.on_event("startup")
async def load_model():
global model, tokenizer, torch_type, QUANT_ENABLED

class ChatMessageResponse(BaseModel):
role: Literal["assistant"]
content: str = None
name: Optional[str] = None
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16

class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None


class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessageInput]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
# Additional parameters
repetition_penalty: Optional[float] = 1.0


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessageResponse


class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage


class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0


class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[
Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n")

quantization_config = {
"load_in_4bit": True,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch_type,
}

@app.on_event("shutdown")
async def shutdown_event():
print("Application shutdown, cleaning up artifacts")
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
bnb_config = BitsAndBytesConfig(**quantization_config)

model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
# load_in_4bit=True,
trust_remote_code=True,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
# attn_implementation="flash_attention_2",
quantization_config=bnb_config,
).eval()

# Torch type
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16

if os.environ.get("QUANT_ENABLED"):
QUANT_ENABLED = True
else:
with torch.cuda.device(DEVICE):
__, total_bytes = torch.cuda.mem_get_info()
total_gb = total_bytes / (1 << 30)
if total_gb < 40:
QUANT_ENABLED = True
else:
QUANT_ENABLED = False


@app.get("/v1/models", response_model=ModelList)
Expand Down Expand Up @@ -525,9 +439,23 @@ def generate_stream_cogvlm(
yield ret


gc.collect()
torch.cuda.empty_cache()
@app.on_event("shutdown")
async def shutdown_event():
print("Application shutdown, cleaning up artifacts")
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except Exception as e:
print(f"Error during shutdown: {e}")


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
uvicorn.run(
app,
host="0.0.0.0",
port=os.environ.get("MODEL_API_PORT", 8000),
workers=calculate_workers(),
log_level="info",
use_colors=True,
)
Loading

0 comments on commit 74eb7c4

Please sign in to comment.