Skip to content

Commit

Permalink
Remove test category
Browse files Browse the repository at this point in the history
- test category is already added to each example during loading the data
  • Loading branch information
devanshamin committed Jul 8, 2024
1 parent 34a170a commit f736521
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
16 changes: 7 additions & 9 deletions berkeley-function-call-leaderboard/bfcl/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ def main() -> None:
model_handler = _get_model_handler(args)
test_inputs = test_categories.load_data()
if model_handler.model_style == ModelStyle.OSS_MODEL:
result = model_handler.inference(
inputs=test_inputs,
test_category=test_categories,
num_gpus=args.num_gpus,
)
for res in result[0]:
model_handler.write(res, "result.json")
responses = model_handler.inference(inputs=test_inputs, num_gpus=args.num_gpus)
file_name = test_categories.output_file_path.name.replace('.jsonl', '_result.jsonl')
model_handler.write(responses, file_name)
else:
raise NotImplementedError()


def get_args() -> argparse.Namespace:
Expand Down Expand Up @@ -87,8 +85,8 @@ def _get_model_handler(args) -> BaseHandler:
elif args.model_type == ModelType.PROPRIETARY:
from bfcl.model_handler.proprietary_model import MODEL_TO_HANDLER_CLS

assert (handler_cls := MODEL_TO_HANDLER_CLS.get(args.model_name)), \
f'Invalid model name! Please select a {args.model_type.value} model from {tuple(MODEL_TO_HANDLER_CLS)}'
assert (handler_cls := MODEL_TO_HANDLER_CLS.get(args.model)), \
f'Invalid model name "{args.model}"! Please select a {args.model_type.value} model from {tuple(MODEL_TO_HANDLER_CLS)}'

return handler_cls(
model_name=args.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ def get_prompt(self, user_input, functions, test_category) -> str:
user_input=user_input,
)

def inference(self, inputs, test_category, num_gpus):
def inference(self, inputs, num_gpus):
chunk_size = len(inputs) // num_gpus
futures = []
for i in range(0, len(inputs), chunk_size):
futures.append(
self._batch_generate.remote(
inputs[i: i + chunk_size],
test_category,
self.model_name,
self.sampling_params,
get_prompt_func=self.get_prompt,
Expand Down Expand Up @@ -79,9 +78,8 @@ def _batch_generate(
get_prompt_func
):
prompts = []
for line in inputs:
for _input in inputs:
test_category = _input["test_category"]
_input = line
prompt = utils.augment_prompt_by_languge(_input["question"], test_category)
functions = utils.language_specific_pre_processing(_input["function"], test_category, False)
prompts.append(get_prompt_func(prompt, functions, test_category))
Expand Down

0 comments on commit f736521

Please sign in to comment.