-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
# Description Create a nightly workflow for SGLang Benchmark test that enables running a Shortfin server and benchmarking from SGLang, using the `bench_serving` script. ## `bench_serving` Invocations The bench_serving script is ran with various `request-rate` arguments: - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer=<tokenizer_path> `--request-rate 1` --output-file <tmp_dir>/shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer=<tokenizer_path> `--request-rate 2` --output-file <tmp_dir>/shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer=<tokenizer_path> `--request-rate 4` --output-file <tmp_dir>/shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer=<tokenizer_path> `--request-rate 8` --output-file <tmp_dir>/shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer=<tokenizer_path> `--request-rate 16` --output-file <tmp_dir>/shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer=<tokenizer_path> `--request-rate 32` --output-file <tmp_dir>/shortfin_10_1.jsonl After the test is finished running, we upload the html output from pytest to gh-pages. The subdirectory is set to `./llm/sglang`, so the results should be accessible from the browser at `/llm/sglang/index.html` in gh-pages. This also includes a refactor of the existing integration test. Most of the methods for downloading a model/tokenizer, exporting to mlir, compiling to vmfb, and starting a shortfin server have been moved to a `utils.py` file.
- Loading branch information
Showing
14 changed files
with
615 additions
and
223 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
name: SGLang Llama Benchmarking Tests | ||
|
||
on: | ||
workflow_dispatch: | ||
schedule: | ||
# Weekdays at 4:00 AM UTC = 9:00 PM PST. | ||
- cron: "0 4 * * 1-5" | ||
|
||
concurrency: | ||
# A PR number if a pull request and otherwise the commit hash. This cancels | ||
# queued and in-progress runs for the same PR (presubmit) or commit | ||
# (postsubmit). The workflow name is prepended to avoid conflicts between | ||
# different workflows. | ||
group: ${{ github.workflow }}-${{ github.event.number || github.sha }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
sglang_bench_serve: | ||
name: "SGLang Serving Benchmark Tests" | ||
strategy: | ||
matrix: | ||
version: [3.11] | ||
fail-fast: false | ||
runs-on: llama-mi300x-3 | ||
defaults: | ||
run: | ||
shell: bash | ||
env: | ||
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" | ||
steps: | ||
- name: Get Current Date | ||
id: date | ||
run: echo "::set-output name=date::$(date +'%Y-%m-%d')" | ||
|
||
- name: "Setting up Python" | ||
id: setup_python | ||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 | ||
with: | ||
python-version: ${{matrix.version}} | ||
|
||
- name: "Checkout Code" | ||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
|
||
- name: Cache Pip Packages | ||
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 | ||
id: cache-pip | ||
with: | ||
path: ${{ env.PIP_CACHE_DIR }} | ||
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} | ||
|
||
- name: Install pip deps | ||
run: | | ||
python -m pip install --no-compile --upgrade pip | ||
# Note: We install in three steps in order to satisfy requirements | ||
# from non default locations first. Installing the PyTorch CPU | ||
# wheels saves multiple minutes and a lot of bandwidth on runner setup. | ||
pip install --no-compile -r pytorch-cpu-requirements.txt | ||
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ | ||
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" | ||
pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ | ||
# Try with the latest nightly releases, not what iree-turbine pins. | ||
# We could also pin to a known working or stable version. | ||
# This should eventually stabilize. Do the best we can for now. | ||
pip install -f https://iree.dev/pip-release-links.html --upgrade \ | ||
iree-base-compiler==2.9.0rc20241108 \ | ||
iree-base-runtime==2.9.0rc20241108 \ | ||
"numpy<2.0" | ||
- name: Install SGLang | ||
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" | ||
|
||
- name: Launch Shortfin Server | ||
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html | ||
|
||
- name: Deploy to GitHub Pages | ||
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 | ||
with: | ||
github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} | ||
publish_dir: ./out/llm/sglang | ||
destination_dir: ./llm/sglang | ||
keep_files: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import json | ||
import os | ||
import pytest | ||
import sys | ||
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) | ||
from integration_tests.llm.utils import compile_model, export_paged_llm_v1 | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def pre_process_model(request, tmp_path_factory): | ||
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test") | ||
|
||
model_path = request.param["model_path"] | ||
settings = request.param["settings"] | ||
batch_sizes = request.param["batch_sizes"] | ||
|
||
tmp_dir = tmp_path_factory.mktemp("llm_benchmark_test") | ||
mlir_path = tmp_dir / "model.mlir" | ||
config_path = tmp_dir / "config.json" | ||
vmfb_path = tmp_dir / "model.vmfb" | ||
|
||
export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) | ||
|
||
config = { | ||
"module_name": "module", | ||
"module_abi_version": 1, | ||
"max_seq_len": 131072, | ||
"attn_head_count": 8, | ||
"attn_head_dim": 128, | ||
"prefill_batch_sizes": batch_sizes, | ||
"decode_batch_sizes": batch_sizes, | ||
"transformer_block_count": 32, | ||
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, | ||
} | ||
with open(config_path, "w") as file: | ||
json.dump(config, file) | ||
|
||
compile_model(mlir_path, vmfb_path, settings) | ||
|
||
return tmp_dir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import json | ||
import logging | ||
import multiprocessing | ||
import os | ||
from pathlib import Path | ||
import pytest | ||
import time | ||
from unittest.mock import patch | ||
|
||
pytest.importorskip("sglang") | ||
from sglang import bench_serving | ||
|
||
from utils import SGLangBenchmarkArgs | ||
|
||
from integration_tests.llm.utils import ( | ||
find_available_port, | ||
start_llm_server, | ||
) | ||
|
||
logger = logging.getLogger("__name__") | ||
|
||
device_settings = { | ||
"device_flags": [ | ||
"--iree-hal-target-backends=rocm", | ||
"--iree-hip-target=gfx942", | ||
], | ||
"device": "hip", | ||
} | ||
|
||
# TODO: Download on demand instead of assuming files exist at this path | ||
MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa") | ||
TOKENIZER_DIR = Path("/data/llama3.1/8b/") | ||
|
||
|
||
@pytest.mark.parametrize("request_rate", [1, 2, 4, 8, 16, 32]) | ||
@pytest.mark.parametrize( | ||
"pre_process_model", | ||
[ | ||
( | ||
{ | ||
"model_path": MODEL_PATH, | ||
"settings": device_settings, | ||
"batch_sizes": [1, 4], | ||
} | ||
) | ||
], | ||
indirect=True, | ||
) | ||
def test_sglang_benchmark_server(request_rate, pre_process_model): | ||
# TODO: Remove when multi-device is fixed | ||
os.environ["ROCR_VISIBLE_DEVICES"] = "1" | ||
|
||
tmp_dir = pre_process_model | ||
|
||
config_path = tmp_dir / "config.json" | ||
vmfb_path = tmp_dir / "model.vmfb" | ||
tokenizer_path = TOKENIZER_DIR / "tokenizer.json" | ||
|
||
# Start shortfin llm server | ||
port = find_available_port() | ||
server_process = start_llm_server( | ||
port, | ||
tokenizer_path, | ||
config_path, | ||
vmfb_path, | ||
MODEL_PATH, | ||
device_settings, | ||
timeout=30, | ||
) | ||
|
||
# Run and collect SGLang Serving Benchmark | ||
benchmark_args = SGLangBenchmarkArgs( | ||
backend="shortfin", | ||
num_prompt=10, | ||
base_url=f"http://localhost:{port}", | ||
tokenizer=TOKENIZER_DIR, | ||
request_rate=request_rate, | ||
) | ||
output_file = ( | ||
tmp_dir | ||
/ f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl" | ||
) | ||
benchmark_args.output_file = output_file | ||
|
||
logger.info("Running SGLang Benchmark with the following args:") | ||
logger.info(benchmark_args) | ||
try: | ||
start = time.time() | ||
with patch.object(bench_serving, "print", side_effect=logger.info): | ||
benchmark_process = multiprocessing.Process( | ||
target=bench_serving.run_benchmark, | ||
args=(benchmark_args.as_namespace(),), | ||
) | ||
benchmark_process.start() | ||
benchmark_process.join() | ||
|
||
logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds") | ||
except Exception as e: | ||
logger.info(e) | ||
|
||
server_process.terminate() | ||
server_process.wait() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from argparse import Namespace | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
|
||
|
||
@dataclass | ||
class SGLangBenchmarkArgs: | ||
base_url: str | ||
num_prompt: int | ||
request_rate: int | ||
tokenizer: str | Path | ||
|
||
seed: int = 1 | ||
extra_request_body: str | None = None | ||
output_file: str | Path | None = None | ||
port: int = 8000 | ||
backend: str = "shortfin" | ||
|
||
def as_namespace(self) -> Namespace: | ||
return Namespace( | ||
num_prompts=self.num_prompt, | ||
base_url=self.base_url, | ||
tokenizer=str(self.tokenizer), | ||
request_rate=self.request_rate, | ||
backend=self.backend, | ||
output_file=self.output_file, | ||
seed=self.seed, | ||
extra_request_body=self.extra_request_body, | ||
port=8000, | ||
model=None, | ||
dataset_name="sharegpt", | ||
random_input_len=None, | ||
random_output_len=None, | ||
dataset_path="", | ||
sharegpt_output_len=None, | ||
multi=False, | ||
disable_tqdm=False, | ||
disable_stream=False, | ||
disable_ignore_eos=False, | ||
) | ||
|
||
def __repr__(self): | ||
return ( | ||
f"Backend: {self.backend}\n" | ||
f"Base URL: {self.base_url}\n" | ||
f"Num Prompt: {self.num_prompt}\n" | ||
f"Tokenizer: {self.tokenizer}\n" | ||
f"Request Rate: {self.request_rate}" | ||
) |
Empty file.
Empty file.
Oops, something went wrong.