-
Notifications
You must be signed in to change notification settings - Fork 783
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Aaron Pham <[email protected]>
- Loading branch information
Showing
43 changed files
with
1,187 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,3 +167,7 @@ mlruns/ | |
.pdm-python | ||
.python-version | ||
.pdm-build/ | ||
|
||
# from training scripts | ||
model | ||
outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import os | ||
import pathlib | ||
import shutil | ||
import subprocess | ||
import sys | ||
import tempfile | ||
import typing as t | ||
|
||
import yaml | ||
from deepmerge.merger import Merger | ||
|
||
import bentoml | ||
from _bentoml_sdk.service.config import ServiceConfig as Config | ||
from bentoml._internal.bento.build_config import BentoBuildConfig | ||
from bentoml._internal.bento.build_config import DockerOptions | ||
from bentoml._internal.bento.build_config import ModelSpec | ||
from bentoml._internal.utils import pkg | ||
from bentoml.exceptions import BentoMLException | ||
from bentoml.exceptions import MissingDependencyException | ||
|
||
from .mapping import RUNTIME_MAPPING as MAPPINGS | ||
from .mapping import get_extras | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
if t.TYPE_CHECKING: | ||
from transformers import PreTrainedModel | ||
from transformers import PreTrainedTokenizerFast | ||
|
||
if pkg.find_spec("unsloth") is None: | ||
raise MissingDependencyException( | ||
"'unsloth' is required in order to use module 'bentoml.unsloth', install unsloth with 'pip install bentoml[unsloth]'." | ||
) | ||
|
||
merger = Merger( | ||
# merge dicts, append list | ||
[(dict, "merge"), (list, "append")], | ||
# override all other types | ||
["override"], | ||
# override conflicting types | ||
["override"], | ||
) | ||
|
||
|
||
def replace_tag(tag: str) -> str: | ||
return tag.lower().replace("/", "--") | ||
|
||
|
||
ChatTemplate = t.Literal[ | ||
"alpaca", | ||
"amberchat", | ||
"chatml", | ||
"chatqa", | ||
"falcon-instruct", | ||
"gemma-it", | ||
"llama-2-chat", | ||
"llama-3-chat", | ||
"mistral-instruct", | ||
"openchat", | ||
"phi-3", | ||
"saiga", | ||
"solar-instruct", | ||
"vicuna", | ||
"zephyr", | ||
] | ||
|
||
ModelType = t.Literal["llama", "mistral", "gemma", "gemma2", "qwen2"] | ||
|
||
|
||
def build_bento( | ||
model: PreTrainedModel, | ||
tokenizer: PreTrainedTokenizerFast, | ||
/, | ||
model_name: str | None = None, | ||
*, | ||
chat_template: ChatTemplate, | ||
quantization_method: t.Literal["bitsandbytes"] | None = None, | ||
save_method: t.Literal["merged_16bit", "merged_4bit"] = "merged_16bit", | ||
service_config: Config | None = None, | ||
engine_config: dict[str, t.Any] | ||
| None = None, # arguments to pass to AsyncEngineArgs | ||
) -> bentoml.Model: | ||
# this model is local then model_name must specified, otherwise derived from model_id | ||
is_local = getattr(model.config, "_commit_hash", None) is None | ||
if is_local is True and model_name is None: | ||
raise BentoMLException( | ||
'Fine-tune from a local checkpoint requires specifying "model_name".' | ||
) | ||
else: | ||
model_name = model_name or replace_tag(model.config._name_or_path) | ||
|
||
model_type = t.cast(ModelType, model.config.model_type) | ||
|
||
if service_config is None: | ||
service_config = {} | ||
if engine_config is None: | ||
engine_config = {} | ||
|
||
service_config.update({**MAPPINGS[model_type]["service_config"]}) | ||
|
||
engine_config.update(MAPPINGS[model.config.model_type]["engine_config"]) | ||
if quantization_method is not None: | ||
engine_config.update( | ||
{"quantization": quantization_method, "load_format": quantization_method} | ||
) | ||
|
||
with bentoml.models.create(model_name) as bentomodel: | ||
model.save_pretrained_merged( | ||
bentomodel.path, tokenizer, save_method=save_method | ||
) | ||
|
||
build_opts = dict( | ||
python=dict( | ||
packages=[ | ||
"pyyaml", | ||
"vllm==0.5.5", | ||
"fastapi==0.111.0", | ||
"unsloth[huggingface] @ git+https://github.com/bentoml/unsloth.git@main", | ||
], | ||
lock_packages=True, | ||
), | ||
envs=[{"name": "HF_TOKEN"}], | ||
) | ||
merger.merge(build_opts, get_extras().get(model_type, {})) | ||
|
||
logger.info( | ||
"Building bentos for %s, model_id=%s", model_type, model.config._name_or_path | ||
) | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
tempdir = pathlib.Path(tempdir) | ||
shutil.copytree( | ||
pathlib.Path(__file__).parent / "template", tempdir, dirs_exist_ok=True | ||
) | ||
with (tempdir / "service_config.yaml").open("w") as f: | ||
f.write( | ||
yaml.safe_dump( | ||
dict( | ||
chat_template=chat_template, | ||
model_tag=str(bentomodel.tag), | ||
engine_config=engine_config, | ||
service_config=service_config, | ||
) | ||
) | ||
) | ||
with (tempdir / "bentofile.yaml").open("w") as f: | ||
BentoBuildConfig( | ||
service="service:VLLM", | ||
name=f"{model_name.replace('.', '-')}-service", | ||
include=[ | ||
"*.py", | ||
"*.yaml", | ||
"chat_templates/*.jinja", | ||
"generation_configs/*.json", | ||
], | ||
docker=DockerOptions(python_version="3.11", system_packages=["git"]), | ||
models=[ModelSpec.from_item(str(bentomodel.tag))], | ||
description="API Service for running Unsloth models, powered with BentoML and vLLM.", | ||
**build_opts, | ||
).with_defaults().to_yaml(f) | ||
|
||
subprocess.run( | ||
[ | ||
sys.executable, | ||
"-m", | ||
"bentoml", | ||
"build", | ||
str(tempdir), | ||
], | ||
check=True, | ||
cwd=tempdir, | ||
env=os.environ, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
RUNTIME_MAPPING = { | ||
"llama": { | ||
"service_config": { | ||
"traffic": {"timeout": 300}, | ||
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"}, | ||
}, | ||
"engine_config": {"max_model_len": 2048}, | ||
}, | ||
"mistral": { | ||
"service_config": { | ||
"traffic": {"timeout": 300}, | ||
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"}, | ||
}, | ||
"engine_config": {"max_model_len": 2048}, | ||
}, | ||
"gemma": { | ||
"service_config": { | ||
"traffic": {"timeout": 300}, | ||
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"}, | ||
}, | ||
"engine_config": {"max_model_len": 2048}, | ||
}, | ||
"gemma2": { | ||
"service_config": { | ||
"traffic": {"timeout": 300}, | ||
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"}, | ||
}, | ||
"engine_config": {"max_model_len": 2048}, | ||
}, | ||
"qwen2": { | ||
"service_config": { | ||
"traffic": {"timeout": 300}, | ||
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"}, | ||
}, | ||
"engine_config": {"max_model_len": 2048}, | ||
}, | ||
} | ||
|
||
|
||
def get_extras(): | ||
return { | ||
"gemma2": { | ||
"envs": [{"name": "VLLM_ATTENTION_BACKEND", "value": "FLASHINFER"}], | ||
"python": { | ||
"extra_index_url": ["https://flashinfer.ai/whl/cu121/torch2.3"], | ||
"packages": ["flashinfer==0.1.2+cu121torch2.3"], | ||
}, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
.ipynb_checkpoints | ||
venv/ | ||
.venv/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
* text=auto eol=lf | ||
**/ui/** linguist-vendored=true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# Environments | ||
venv/ | ||
|
||
# BentoML | ||
bentoml/client_id | ||
|
||
chattts/ChatTTS/ |
Oops, something went wrong.