Skip to content

Commit

Permalink
feat: unsloth integrations
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <[email protected]>
  • Loading branch information
aarnphm committed Sep 4, 2024
1 parent d66d828 commit eb759c9
Show file tree
Hide file tree
Showing 43 changed files with 1,187 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,7 @@ mlruns/
.pdm-python
.python-version
.pdm-build/

# from training scripts
model
outputs
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ci:
autoupdate_schedule: monthly
autofix_commit_msg: "ci: auto fixes from pre-commit.ci\n\nFor more information, see https://pre-commit.ci"
autoupdate_commit_msg: 'ci: pre-commit autoupdate [skip ci]'
skip: # exceeds tier max size
skip: # exceeds tier max size
- buf-format
- buf-lint
exclude: '(.*\.(css|js|svg))|(.*/(snippets|grpc|proto)/.*)$'
Expand All @@ -13,7 +13,9 @@ repos:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
types_or: [python, pyi]
exclude: ^src/_bentoml_impl/frameworks/unsloth/train\.py$
- id: ruff-format
exclude: ^src/_bentoml_impl/frameworks/unsloth/train\.py$
types_or: [python, pyi]
files: '(src|tests|docs|examples|typings)/'
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ io = [
io-image = ["Pillow"]
io-pandas = ["pandas>=1", "pyarrow"]
triton = ["tritonclient>=2.29.0", "tritonclient[all]; sys_platform != 'darwin'"]
unsloth = [
"unsloth[huggingface] @ git+https://github.com/bentoml/unsloth.git@main",
"vllm>=0.5.5",
"fastapi"
]
grpc = [
"protobuf",
"grpcio",
Expand Down Expand Up @@ -238,6 +243,11 @@ testpaths = ["tests"]
line-length = 88
target-version = "py310"

[tool.ruff.format]
exclude = [
"src/_bentoml_impl/frameworks/unsloth/train.py",
]

[tool.ruff.lint]
# We ignore E501 (line too long) here because we keep user-visible strings on one line.
ignore = ["E501"]
Expand All @@ -250,6 +260,7 @@ exclude = [
"src/bentoml/_internal/external_typing",
"src/bentoml/grpc/v1alpha1",
"src/bentoml/grpc/v1",
"src/_bentoml_impl/frameworks/unsloth",
"tests/proto",
]

Expand Down
176 changes: 176 additions & 0 deletions src/_bentoml_impl/frameworks/unsloth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from __future__ import annotations

import logging
import os
import pathlib
import shutil
import subprocess
import sys
import tempfile
import typing as t

import yaml
from deepmerge.merger import Merger

import bentoml
from _bentoml_sdk.service.config import ServiceConfig as Config
from bentoml._internal.bento.build_config import BentoBuildConfig
from bentoml._internal.bento.build_config import DockerOptions
from bentoml._internal.bento.build_config import ModelSpec
from bentoml._internal.utils import pkg
from bentoml.exceptions import BentoMLException
from bentoml.exceptions import MissingDependencyException

from .mapping import RUNTIME_MAPPING as MAPPINGS
from .mapping import get_extras

logger = logging.getLogger(__name__)

if t.TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers import PreTrainedTokenizerFast

if pkg.find_spec("unsloth") is None:
raise MissingDependencyException(
"'unsloth' is required in order to use module 'bentoml.unsloth', install unsloth with 'pip install bentoml[unsloth]'."
)

merger = Merger(
# merge dicts, append list
[(dict, "merge"), (list, "append")],
# override all other types
["override"],
# override conflicting types
["override"],
)


def replace_tag(tag: str) -> str:
return tag.lower().replace("/", "--")


ChatTemplate = t.Literal[
"alpaca",
"amberchat",
"chatml",
"chatqa",
"falcon-instruct",
"gemma-it",
"llama-2-chat",
"llama-3-chat",
"mistral-instruct",
"openchat",
"phi-3",
"saiga",
"solar-instruct",
"vicuna",
"zephyr",
]

ModelType = t.Literal["llama", "mistral", "gemma", "gemma2", "qwen2"]


def build_bento(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerFast,
/,
model_name: str | None = None,
*,
chat_template: ChatTemplate,
quantization_method: t.Literal["bitsandbytes"] | None = None,
save_method: t.Literal["merged_16bit", "merged_4bit"] = "merged_16bit",
service_config: Config | None = None,
engine_config: dict[str, t.Any]
| None = None, # arguments to pass to AsyncEngineArgs
) -> bentoml.Model:
# this model is local then model_name must specified, otherwise derived from model_id
is_local = getattr(model.config, "_commit_hash", None) is None
if is_local is True and model_name is None:
raise BentoMLException(
'Fine-tune from a local checkpoint requires specifying "model_name".'
)
else:
model_name = model_name or replace_tag(model.config._name_or_path)

model_type = t.cast(ModelType, model.config.model_type)

if service_config is None:
service_config = {}
if engine_config is None:
engine_config = {}

service_config.update({**MAPPINGS[model_type]["service_config"]})

engine_config.update(MAPPINGS[model.config.model_type]["engine_config"])
if quantization_method is not None:
engine_config.update(
{"quantization": quantization_method, "load_format": quantization_method}
)

with bentoml.models.create(model_name) as bentomodel:
model.save_pretrained_merged(
bentomodel.path, tokenizer, save_method=save_method
)

build_opts = dict(
python=dict(
packages=[
"pyyaml",
"vllm==0.5.5",
"fastapi==0.111.0",
"unsloth[huggingface] @ git+https://github.com/bentoml/unsloth.git@main",
],
lock_packages=True,
),
envs=[{"name": "HF_TOKEN"}],
)
merger.merge(build_opts, get_extras().get(model_type, {}))

logger.info(
"Building bentos for %s, model_id=%s", model_type, model.config._name_or_path
)

with tempfile.TemporaryDirectory() as tempdir:
tempdir = pathlib.Path(tempdir)
shutil.copytree(
pathlib.Path(__file__).parent / "template", tempdir, dirs_exist_ok=True
)
with (tempdir / "service_config.yaml").open("w") as f:
f.write(
yaml.safe_dump(
dict(
chat_template=chat_template,
model_tag=str(bentomodel.tag),
engine_config=engine_config,
service_config=service_config,
)
)
)
with (tempdir / "bentofile.yaml").open("w") as f:
BentoBuildConfig(
service="service:VLLM",
name=f"{model_name.replace('.', '-')}-service",
include=[
"*.py",
"*.yaml",
"chat_templates/*.jinja",
"generation_configs/*.json",
],
docker=DockerOptions(python_version="3.11", system_packages=["git"]),
models=[ModelSpec.from_item(str(bentomodel.tag))],
description="API Service for running Unsloth models, powered with BentoML and vLLM.",
**build_opts,
).with_defaults().to_yaml(f)

subprocess.run(
[
sys.executable,
"-m",
"bentoml",
"build",
str(tempdir),
],
check=True,
cwd=tempdir,
env=os.environ,
)
49 changes: 49 additions & 0 deletions src/_bentoml_impl/frameworks/unsloth/mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
RUNTIME_MAPPING = {
"llama": {
"service_config": {
"traffic": {"timeout": 300},
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"},
},
"engine_config": {"max_model_len": 2048},
},
"mistral": {
"service_config": {
"traffic": {"timeout": 300},
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"},
},
"engine_config": {"max_model_len": 2048},
},
"gemma": {
"service_config": {
"traffic": {"timeout": 300},
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"},
},
"engine_config": {"max_model_len": 2048},
},
"gemma2": {
"service_config": {
"traffic": {"timeout": 300},
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"},
},
"engine_config": {"max_model_len": 2048},
},
"qwen2": {
"service_config": {
"traffic": {"timeout": 300},
"resources": {"gpu": 1, "gpu_type": "nvidia-l4"},
},
"engine_config": {"max_model_len": 2048},
},
}


def get_extras():
return {
"gemma2": {
"envs": [{"name": "VLLM_ATTENTION_BACKEND", "value": "FLASHINFER"}],
"python": {
"extra_index_url": ["https://flashinfer.ai/whl/cu121/torch2.3"],
"packages": ["flashinfer==0.1.2+cu121torch2.3"],
},
}
}
6 changes: 6 additions & 0 deletions src/_bentoml_impl/frameworks/unsloth/template/.bentoignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
__pycache__/
*.py[cod]
*$py.class
.ipynb_checkpoints
venv/
.venv/
2 changes: 2 additions & 0 deletions src/_bentoml_impl/frameworks/unsloth/template/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
* text=auto eol=lf
**/ui/** linguist-vendored=true
12 changes: 12 additions & 0 deletions src/_bentoml_impl/frameworks/unsloth/template/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# Environments
venv/

# BentoML
bentoml/client_id

chattts/ChatTTS/
Loading

0 comments on commit eb759c9

Please sign in to comment.