Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add huggingface handler #618

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions berkeley-function-call-leaderboard/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
log.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed.

Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,19 @@
"https://huggingface.co/Salesforce/xLAM-7b-fc-r",
"Salesforce",
"cc-by-nc-4.0",
]
],
"google/gemma-2-2b-it": [
"Gemma-2-2.6B-it (Prompt)",
"https://huggingface.co/google/gemma-2b-it",
"Google",
"gemma-terms-of-use",
],
"microsoft/Phi-3.5-mini-instruct": [
"Phi-3.5-mini-instruct (Prompt)",
"https://huggingface.co/microsoft/Phi-3.5-mini-instruct",
"Microsoft",
"MIT",
]
}

INPUT_PRICE_PER_MILLION_TOKEN = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ class OpenAIHandler(BaseHandler):
def __init__(self, model_name, temperature=0.001, top_p=1, max_tokens=1000) -> None:
super().__init__(model_name, temperature, top_p, max_tokens)
self.model_style = ModelStyle.OpenAI
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def inference(self, prompt, functions, test_category):
# Move it here so eval checker can run w/o instantiating API key since we don't use it at that point
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary. eval_checker can run without the OpenAI API keys. If OPENAI_API_KEY is not supplied, it would just be None and won't break the evaluation pipeline.

self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Chatting model
if "FC" not in self.model_name:
functions = func_doc_language_specific_pre_processing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from bfcl.model_handler.glm_handler import GLMHandler
from bfcl.model_handler.yi_handler import YiHandler
from bfcl.model_handler.xlam_handler import xLAMHandler
from bfcl.model_handler.huggingface_handler import HuggingFaceHandler
from bfcl.model_handler.phi_handler import PhiHandler

handler_map = {
"gorilla-openfunctions-v0": GorillaHandler,
Expand Down Expand Up @@ -72,6 +74,7 @@
"gemini-1.5-pro-preview-0409": GeminiHandler,
"gemini-1.5-pro-preview-0514": GeminiHandler,
"gemini-1.5-flash-preview-0514": GeminiHandler,
"google/gemma-2-2b-it": HuggingFaceHandler,
"google/gemma-7b-it": GemmaHandler,
"glaiveai/glaive-function-calling-v1": GlaiveHandler,
"deepseek-ai/deepseek-coder-6.7b-instruct": DeepseekHandler,
Expand All @@ -96,5 +99,6 @@
"THUDM/glm-4-9b-chat": GLMHandler,
"yi-large-fc": YiHandler,
"Salesforce/xLAM-1b-fc-r": xLAMHandler,
"Salesforce/xLAM-7b-fc-r": xLAMHandler
"Salesforce/xLAM-7b-fc-r": xLAMHandler,
"microsoft/Phi-3.5-mini-instruct": PhiHandler,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import pdb
import re, time

from bfcl.model_handler.constant import DEFAULT_SYSTEM_PROMPT
from bfcl.model_handler.handler import BaseHandler
from bfcl.model_handler.model_style import ModelStyle
from bfcl.model_handler.utils import (
ast_parse,
combine_consecutive_user_prompr,
convert_system_prompt_into_user_prompt,
func_doc_language_specific_pre_processing,
system_prompt_pre_processing_chat_model,
)
from transformers import ( # type: ignore
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
pipeline,
PreTrainedTokenizerFast,
)


class HuggingFaceHandler(BaseHandler):
def __init__(
self,
model_name,
temperature=0,
top_p=1,
max_tokens=1000,
add_generation_prompt=True,
system_prompt_support=False,
attn_implementation=None,
) -> None:
super().__init__(model_name, temperature, top_p, max_tokens)
self.model_style = ModelStyle.OSSMODEL
self.system_prompt_support = system_prompt_support
self.add_generation_prompt = add_generation_prompt
self.attn_implementation = attn_implementation
self.model_name = model_name

def _format_prompt(self, prompt, function, test_category, tokenizer):
if isinstance(prompt, str):
return prompt
elif isinstance(prompt, list):
if self.system_prompt_support:
msg_list = prompt
# If the model does not support system prompt, we need to convert the system prompt into user prompt.
else:
msg_list = convert_system_prompt_into_user_prompt(prompt)
msg_list = combine_consecutive_user_prompr(msg_list)
return tokenizer.apply_chat_template(
msg_list,
tokenize=False,
add_generation_prompt=True,
)
else:
raise NotImplementedError(f"Unsupported prompt type {type(prompt)}")

def process_input(
self,
test_question,
format_prompt_func,
tokenizer,
include_system_prompt=True,
):
prompts = []
for question in test_question:
test_category = question["id"].rsplit("_", 1)[0]
functions = func_doc_language_specific_pre_processing(
question["function"], test_category
)
# Only the chat model needs the system prompt; also some require special formatting
if include_system_prompt:
question["question"] = system_prompt_pre_processing_chat_model(
question["question"], DEFAULT_SYSTEM_PROMPT, functions
)

formatted_prompt = format_prompt_func(
question["question"], functions, test_category, tokenizer
)
prompts.append(formatted_prompt)

return prompts

def inference(
self,
test_question,
num_gpus,
gpu_memory_utilization,
format_prompt_func=None,
stop_token_ids=None,
max_model_len=None,
include_system_prompt=True,
):

model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
attn_implementation=(
self.attn_implementation
if self.attn_implementation is not None
else None
),
)
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
padding_size="left",
truncation=True,
trust_remote_code=True,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = (
model.config.pad_token_id
or tokenizer.eos_token_id
or model.config.eos_token_id
)
prompts = self.process_input(
test_question,
(
format_prompt_func
if format_prompt_func is not None
else self._format_prompt
),
tokenizer,
include_system_prompt=include_system_prompt,
)

