Skip to content

Commit

Permalink
[COGVLM][Speedup]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 11, 2024
1 parent fd2e8c0 commit 58495f8
Showing 1 changed file with 52 additions and 53 deletions.
105 changes: 52 additions & 53 deletions servers/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,58 @@
MODEL_PATH = os.environ.get("COGVLM_MODEL_PATH", "THUDM/cogvlm-chat-hf")
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "lmsys/vicuna-7b-v1.5")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
QUANT_ENABLED = os.environ.get("QUANT_ENABLED", True)


# Model and tokenizer
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16

print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n")

quantization_config = {
"load_in_4bit": QUANT_ENABLED,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch_type,
}

bnb_config = BitsAndBytesConfig(**quantization_config)

if "cuda" in DEVICE:
if QUANT_ENABLED:
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
load_in_4bit=True,
trust_remote_code=True,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
).eval()
else:
model = (
AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
load_in_4bit=False,
trust_remote_code=True,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
)
.to(DEVICE)
.eval()
)
else:
model = (
AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True)
.float()
.to(DEVICE)
.eval()
)


# Torch type
Expand Down Expand Up @@ -455,58 +507,5 @@ def generate_stream_cogvlm(
torch.cuda.empty_cache()


def main():
global model, tokenizer
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.bfloat16
else:
torch_type = torch.float16

print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n")

quantization_config = {
"load_in_4bit": QUANT_ENABLED,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch_type,
}

bnb_config = BitsAndBytesConfig(**quantization_config)

if "cuda" in DEVICE:
if QUANT_ENABLED:
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
load_in_4bit=True,
trust_remote_code=True,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
).eval()
else:
model = (
AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
load_in_4bit=False,
trust_remote_code=True,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
)
.to(DEVICE)
.eval()
)
else:
(
AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True)
.float()
.to(DEVICE)
.eval()
)


if __name__ == "__main__":
main()
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

0 comments on commit 58495f8

Please sign in to comment.