Skip to content

Commit

Permalink
Sglang benchmark test (#476)
Browse files Browse the repository at this point in the history
# Description

Create a nightly workflow for SGLang Benchmark test that enables running
a Shortfin server and benchmarking from SGLang, using the
`bench_serving` script.

## `bench_serving` Invocations
The bench_serving script is ran with various `request-rate` arguments:

- python -m sglang.bench_serving --backend shortfin --num-prompt 10
--base-url http://localhost:8000 --tokenizer=<tokenizer_path>
`--request-rate 1` --output-file <tmp_dir>/shortfin_10_1.jsonl
- python -m sglang.bench_serving --backend shortfin --num-prompt 10
--base-url http://localhost:8000 --tokenizer=<tokenizer_path>
`--request-rate 2` --output-file <tmp_dir>/shortfin_10_1.jsonl
- python -m sglang.bench_serving --backend shortfin --num-prompt 10
--base-url http://localhost:8000 --tokenizer=<tokenizer_path>
`--request-rate 4` --output-file <tmp_dir>/shortfin_10_1.jsonl
- python -m sglang.bench_serving --backend shortfin --num-prompt 10
--base-url http://localhost:8000 --tokenizer=<tokenizer_path>
`--request-rate 8` --output-file <tmp_dir>/shortfin_10_1.jsonl
- python -m sglang.bench_serving --backend shortfin --num-prompt 10
--base-url http://localhost:8000 --tokenizer=<tokenizer_path>
`--request-rate 16` --output-file <tmp_dir>/shortfin_10_1.jsonl
- python -m sglang.bench_serving --backend shortfin --num-prompt 10
--base-url http://localhost:8000 --tokenizer=<tokenizer_path>
`--request-rate 32` --output-file <tmp_dir>/shortfin_10_1.jsonl

After the test is finished running, we upload the html output from
pytest to gh-pages. The subdirectory is set to `./llm/sglang`, so the
results should be accessible from the browser at
`/llm/sglang/index.html` in gh-pages.

This also includes a refactor of the existing integration test. Most of
the methods for downloading a model/tokenizer, exporting to mlir,
compiling to vmfb, and starting a shortfin server have been moved to a
`utils.py` file.
  • Loading branch information
stbaione authored Nov 14, 2024
1 parent 46420ec commit 86bd384
Show file tree
Hide file tree
Showing 14 changed files with 615 additions and 223 deletions.
88 changes: 88 additions & 0 deletions .github/workflows/ci-sglang-benchmark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

name: SGLang Llama Benchmarking Tests

on:
workflow_dispatch:
schedule:
# Weekdays at 4:00 AM UTC = 9:00 PM PST.
- cron: "0 4 * * 1-5"

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit). The workflow name is prepended to avoid conflicts between
# different workflows.
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

jobs:
sglang_bench_serve:
name: "SGLang Serving Benchmark Tests"
strategy:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
steps:
- name: Get Current Date
id: date
run: echo "::set-output name=date::$(date +'%Y-%m-%d')"

- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Cache Pip Packages
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}

- name: Install pip deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
pip install --no-compile -r requirements.txt -e sharktank/ shortfin/
# Try with the latest nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
pip install -f https://iree.dev/pip-release-links.html --upgrade \
iree-base-compiler==2.9.0rc20241108 \
iree-base-runtime==2.9.0rc20241108 \
"numpy<2.0"
- name: Install SGLang
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"

- name: Launch Shortfin Server
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
with:
github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
publish_dir: ./out/llm/sglang
destination_dir: ./llm/sglang
keep_files: true
2 changes: 1 addition & 1 deletion .github/workflows/ci-shark-platform.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ jobs:
iree-base-runtime
- name: Run LLM Integration Tests
run: pytest -v build_tools/integration_tests/llm --log-cli-level=INFO
run: pytest -v app_tests/integration_tests/llm --log-cli-level=INFO
Empty file added app_tests/__init__.py
Empty file.
Empty file.
47 changes: 47 additions & 0 deletions app_tests/benchmark_tests/llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import os
import pytest
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from integration_tests.llm.utils import compile_model, export_paged_llm_v1