final_ans_jsons = []
for prompt in prompts:
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)

message = prompt
print("Prompt: ", message)
output = pipe(
message,
max_new_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
return_full_text=False,
)
result = output[0]["generated_text"]
print("Generation: ", result)

final_ans_jsons.append(result)

return final_ans_jsons, prompts

def decode_ast(self, result, language="Python"):
result = result.strip()
PRETTY_FORMAT_PATTERN = r"```\n?(python|json|tool_code)\n?(.*)\n?```"
pattern = r"\[(.*)\]"
# Searching for the pattern in the input text

unformatted = re.search(PRETTY_FORMAT_PATTERN, result, re.DOTALL)
raw_match = re.search(PRETTY_FORMAT_PATTERN, result, re.DOTALL)
if unformatted:
removed_formatting = unformatted.group(2)
match = re.search(pattern, removed_formatting, re.DOTALL)
if match:
raw_input = match.group(1)
else:
raw_input = removed_formatting
elif raw_match:
raw_input = raw_match.group(1)
else:
raw_input = result
raw = raw_input.strip()
func = "[" + raw + "]"
decoded_output = ast_parse(func, language=language)

return decoded_output

def decode_execute(self, result):
result = result.strip()
PRETTY_FORMAT_PATTERN = r"```\n?(python|json|tool_code)\n?(.*)\n?```"
pattern = r"\[(.*)\]"
# Searching for the pattern in the input text

unformatted = re.search(PRETTY_FORMAT_PATTERN, result, re.DOTALL)
raw_match = re.search(PRETTY_FORMAT_PATTERN, result, re.DOTALL)
if unformatted:
removed_formatting = unformatted.group(2)
match = re.search(pattern, removed_formatting, re.DOTALL)
if match:
raw_input = match.group(1)
else:
raw_input = removed_formatting
elif raw_match:
raw_input = raw_match.group(1)
else:
raw_input = result
raw = raw_input.strip()
func = "[" + raw + "]"
decoded_output = ast_parse(func, language=language)

execution_list = []
for function_call in decoded_output:
for key, value in function_call.items():
execution_list.append(
f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})"
)
return execution_list
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
from tree_sitter import Language, Parser
import tree_sitter_java

JAVA_LANGUAGE = Language(tree_sitter_java.language(), "java")

parser = Parser()
parser.set_language(JAVA_LANGUAGE)


def parse_java_function_call(source_code):
# Move in here to avoid dependency errors if we aren't running categories that use this function
import tree_sitter_java
JAVA_LANGUAGE = Language(tree_sitter_java.language(), "java")

parser = Parser()
parser.set_language(JAVA_LANGUAGE)
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node
sexp_result = root_node.sexp()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
from tree_sitter import Language, Parser
import tree_sitter_javascript

JS_LANGUAGE = Language(tree_sitter_javascript.language(), "javascript")

parser = Parser()
parser.set_language(JS_LANGUAGE)
def parse_javascript_function_call(source_code):
# Move in here to avoid dependency errors if we aren't running categories that use this function
import tree_sitter_javascript

JS_LANGUAGE = Language(tree_sitter_javascript.language(), "javascript")

def parse_javascript_function_call(source_code):
parser = Parser()
parser.set_language(JS_LANGUAGE)
# Parse the source code
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from bfcl.model_handler.huggingface_handler import HuggingFaceHandler


class PhiHandler(HuggingFaceHandler):
def __init__(self, model_name, temperature=0, top_p=1, max_tokens=1000, add_generation_prompt=False, system_prompt_support=True, attn_implementation="flash_attention_2") -> None:
super().__init__(model_name, temperature, top_p, max_tokens, add_generation_prompt, system_prompt_support, attn_implementation)

def inference(
self,
test_question,
num_gpus,
gpu_memory_utilization,
stop_token_ids=None,
max_model_len=None,
include_system_prompt=True,
):
return super().inference(
test_question=test_question,
num_gpus=num_gpus,
gpu_memory_utilization=gpu_memory_utilization,
include_system_prompt=include_system_prompt,
)

Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,14 @@ def resolve_ast_call(elem):
func_parts.append(func_part.id)
func_name = ".".join(reversed(func_parts))
args_dict = {}
# Parse when args are simply passed as an unnamed dictionary arg
for arg in elem.args:
if isinstance(arg, ast.Dict):
for key, value in zip(arg.keys, arg.values):
if isinstance(key, ast.Constant):
arg_name = key.value
output = resolve_ast_by_type(value)
args_dict[arg_name] = output
for arg in elem.keywords:
output = resolve_ast_by_type(arg.value)
args_dict[arg.arg] = output
Expand Down