Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BFCL] Adds support for parallel inference and batching #498

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions berkeley-function-call-leaderboard/model_handler/handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from model_handler.model_style import ModelStyle
import json, os

import aiofiles

class BaseHandler:
model_name: str
Expand All @@ -24,20 +24,23 @@ def decode_execute(self, result):
# This method takes raw model output and convert it to standard execute checker input.
pass

def write(self, result, file_to_open):
# This method is used to write the result to the file.
## make the write function async
async def write(self, result, file_to_open):
# Ensure the result directories exist
if not os.path.exists("./result"):
os.mkdir("./result")
if not os.path.exists("./result/" + self.model_name):
os.mkdir("./result/" + self.model_name)
with open(

# Use aiofiles to write asynchronously
async with aiofiles.open(
"./result/"
+ self.model_name
+ "/"
+ file_to_open.replace(".json", "_result.json"),
"a+",
mode='a+'
) as f:
f.write(json.dumps(result) + "\n")
await f.write(json.dumps(result) + "\n")

def load_result(self, test_category):
# This method is used to load the result from the file.
Expand All @@ -48,3 +51,19 @@ def load_result(self, test_category):
for line in f:
result_list.append(json.loads(line))
return result_list

# open the result file and sort it on idx
# return the indicies that are saved in file in ascending order
def sort_results(self,file_to_open):
path = "./result/"+ self.model_name+ "/" + file_to_open.replace(".json", "_result.json")
# if the file doesnt exist yet then return
if not os.path.exists(path):
return None
with open(path,mode='r',) as f:
lines = f.readlines()
results = [json.loads(line) for line in lines]
sorted_results = sorted(results, key=lambda x: x['idx'])
with open(path, mode='w') as f:
for result in sorted_results:
f.write(json.dumps(result) + "\n")
return [result["idx"] for result in sorted_results]
115 changes: 79 additions & 36 deletions berkeley-function-call-leaderboard/openfunctions_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from model_handler.handler_map import handler_map
from model_handler.model_style import ModelStyle
from model_handler.constant import USE_COHERE_OPTIMIZATION

import aiohttp
import asyncio
from functools import wraps,partial

def get_args():
parser = argparse.ArgumentParser()
Expand All @@ -18,6 +20,7 @@ def get_args():
parser.add_argument("--max-tokens", type=int, default=1200)
parser.add_argument("--num-gpus", default=1, type=int)
parser.add_argument("--timeout", default=60, type=int)
parser.add_argument('--batch-size', type=int, default=1, help='Batch size for processing (default: 1)')

args = parser.parse_args()
return args
Expand All @@ -39,9 +42,21 @@ def get_args():
"sql": "gorilla_openfunctions_v1_test_sql.json",
}

def make_async(func):
@wraps(func)
async def run(*args, loop=None, executor=None, **kwargs):
if loop is None:
loop = asyncio.get_running_loop()
pfunc = partial(func, *args, **kwargs)
return await loop.run_in_executor(executor, pfunc)
return run

## automatically wraps the the handler to make handler.inference async
def build_handler(model_name, temperature, top_p, max_tokens):
handler = handler_map[model_name](model_name, temperature, top_p, max_tokens)

if not asyncio.iscoroutinefunction(handler.inference):
handler.inference = make_async(handler.inference)
return handler


Expand All @@ -55,19 +70,37 @@ def load_file(test_category):
return test_cate, files_to_open


if __name__ == "__main__":
async def fetch_and_process(session, index, test_case, handler, test_category, file_to_open):
user_question, functions = test_case["question"], test_case["function"]
if isinstance(functions, (dict, str)):
functions = [functions]

result, metadata = await handler.inference(user_question, functions, test_category)

result_to_write = {
"idx": index,
"result": result,
"input_token_count": metadata["input_tokens"],
"output_token_count": metadata["output_tokens"],
"latency": metadata["latency"],
}
await handler.write(result_to_write, file_to_open)

async def main():
args = get_args()
if USE_COHERE_OPTIMIZATION and "command-r-plus" in args.model:
args.model = args.model + "-optimized"

handler = build_handler(args.model, args.temperature, args.top_p, args.max_tokens)

if handler.model_style == ModelStyle.OSSMODEL:
result = handler.inference(
result = await handler.inference(
question_file="eval_data_total.json",
test_category=args.test_category,
num_gpus=args.num_gpus,
)
for res in result[0]:
handler.write(res, "result.json")
await handler.write(res, "result.json")
else:
test_cate, files_to_open = load_file(args.test_category)
for test_category, file_to_open in zip(test_cate, files_to_open):
Expand All @@ -76,35 +109,45 @@ def load_file(test_category):
with open("./data/" + file_to_open) as f:
for line in f:
test_cases.append(json.loads(line))
num_existing_result = 0 # if the result file already exists, skip the test cases that have been tested.
if os.path.exists(
"./result/"
+ args.model.replace("/", "_")
+ "/"
+ file_to_open.replace(".json", "_result.json")
):
with open(
"./result/"
+ args.model.replace("/", "_")
+ "/"
+ file_to_open.replace(".json", "_result.json")
) as f:
for line in f:
num_existing_result += 1
for index, test_case in enumerate(tqdm(test_cases)):
if index < num_existing_result:
continue
user_question, functions = test_case["question"], test_case["function"]
if type(functions) is dict or type(functions) is str:
functions = [functions]
result, metadata = handler.inference(
user_question, functions, test_category
)
result_to_write = {
"idx": index,
"result": result,
"input_token_count": metadata["input_tokens"],
"output_token_count": metadata["output_tokens"],
"latency": metadata["latency"],
}
handler.write(result_to_write, file_to_open)

# sorts the results of the test cases if they already exist
# and returns the indexs that have been saved so far in order
indicies = handler.sort_results(file_to_open)

# filter the test_cases that have already been completed
# add a None if test_case already completed
filtered_test_cases = []
if indicies is not None:
for test_idx,test_case in enumerate(test_cases):
if test_idx not in indicies:
filtered_test_cases.append(test_case)
else:
filtered_test_cases.append(None)
test_cases = filtered_test_cases

async with aiohttp.ClientSession() as session:
batch_size = args.batch_size # Number of iterations to run at a time
tasks = []
# Create a tqdm progress bar for the entire dataset
progress_bar = tqdm(total=len(test_cases), desc="Processing test cases")

for start_index in range(0, len(test_cases), batch_size):
end_index = min(start_index + batch_size, len(test_cases))
for index in range(start_index, end_index):
test_case = test_cases[index]
# if test_case is None it means its already completed
if test_case is None:
progress_bar.update(1) # Update for skipped items
continue
task = asyncio.create_task(fetch_and_process(session, index, test_case, handler, test_category, file_to_open))
task.add_done_callback(lambda _: progress_bar.update(1)) # Update progress when task is done
tasks.append(task)
await asyncio.gather(*tasks)
tasks.clear()
progress_bar.close()
## sort results since async entires could be out of order
handler.sort_results(file_to_open)


if __name__ == "__main__":
asyncio.run(main())
1 change: 1 addition & 0 deletions berkeley-function-call-leaderboard/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ numpy
cohere~=5.2.5
tree-sitter-java==0.21.0
tree-sitter-javascript==0.21.4
aiofiles