diff --git a/berkeley-function-call-leaderboard/bfcl/model_handler/oss_model/base_oss_handler.py b/berkeley-function-call-leaderboard/bfcl/model_handler/oss_model/base_oss_handler.py index 1140a926d..5db614880 100644 --- a/berkeley-function-call-leaderboard/bfcl/model_handler/oss_model/base_oss_handler.py +++ b/berkeley-function-call-leaderboard/bfcl/model_handler/oss_model/base_oss_handler.py @@ -26,6 +26,39 @@ def __init__(self, model_name, temperature, dtype="bfloat16") -> None: self.dtype = dtype self.client = OpenAI(base_url=f"http://localhost:{VLLM_PORT}/v1", api_key="EMPTY") + def get_max_tokens(model_name): + # Llama Family + if "Llama" in model_name: + if "70B" in model_name: + return 8192 + else: + return 4096 + # GLM Family + elif "glm" in model_name: + return 2048 + # Phi Family + elif "Phi" in model_name: + if "4k" in model_name: + return 4000 + elif "128k" in model_name: + return 128000 + else: + return 4096 + # Hermes Family + elif "Hermes" in model_name: + return 4096 + # Qwen Family + elif "Qwen" in model_name: + return 2048 + # Salesforce Family + elif "Salesforce" in model_name: + return 4096 + # Default token limit + else: + return 4096 + # Assigning max_tokens based on the model name + self.max_tokens = get_max_tokens(self.model_name_huggingface) + def inference(self, test_entry: dict, include_debugging_log: bool): """ OSS models have a different inference method. @@ -210,14 +243,14 @@ def _query_prompting(self, inference_data: dict): temperature=self.temperature, prompt=formatted_prompt, stop_token_ids=self.stop_token_ids, - max_tokens=4096, # TODO: Is there a better way to handle this? + max_tokens=self.max_tokens, ) else: api_response = self.client.completions.create( model=self.model_name_huggingface, temperature=self.temperature, prompt=formatted_prompt, - max_tokens=4096, + max_tokens=self.max_tokens, ) return api_response