Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BFCL] Add dynamic max_token handling for locally hosted models #693

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,39 @@ def __init__(self, model_name, temperature, dtype="bfloat16") -> None:
self.dtype = dtype
self.client = OpenAI(base_url=f"http://localhost:{VLLM_PORT}/v1", api_key="EMPTY")

def get_max_tokens(model_name):
# Llama Family
if "Llama" in model_name:
if "70B" in model_name:
return 8192
else:
return 4096
# GLM Family
elif "glm" in model_name:
return 2048
# Phi Family
elif "Phi" in model_name:
if "4k" in model_name:
return 4000
elif "128k" in model_name:
return 128000
else:
return 4096
# Hermes Family
elif "Hermes" in model_name:
return 4096
# Qwen Family
elif "Qwen" in model_name:
return 2048
# Salesforce Family
elif "Salesforce" in model_name:
return 4096
# Default token limit
else:
return 4096
# Assigning max_tokens based on the model name
self.max_tokens = get_max_tokens(self.model_name_huggingface)

def inference(self, test_entry: dict, include_debugging_log: bool):
"""
OSS models have a different inference method.
Expand Down Expand Up @@ -210,14 +243,14 @@ def _query_prompting(self, inference_data: dict):
temperature=self.temperature,
prompt=formatted_prompt,
stop_token_ids=self.stop_token_ids,
max_tokens=4096, # TODO: Is there a better way to handle this?
max_tokens=self.max_tokens,
)
else:
api_response = self.client.completions.create(
model=self.model_name_huggingface,
temperature=self.temperature,
prompt=formatted_prompt,
max_tokens=4096,
max_tokens=self.max_tokens,
)

return api_response
Expand Down