@pytest.fixture(scope="module")
def pre_process_model(request, tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")

model_path = request.param["model_path"]
settings = request.param["settings"]
batch_sizes = request.param["batch_sizes"]

tmp_dir = tmp_path_factory.mktemp("llm_benchmark_test")
mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"

export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)

config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 131072,
"attn_head_count": 8,
"attn_head_dim": 128,
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 32,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
}
with open(config_path, "w") as file:
json.dump(config, file)

compile_model(mlir_path, vmfb_path, settings)

return tmp_dir
108 changes: 108 additions & 0 deletions app_tests/benchmark_tests/llm/sglang_benchmark_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import logging
import multiprocessing
import os
from pathlib import Path
import pytest
import time
from unittest.mock import patch

pytest.importorskip("sglang")
from sglang import bench_serving

from utils import SGLangBenchmarkArgs

from integration_tests.llm.utils import (
find_available_port,
start_llm_server,
)

logger = logging.getLogger("__name__")

device_settings = {
"device_flags": [
"--iree-hal-target-backends=rocm",
"--iree-hip-target=gfx942",
],
"device": "hip",
}

# TODO: Download on demand instead of assuming files exist at this path
MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa")
TOKENIZER_DIR = Path("/data/llama3.1/8b/")


@pytest.mark.parametrize("request_rate", [1, 2, 4, 8, 16, 32])
@pytest.mark.parametrize(
"pre_process_model",
[
(
{
"model_path": MODEL_PATH,
"settings": device_settings,
"batch_sizes": [1, 4],
}
)
],
indirect=True,
)
def test_sglang_benchmark_server(request_rate, pre_process_model):
# TODO: Remove when multi-device is fixed
os.environ["ROCR_VISIBLE_DEVICES"] = "1"

tmp_dir = pre_process_model

config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"
tokenizer_path = TOKENIZER_DIR / "tokenizer.json"

# Start shortfin llm server
port = find_available_port()
server_process = start_llm_server(
port,
tokenizer_path,
config_path,
vmfb_path,
MODEL_PATH,
device_settings,
timeout=30,
)

# Run and collect SGLang Serving Benchmark
benchmark_args = SGLangBenchmarkArgs(
backend="shortfin",
num_prompt=10,
base_url=f"http://localhost:{port}",
tokenizer=TOKENIZER_DIR,
request_rate=request_rate,
)
output_file = (
tmp_dir
/ f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl"
)
benchmark_args.output_file = output_file

logger.info("Running SGLang Benchmark with the following args:")
logger.info(benchmark_args)
try:
start = time.time()
with patch.object(bench_serving, "print", side_effect=logger.info):
benchmark_process = multiprocessing.Process(
target=bench_serving.run_benchmark,
args=(benchmark_args.as_namespace(),),
)
benchmark_process.start()
benchmark_process.join()

logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds")
except Exception as e:
logger.info(e)

server_process.terminate()
server_process.wait()
55 changes: 55 additions & 0 deletions app_tests/benchmark_tests/llm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from argparse import Namespace
from dataclasses import dataclass
from pathlib import Path


@dataclass
class SGLangBenchmarkArgs:
base_url: str
num_prompt: int
request_rate: int
tokenizer: str | Path

seed: int = 1
extra_request_body: str | None = None
output_file: str | Path | None = None
port: int = 8000
backend: str = "shortfin"

def as_namespace(self) -> Namespace:
return Namespace(
num_prompts=self.num_prompt,
base_url=self.base_url,
tokenizer=str(self.tokenizer),
request_rate=self.request_rate,
backend=self.backend,
output_file=self.output_file,
seed=self.seed,
extra_request_body=self.extra_request_body,
port=8000,
model=None,
dataset_name="sharegpt",
random_input_len=None,
random_output_len=None,
dataset_path="",
sharegpt_output_len=None,
multi=False,
disable_tqdm=False,
disable_stream=False,
disable_ignore_eos=False,
)

def __repr__(self):
return (
f"Backend: {self.backend}\n"
f"Base URL: {self.base_url}\n"
f"Num Prompt: {self.num_prompt}\n"
f"Tokenizer: {self.tokenizer}\n"
f"Request Rate: {self.request_rate}"
)
Empty file.
Empty file.
Loading

0 comments on commit 86bd384

Please sign in to comment.