-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support external models with vllm huggingface openai gemini and claude (
#97) * support external models: vllm huggingface openai * fix prompt for internlm * support gemini claude * fix_bugs * fix_model_bugs * update_requirement * add requirements in README * fixed by suggestions * removed comments
- Loading branch information
1 parent
5e35567
commit b7e199e
Showing
10 changed files
with
578 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .oai_runner import OpenAIRunner | ||
from .hf_runner import HFTacticGenerator | ||
from .vllm_runner import VLLMTacticGenerator | ||
from .claude_runner import ClaudeRunner | ||
from .gemini_runner import GeminiRunner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
import numpy as np | ||
from loguru import logger | ||
from typing import List, Tuple | ||
from abc import ABC, abstractmethod | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoModelForSeq2SeqLM, | ||
AutoTokenizer, | ||
AutoModelForTextEncoding, | ||
) | ||
import os | ||
import numpy as np | ||
try: | ||
from anthropic import Anthropic | ||
except ImportError as e: | ||
pass | ||
from .external_parser import * | ||
|
||
|
||
class ClaudeRunner(Generator, Transformer): | ||
client = Anthropic(api_key=os.getenv("ANTHROPIC_KEY")) | ||
|
||
def __init__(self, **args): | ||
self.client_kwargs: dict[str | str] = { | ||
"model": args['model'], | ||
"temperature": args['temperature'], | ||
"max_tokens": args['max_tokens'], | ||
"top_p": args['top_p'], | ||
} | ||
self.name = self.client_kwargs["model"] | ||
|
||
def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: | ||
prompt = pre_process_input(self.name, input + target_prefix) | ||
|
||
try: | ||
response = self.client.completions.create( | ||
prompt=prompt, | ||
**self.client_kwargs, | ||
) | ||
content = response.completion | ||
|
||
except Exception as e: | ||
raise e | ||
|
||
results = [(post_process_output(self.name, content),1.0)]# current claude only supports one output | ||
return choices_dedup(results) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
generation_kwargs = {"model": "claude-3-opus", | ||
"temperature": 0.9, | ||
"max_tokens": 1024, | ||
"top_p": 0.9, | ||
} | ||
|
||
model = ClaudeRunner(**generation_kwargs) | ||
print(model.generate("n : ℕ\n⊢ gcd n n = n")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import os | ||
import torch | ||
import argparse | ||
import numpy as np | ||
from loguru import logger | ||
from typing import List, Tuple | ||
from abc import ABC, abstractmethod | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoModelForSeq2SeqLM, | ||
AutoTokenizer, | ||
AutoModelForTextEncoding, | ||
) | ||
|
||
|
||
def get_cuda_if_available(): | ||
return torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def pre_process_input(model_name, input): | ||
if model_name == "internlm/internlm2-math-plus-1_8b": | ||
prompt="My LEAN 4 state is:\n```lean\n" + input + \ | ||
"```\nPlease predict a possible tactic to help me prove the theorem." | ||
prompt = f"""<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n""" | ||
elif model_name == "gpt-3.5-turbo" or model_name == "gpt-4-turbo-preview": | ||
prompt = 'Here is a theorom you need to prove in Lean:\n' + \ | ||
input+'\nNow you should suggest one line tactic in lean code:' | ||
elif 'gemini' in model_name or "claude" in model_name: | ||
prompt = 'Here is a theorom you need to prove in Lean:\n' + \ | ||
input+'\nNow you should suggest one line tactic in lean code:' | ||
else: | ||
raise NotImplementedError(f"External model '{model_name}' not supported") | ||
return prompt | ||
|
||
|
||
def post_process_output(model_name, output): | ||
if model_name == "internlm/internlm2-math-plus-1_8b": | ||
result = output.split( | ||
'assistant')[-1].split('lean')[-1].split('```')[0].split('\n')[1] | ||
elif model_name == "gpt-3.5-turbo" or model_name == "gpt-4-turbo-preview": | ||
result = output.split('lean')[-1].split('```')[0].split('\n')[1] | ||
elif 'gemini' in model_name or "claude" in model_name: | ||
result = output.split('lean')[-1].split('```')[0].split('\n')[1] | ||
else: | ||
raise NotImplementedError(f"External model '{model_name}' not supported") | ||
return result | ||
|
||
|
||
def choices_dedup(output_list: List[tuple[str, float]]) -> List[tuple[str, float]]: | ||
unique_data = {} | ||
for item in output_list: | ||
if item[0] not in unique_data or item[1] > unique_data[item[0]]: | ||
unique_data[item[0]] = item[1] | ||
sorted_data = sorted(unique_data.items(), key=lambda x: x[1], reverse=True) | ||
return sorted_data | ||
|
||
|
||
class Generator(ABC): | ||
@abstractmethod | ||
def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: | ||
pass | ||
|
||
|
||
class Encoder(ABC): | ||
@abstractmethod | ||
def encode(self, input: str) -> np.ndarray: | ||
pass | ||
|
||
|
||
class Transformer: | ||
def cuda(self) -> None: | ||
self.model.cuda() | ||
|
||
def cpu(self) -> None: | ||
self.model.cpu() | ||
|
||
@property | ||
def device(self) -> torch.device: | ||
return self.model.device | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import torch | ||
import numpy as np | ||
from loguru import logger | ||
from typing import List, Tuple | ||
from abc import ABC, abstractmethod | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoModelForSeq2SeqLM, | ||
AutoTokenizer, | ||
AutoModelForTextEncoding, | ||
) | ||
import os | ||
import numpy as np | ||
from .external_parser import * | ||
|
||
try: | ||
import google.generativeai as genai | ||
from google.generativeai import GenerationConfig | ||
except ImportError as e: | ||
pass | ||
|
||
|
||
class GeminiRunner(Generator, Transformer): | ||
client = genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) | ||
safety_settings = [ | ||
{ | ||
"category": "HARM_CATEGORY_HARASSMENT", | ||
"threshold": "BLOCK_NONE", | ||
}, | ||
{ | ||
"category": "HARM_CATEGORY_HATE_SPEECH", | ||
"threshold": "BLOCK_NONE", | ||
}, | ||
{ | ||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | ||
"threshold": "BLOCK_NONE", | ||
}, | ||
{ | ||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | ||
"threshold": "BLOCK_NONE", | ||
},] | ||
def __init__(self, **args): | ||
|
||
self.client_kwargs: dict[str | str] = { | ||
"model": args['model'], | ||
"temperature": args['temperature'], | ||
"max_tokens": args['max_tokens'], | ||
"top_p": args['top_p'], | ||
|
||
} | ||
self.name = self.client_kwargs["model"] | ||
|
||
self.client = genai.GenerativeModel(args['model']) | ||
self.generation_config = GenerationConfig( | ||
candidate_count=1, | ||
max_output_tokens=args['max_tokens'], | ||
temperature=args['temperature'], | ||
top_p=args['top_p'], | ||
) | ||
def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: | ||
prompt = pre_process_input(self.name, input + target_prefix) | ||
|
||
|
||
response = self.client.generate_content( | ||
prompt, | ||
generation_config=self.generation_config, | ||
safety_settings=GeminiRunner.safety_settings, | ||
) | ||
|
||
|
||
|
||
results = [(post_process_output(self.name, response.text),1.0)]# current gemini only supports one output | ||
return choices_dedup(results) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
generation_kwargs = {"model": 'gemini-1.0-pro', | ||
"temperature": 0.9, | ||
"max_tokens": 1024, | ||
"top_p": 0.9, | ||
} | ||
|
||
model = GeminiRunner(**generation_kwargs) | ||
print(model.generate("n : ℕ\n⊢ gcd n n = n")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import torch | ||
import numpy as np | ||
from loguru import logger | ||
from typing import List, Tuple | ||
from abc import ABC, abstractmethod | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoModelForSeq2SeqLM, | ||
AutoTokenizer, | ||
AutoModelForTextEncoding, | ||
) | ||
import os | ||
import numpy as np | ||
|
||
import openai | ||
from openai import OpenAI | ||
from .external_parser import * | ||
|
||
|
||
class HFTacticGenerator(Generator, Transformer): | ||
def __init__( | ||
self, | ||
**args | ||
) -> None: | ||
self.name = args['model'] | ||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
self.name, trust_remote_code=True) | ||
device = args['device'] | ||
if device == "auto": | ||
device = get_cuda_if_available() | ||
else: | ||
device = torch.device(device) | ||
logger.info(f"Loading {self.name} on {device}") | ||
self.model = AutoModelForCausalLM.from_pretrained( | ||
self.name, trust_remote_code=True).to(device) | ||
|
||
self.generation_args: dict[str | str] = { | ||
"do_sample": args["do_sample"], | ||
"temperature": args['temperature'], # chat default is 0.8 | ||
"max_new_tokens": args['max_new_tokens'], | ||
"top_p": args['top_p'], # chat default is 0.8 | ||
# "length_penalty": args["length_penalty"], | ||
"num_return_sequences": args['num_return_sequences'], | ||
# "num_beams": self.num_return_sequences, | ||
# Here if we use beam search for llms the output are not diverse(few tactics are provided). | ||
"output_scores": args["output_scores"], | ||
"output_logits": args["output_logits"], | ||
"return_dict_in_generate": args["return_dict_in_generate"], | ||
} | ||
|
||
def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: | ||
prompt = input + target_prefix | ||
'''prompt= 'Here is a theorom you need to prove in Lean:\n'+prompt+'\nNow you should suggest one line tactic in lean code:' | ||
prompt = f"""<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n""" | ||
''' | ||
prompt = pre_process_input(self.name, prompt) | ||
|
||
self.model = self.model.eval() | ||
|
||
tokenized_input = self.tokenizer(prompt, return_tensors="pt") | ||
eos_token_id = [self.tokenizer.eos_token_id, | ||
self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]] | ||
outputs = self.model.generate( | ||
tokenized_input.input_ids.to(self.device), | ||
eos_token_id=eos_token_id, | ||
**self.generation_args | ||
) | ||
response = self.tokenizer.batch_decode( | ||
outputs['sequences'], skip_special_tokens=True) | ||
result = [] | ||
index = 0 | ||
for out, score in zip(response, outputs.scores): | ||
out = post_process_output(self.name, out) | ||
result.append((out, score[index].exp().sum().log().cpu().item())) | ||
index += 1 | ||
result = choices_dedup(result) | ||
return result | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
generation_kwargs = {"model": "internlm/internlm2-math-plus-1_8b", | ||
"temperature": 0.6, | ||
"max_new_tokens": 1024, | ||
"top_p": 0.9, | ||
"length_penalty": 0, | ||
"num_return_sequences": 64, | ||
"do_sample": True, | ||
"output_scores": True, | ||
"output_logits": False, | ||
"return_dict_in_generate": True, | ||
"device": "auto", | ||
} | ||
model = HFTacticGenerator(**generation_kwargs) | ||
model.cuda() | ||
print(model.generate("n : ℕ\n⊢ gcd n n = n")) |
Oops, something went wrong.