Skip to content

Commit

Permalink
[REVERT]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 21, 2024
1 parent 559186c commit 2e2a128
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 31 deletions.
1 change: 1 addition & 0 deletions send_local_request_to_cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# Swarms Cloud API key
swarms_cloud_api_key = os.getenv("SWARMS_CLOUD_API_KEY")


# Convert image to Base64
def image_to_base64(image_path):
with Image.open(image_path) as image:
Expand Down
57 changes: 26 additions & 31 deletions servers/cogvlm/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
QUANT_ENABLED = os.environ.get("QUANT_ENABLED", True)


@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Expand All @@ -72,41 +73,37 @@ async def lifespan(app: FastAPI):
allow_headers=["*"],
)

@app.on_event("startup")
async def load_model():
global model, tokenizer, torch_type

# Load the tokenizer and model
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
# Load the tokenizer and model
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16

print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n")
print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n")

quantization_config = {
"load_in_4bit": True,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch_type,
}
quantization_config = {
"load_in_4bit": True,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch_type,
}

bnb_config = BitsAndBytesConfig(**quantization_config)
bnb_config = BitsAndBytesConfig(**quantization_config)

model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
quantization_config=bnb_config,
).eval()
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
quantization_config=bnb_config,
).eval()

# Torch type
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16
# Torch type
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16


@app.get("/v1/models", response_model=ModelList)
Expand All @@ -125,8 +122,6 @@ async def list_models():
async def create_chat_completion(
request: ChatCompletionRequest, token: str = Depends(authenticate_user)
):
global model, tokenizer, torch_type

try:
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
Expand Down

0 comments on commit 2e2a128

Please sign in to comment.