Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BFCL] Add empower functions models and the supporting handler #630

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,18 @@
"Microsoft",
"MIT",
],
"empower-dev/llama3-empower-functions-small-v1.1": [
"Empower-Fucntions-Small-v1.1 (FC)",
"https://huggingface.co/empower-dev/llama3-empower-functions-small-v1.1",
"Empower.dev",
"apache-2.0"
],
"empower-dev/llama3-empower-functions-large-v1.1": [
"Empower-Fucntions-Large-v1.1 (FC)",
"https://huggingface.co/empower-dev/llama3-empower-functions-large-v1.1",
"Empower.dev",
"apache-2.0"
]
}

INPUT_PRICE_PER_MILLION_TOKEN = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,6 @@
"THUDM/glm-4-9b-chat",
"ibm-granite/granite-20b-functioncalling",
"yi-large-fc",
"empower-dev/llama3-empower-functions-small-v1.1",
"empower-dev/llama3-empower-functions-large-v1.1",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from bfcl.model_handler.oss_model.deepseek import DeepseekHandler
from bfcl.model_handler.oss_model.empower import EmpowerHandler
from bfcl.model_handler.oss_model.gemma import GemmaHandler
from bfcl.model_handler.oss_model.glaive import GlaiveHandler
from bfcl.model_handler.oss_model.glm import GLMHandler
Expand Down Expand Up @@ -92,7 +93,9 @@
"ibm-granite/granite-20b-functioncalling": GraniteHandler,
# "MadeAgents/Hammer-7b": HammerHandler, # TODO: Update handler once they have a multi-turn format
"THUDM/glm-4-9b-chat": GLMHandler,

"empower-dev/llama3-empower-functions-small-v1.1": EmpowerHandler,
"empower-dev/llama3-empower-functions-large-v1.1": EmpowerHandler,

# Deprecated/outdated models, no longer on the leaderboard
# "gorilla-openfunctions-v0": GorillaHandler,
# "gpt-4o-2024-05-13": OpenAIHandler,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from bfcl.model_handler.oss_model.base_oss_handler import OSSHandler
from bfcl.model_handler.model_style import ModelStyle
import json
from bfcl.model_handler.utils import (
convert_to_tool,
)
from bfcl.model_handler.constant import (
GORILLA_TO_OPENAPI,
)


class EmpowerHandler(OSSHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)

def _preprocess_messages(self, messages):
# remove system message
messages = [
message for message in messages if message['role'] != "system"]

# combine tool responses
result = []
temp_tool_content = None
for message in messages:
if message['role'] == 'tool':
decoded_content = json.loads(message['content'])
if temp_tool_content:
temp_tool_content.append(decoded_content)
else:
temp_tool_content = [decoded_content]
else:
if temp_tool_content:
result.append({
'role': 'tool',
'content': json.dumps(temp_tool_content, indent=2)
})
temp_tool_content = None
result.append(message)
if temp_tool_content:
result.append({
'role': 'tool',
'content': json.dumps(temp_tool_content, indent=2)
})

return result

def _format_prompt(self, messages, functions):
formatted_prompt = "<|begin_of_text|>"

for idx, message in enumerate(self._preprocess_messages(messages)):
if idx == 0:
tools = convert_to_tool(
functions, GORILLA_TO_OPENAPI, ModelStyle.OSSMODEL
)
message['content'] = "In this environment you have access to a set of functions defined in the JSON format you can use to address user's requests, use them if needed.\nFunctions:\n" \
+ json.dumps(tools, indent=2) \
+ "\n\n" \
+ "User Message:\n" \
+ message['content']
else:
if message['role'] == 'tool':
message['role'] = 'user'
message['content'] = '<r>' + message['content']
elif message['role'] == 'user' and not message['content'].startswith('<r>') and not message['content'].startswith('<u>'):
message['content'] = '<u>' + message['content']

formatted_prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}<|eot_id|>"

formatted_prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n"

return formatted_prompt

def decode_ast(self, result, language="Python"):
if not result.startswith('<f>'):
return []

# strip the function/conversation tag <f>/<c>
result_stripped = result[3:]

decoded_output = []
for invoked_function in json.loads(result_stripped):
name = invoked_function["name"]
params = invoked_function["arguments"] if "arguments" in invoked_function else {
}
decoded_output.append({name: params})

return decoded_output

def decode_execute(self, result):
execution_list = []

for function_call in self.decode_ast(result):
for key, value in function_call.items():
argument_list = []
for k, v in value.items():
argument_list.append(f'{k}={repr(v)}')
execution_list.append(
f"{key}({','.join(argument_list)})"
)

return execution_list