Skip to content

Commit

Permalink
Update app.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Souvik2biswas authored Sep 30, 2024
1 parent a451356 commit c6542ef
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import verovio
model_name = "ucaslcl/GOT-OCR2_0"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True).eval().cuda()
#model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True).eval().cuda()
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map='cpu', use_safetensors=True).eval()

UPLOAD_FOLDER = "./uploads"
RESULTS_FOLDER = "./results"
Expand All @@ -32,7 +33,8 @@ def image_to_base64(image):
return base64.b64encode(buffered.getvalue()).decode()

q_model_name = "Qwen/Qwen2-VL-2B-Instruct"
q_model = Qwen2VLForConditionalGeneration.from_pretrained(q_model_name, torch_dtype="auto").cuda().eval()
#q_model = Qwen2VLForConditionalGeneration.from_pretrained(q_model_name, torch_dtype="auto").cuda().eval()
q_model = Qwen2VLForConditionalGeneration.from_pretrained(q_model_name, torch_dtype="auto").eval()
q_processor = AutoProcessor.from_pretrained(q_model_name, trust_remote_code=True)

def get_qwen_op(image_file, model, processor):
Expand All @@ -53,7 +55,8 @@ def get_qwen_op(image_file, model, processor):
}
]
text_prompt = q_processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = q_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to("cuda")
#inputs = q_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to("cuda")
inputs = q_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
inputs = {k: v.to(torch.float32) if torch.is_floating_point(v) else v for k, v in inputs.items()}

generation_config = {
Expand Down

0 comments on commit c6542ef

Please sign in to comment.