Skip to content

Commit

Permalink
CPU LLM Integration Test (#373)
Browse files Browse the repository at this point in the history
- Adds integration test for Shortfin CPU LLM Server and Sharktank export script
  • Loading branch information
stbaione authored Nov 5, 2024
1 parent 22535ba commit 3dcca1f
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

name: Integration Tests
name: CI - shark-platform

on:
workflow_dispatch:
schedule:
# Weekdays at 13:00 UTC = 05:00 PST / 06:00 PDT.
- cron: "5 4 * * 1-5"
pull_request:
push:
branches:
- main

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
Expand All @@ -21,14 +22,13 @@ concurrency:
cancel-in-progress: true

jobs:
test_llama:
name: "Integration Tests - llama"
test_shortfin_llm_server:
name: "Integration Tests - Shortfin LLM Server"
strategy:
matrix:
version: [3.11]
os: [ubuntu-latest, windows-latest]
fail-fast: false
runs-on: ${{matrix.os}}
runs-on: nodai-amdgpu-mi250-x86-64
defaults:
run:
shell: bash
Expand Down Expand Up @@ -70,31 +70,5 @@ jobs:
iree-runtime \
"numpy<2.0"
- name: Run llama test
run: ./build_tools/integration_tests/llama_export_compile_serve.sh

test_punet:
name: "Integration Tests - punet"
runs-on: nodai-amdgpu-mi250-x86-64
env:
VENV_DIR: ${{ github.workspace }}/.venv
steps:
- name: "Checkout Code"
uses: actions/checkout@v3

- name: "Setup Python venv"
run: python3 -m venv ${VENV_DIR}

- name: Install pip deps
run: |
source ${VENV_DIR}/bin/activate
python -m pip install --no-compile --upgrade pip
pip install --no-compile -r pytorch-rocm-requirements.txt
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ shortfin/
- name: Run punet tests
run: |
source ${VENV_DIR}/bin/activate
pytest -v sharktank/ -m model_punet
- name: Run LLM Integration Tests
run: pytest -v build_tools/integration_tests/llm --log-cli-level=INFO
47 changes: 0 additions & 47 deletions build_tools/integration_tests/llama_export_compile_serve.sh

This file was deleted.

206 changes: 206 additions & 0 deletions build_tools/integration_tests/llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import json
import logging
import os
from pathlib import Path
import pytest
import requests
import shutil
import subprocess
import time

pytest.importorskip("transformers")
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)


@pytest.fixture(scope="module")
def model_test_dir(request, tmp_path_factory):
"""Prepare model artifacts for starting the LLM server.
Args:
request (FixtureRequest): The following params are accepted:
- repo_id (str): The Hugging Face repo ID.
- model_file (str): The model file to download.
- tokenizer_id (str): The tokenizer ID to download.
- settings (dict): The settings for sharktank export.
- batch_sizes (list): The batch sizes to use for the model.
tmp_path_factory (TempPathFactory): Temp dir to save artifacts to.
Yields:
Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir.
"""
logger.info("Preparing model artifacts...")

repo_id = request.param["repo_id"]
model_file = request.param["model_file"]
tokenizer_id = request.param["tokenizer_id"]
settings = request.param["settings"]
batch_sizes = request.param["batch_sizes"]

tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test")
hf_home = os.environ.get("HF_HOME", None)
hf_home = Path(hf_home) if hf_home is not None else tmp_dir
try:
# Download model if it doesn't exist
model_path = hf_home / model_file
logger.info(f"Preparing model_path: {model_path}..")
if not os.path.exists(model_path):
logger.info(
f"Downloading model {repo_id} {model_file} from Hugging Face..."
)
subprocess.run(
f"huggingface-cli download --local-dir {hf_home} {repo_id} {model_file}",
shell=True,
check=True,
)
logger.info(f"Model downloaded to {model_path}")
else:
logger.info("Using cached model")

# Set up tokenizer if it doesn't exist
tokenizer_path = hf_home / "tokenizer.json"
logger.info(f"Preparing tokenizer_path: {tokenizer_path}...")
if not os.path.exists(tokenizer_path):
logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_id,
)
tokenizer.save_pretrained(hf_home)
logger.info(f"Tokenizer saved to {tokenizer_path}")
else:
logger.info("Using cached tokenizer")

# Export model
mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
bs_string = ",".join(map(str, batch_sizes))
logger.info(
"Exporting model with following settings:\n"
f" MLIR Path: {mlir_path}\n"
f" Config Path: {config_path}\n"
f" Batch Sizes: {bs_string}"
)
subprocess.run(
[
"python",
"-m",
"sharktank.examples.export_paged_llm_v1",
f"--gguf-file={model_path}",
f"--output-mlir={mlir_path}",
f"--output-config={config_path}",
f"--bs={bs_string}",
],
check=True,
)
logger.info(f"Model successfully exported to {mlir_path}")

# Compile model
vmfb_path = tmp_dir / "model.vmfb"
logger.info(f"Compiling model to {vmfb_path}")
subprocess.run(
[
"iree-compile",
mlir_path,
"-o",
vmfb_path,
]
+ settings["device_flags"],
check=True,
)
logger.info(f"Model successfully compiled to {vmfb_path}")

# Write config if it doesn't exist
edited_config_path = tmp_dir / "edited_config.json"
config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 2048,
"attn_head_count": 32,
"attn_head_dim": 100,
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 26,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
}
logger.info(f"Saving edited config to: {edited_config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
with open(edited_config_path, "w") as f:
json.dump(config, f)
logger.info("Model artifacts setup successfully")
yield hf_home, tmp_dir
finally:
shutil.rmtree(tmp_dir)


@pytest.fixture(scope="module")
def available_port(port=8000, max_port=8100):
import socket

logger.info(f"Finding available port in range {port}-{max_port}...")

starting_port = port

while port < max_port:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", port))
s.close()
logger.info(f"Found available port: {port}")
return port
except socket.error:
port += 1

raise IOError(f"No available ports found within range {starting_port}-{max_port}")


def wait_for_server(url, timeout=10):
logger.info(f"Waiting for server to start at {url}...")
start = time.time()
while time.time() - start < timeout:
try:
requests.get(f"{url}/health")
logger.info("Server successfully started")
return
except requests.exceptions.ConnectionError:
time.sleep(1)
raise TimeoutError(f"Server did not start within {timeout} seconds")


@pytest.fixture(scope="module")
def llm_server(request, model_test_dir, available_port):
"""Start the LLM server.
Args:
request (FixtureRequest): The following params are accepted:
- model_file (str): The model file to download.
- settings (dict): The settings for starting the server.
model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir.
available_port (int): The available port to start the server on.
Yields:
subprocess.Popen: The server process that was started.
"""
logger.info("Starting LLM server...")
# Start the server
hf_home, tmp_dir = model_test_dir
model_file = request.param["model_file"]
settings = request.param["settings"]
server_process = subprocess.Popen(
[
"python",
"-m",
"shortfin_apps.llm.server",
f"--tokenizer={hf_home / 'tokenizer.json'}",
f"--model_config={tmp_dir / 'edited_config.json'}",
f"--vmfb={tmp_dir / 'model.vmfb'}",
f"--parameters={hf_home / model_file}",
f"--device={settings['device']}",
]
)
# Wait for server to start
wait_for_server(f"http://localhost:{available_port}")
yield server_process
# Teardown: kill the server
server_process.terminate()
server_process.wait()
Loading

0 comments on commit 3dcca1f

Please sign in to comment.