diff --git a/send_local_request_to_cogvlm.py b/send_local_request_to_cogvlm.py index ec18c52..0825c5d 100644 --- a/send_local_request_to_cogvlm.py +++ b/send_local_request_to_cogvlm.py @@ -14,6 +14,7 @@ # Swarms Cloud API key swarms_cloud_api_key = os.getenv("SWARMS_CLOUD_API_KEY") + # Convert image to Base64 def image_to_base64(image_path): with Image.open(image_path) as image: diff --git a/servers/cogvlm/cogvlm.py b/servers/cogvlm/cogvlm.py index 96715dd..2b51ed4 100644 --- a/servers/cogvlm/cogvlm.py +++ b/servers/cogvlm/cogvlm.py @@ -47,6 +47,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" QUANT_ENABLED = os.environ.get("QUANT_ENABLED", True) + @asynccontextmanager async def lifespan(app: FastAPI): """ @@ -72,41 +73,37 @@ async def lifespan(app: FastAPI): allow_headers=["*"], ) -@app.on_event("startup") -async def load_model(): - global model, tokenizer, torch_type - - # Load the tokenizer and model - tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) +# Load the tokenizer and model +tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) - if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: - torch_type = torch.bfloat16 - else: - torch_type = torch.float16 +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: + torch_type = torch.bfloat16 +else: + torch_type = torch.float16 - print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n") +print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n") - quantization_config = { - "load_in_4bit": True, - "bnb_4bit_use_double_quant": True, - "bnb_4bit_compute_dtype": torch_type, - } +quantization_config = { + "load_in_4bit": True, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_compute_dtype": torch_type, +} - bnb_config = BitsAndBytesConfig(**quantization_config) +bnb_config = BitsAndBytesConfig(**quantization_config) - model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, - trust_remote_code=True, - torch_dtype=torch_type, - low_cpu_mem_usage=True, - quantization_config=bnb_config, - ).eval() +model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + torch_dtype=torch_type, + low_cpu_mem_usage=True, + quantization_config=bnb_config, +).eval() - # Torch type - if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: - torch_type = torch.bfloat16 - else: - torch_type = torch.float16 +# Torch type +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: + torch_type = torch.bfloat16 +else: + torch_type = torch.float16 @app.get("/v1/models", response_model=ModelList) @@ -125,8 +122,6 @@ async def list_models(): async def create_chat_completion( request: ChatCompletionRequest, token: str = Depends(authenticate_user) ): - global model, tokenizer, torch_type - try: if len(request.messages) < 1 or request.messages[-1].role == "assistant": raise HTTPException(status_code=400, detail="Invalid request")