diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml new file mode 100644 index 000000000..d890d972c --- /dev/null +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -0,0 +1,88 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: SGLang Llama Benchmarking Tests + +on: + workflow_dispatch: + schedule: + # Weekdays at 4:00 AM UTC = 9:00 PM PST. + - cron: "0 4 * * 1-5" + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + sglang_bench_serve: + name: "SGLang Serving Benchmark Tests" + strategy: + matrix: + version: [3.11] + fail-fast: false + runs-on: llama-mi300x-3 + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: Get Current Date + id: date + run: echo "::set-output name=date::$(date +'%Y-%m-%d')" + + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ + + # Try with the latest nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade \ + iree-base-compiler==2.9.0rc20241108 \ + iree-base-runtime==2.9.0rc20241108 \ + "numpy<2.0" + + - name: Install SGLang + run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" + + - name: Launch Shortfin Server + run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/sglang + destination_dir: ./llm/sglang + keep_files: true diff --git a/.github/workflows/ci-shark-platform.yml b/.github/workflows/ci-shark-platform.yml index 6741f7ea0..dc2f4646a 100644 --- a/.github/workflows/ci-shark-platform.yml +++ b/.github/workflows/ci-shark-platform.yml @@ -72,4 +72,4 @@ jobs: iree-base-runtime - name: Run LLM Integration Tests - run: pytest -v build_tools/integration_tests/llm --log-cli-level=INFO + run: pytest -v app_tests/integration_tests/llm --log-cli-level=INFO diff --git a/app_tests/__init__.py b/app_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app_tests/benchmark_tests/__init__.py b/app_tests/benchmark_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app_tests/benchmark_tests/llm/conftest.py b/app_tests/benchmark_tests/llm/conftest.py new file mode 100644 index 000000000..aac66ca0f --- /dev/null +++ b/app_tests/benchmark_tests/llm/conftest.py @@ -0,0 +1,47 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import os +import pytest +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from integration_tests.llm.utils import compile_model, export_paged_llm_v1 + + +@pytest.fixture(scope="module") +def pre_process_model(request, tmp_path_factory): + tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test") + + model_path = request.param["model_path"] + settings = request.param["settings"] + batch_sizes = request.param["batch_sizes"] + + tmp_dir = tmp_path_factory.mktemp("llm_benchmark_test") + mlir_path = tmp_dir / "model.mlir" + config_path = tmp_dir / "config.json" + vmfb_path = tmp_dir / "model.vmfb" + + export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) + + config = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 131072, + "attn_head_count": 8, + "attn_head_dim": 128, + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "transformer_block_count": 32, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + with open(config_path, "w") as file: + json.dump(config, file) + + compile_model(mlir_path, vmfb_path, settings) + + return tmp_dir diff --git a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py new file mode 100644 index 000000000..8027fcea7 --- /dev/null +++ b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py @@ -0,0 +1,108 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import logging +import multiprocessing +import os +from pathlib import Path +import pytest +import time +from unittest.mock import patch + +pytest.importorskip("sglang") +from sglang import bench_serving + +from utils import SGLangBenchmarkArgs + +from integration_tests.llm.utils import ( + find_available_port, + start_llm_server, +) + +logger = logging.getLogger("__name__") + +device_settings = { + "device_flags": [ + "--iree-hal-target-backends=rocm", + "--iree-hip-target=gfx942", + ], + "device": "hip", +} + +# TODO: Download on demand instead of assuming files exist at this path +MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa") +TOKENIZER_DIR = Path("/data/llama3.1/8b/") + + +@pytest.mark.parametrize("request_rate", [1, 2, 4, 8, 16, 32]) +@pytest.mark.parametrize( + "pre_process_model", + [ + ( + { + "model_path": MODEL_PATH, + "settings": device_settings, + "batch_sizes": [1, 4], + } + ) + ], + indirect=True, +) +def test_sglang_benchmark_server(request_rate, pre_process_model): + # TODO: Remove when multi-device is fixed + os.environ["ROCR_VISIBLE_DEVICES"] = "1" + + tmp_dir = pre_process_model + + config_path = tmp_dir / "config.json" + vmfb_path = tmp_dir / "model.vmfb" + tokenizer_path = TOKENIZER_DIR / "tokenizer.json" + + # Start shortfin llm server + port = find_available_port() + server_process = start_llm_server( + port, + tokenizer_path, + config_path, + vmfb_path, + MODEL_PATH, + device_settings, + timeout=30, + ) + + # Run and collect SGLang Serving Benchmark + benchmark_args = SGLangBenchmarkArgs( + backend="shortfin", + num_prompt=10, + base_url=f"http://localhost:{port}", + tokenizer=TOKENIZER_DIR, + request_rate=request_rate, + ) + output_file = ( + tmp_dir + / f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl" + ) + benchmark_args.output_file = output_file + + logger.info("Running SGLang Benchmark with the following args:") + logger.info(benchmark_args) + try: + start = time.time() + with patch.object(bench_serving, "print", side_effect=logger.info): + benchmark_process = multiprocessing.Process( + target=bench_serving.run_benchmark, + args=(benchmark_args.as_namespace(),), + ) + benchmark_process.start() + benchmark_process.join() + + logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds") + except Exception as e: + logger.info(e) + + server_process.terminate() + server_process.wait() diff --git a/app_tests/benchmark_tests/llm/utils.py b/app_tests/benchmark_tests/llm/utils.py new file mode 100644 index 000000000..c217720cb --- /dev/null +++ b/app_tests/benchmark_tests/llm/utils.py @@ -0,0 +1,55 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from argparse import Namespace +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class SGLangBenchmarkArgs: + base_url: str + num_prompt: int + request_rate: int + tokenizer: str | Path + + seed: int = 1 + extra_request_body: str | None = None + output_file: str | Path | None = None + port: int = 8000 + backend: str = "shortfin" + + def as_namespace(self) -> Namespace: + return Namespace( + num_prompts=self.num_prompt, + base_url=self.base_url, + tokenizer=str(self.tokenizer), + request_rate=self.request_rate, + backend=self.backend, + output_file=self.output_file, + seed=self.seed, + extra_request_body=self.extra_request_body, + port=8000, + model=None, + dataset_name="sharegpt", + random_input_len=None, + random_output_len=None, + dataset_path="", + sharegpt_output_len=None, + multi=False, + disable_tqdm=False, + disable_stream=False, + disable_ignore_eos=False, + ) + + def __repr__(self): + return ( + f"Backend: {self.backend}\n" + f"Base URL: {self.base_url}\n" + f"Num Prompt: {self.num_prompt}\n" + f"Tokenizer: {self.tokenizer}\n" + f"Request Rate: {self.request_rate}" + ) diff --git a/app_tests/integration_tests/__init__.py b/app_tests/integration_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app_tests/integration_tests/llm/__init__.py b/app_tests/integration_tests/llm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app_tests/integration_tests/llm/conftest.py b/app_tests/integration_tests/llm/conftest.py new file mode 100644 index 000000000..17cdf1def --- /dev/null +++ b/app_tests/integration_tests/llm/conftest.py @@ -0,0 +1,135 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import logging +import os +from pathlib import Path +import pytest +import shutil + +pytest.importorskip("transformers") +from .utils import ( + download_huggingface_model, + download_tokenizer, + export_paged_llm_v1, + compile_model, + find_available_port, + start_llm_server, +) + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def model_test_dir(request, tmp_path_factory): + """Prepare model artifacts for starting the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - repo_id (str): The Hugging Face repo ID. + - model_file (str): The model file to download. + - tokenizer_id (str): The tokenizer ID to download. + - settings (dict): The settings for sharktank export. + - batch_sizes (list): The batch sizes to use for the model. + tmp_path_factory (TempPathFactory): Temp dir to save artifacts to. + + Yields: + Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir. + """ + logger.info("Preparing model artifacts...") + + repo_id = request.param["repo_id"] + model_file = request.param["model_file"] + tokenizer_id = request.param["tokenizer_id"] + settings = request.param["settings"] + batch_sizes = request.param["batch_sizes"] + + tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test") + hf_home = os.environ.get("HF_HOME", None) + hf_home = Path(hf_home) if hf_home is not None else tmp_dir + try: + # Download model if it doesn't exist + model_path = hf_home / model_file + download_huggingface_model(hf_home, repo_id, model_file) + + # Set up tokenizer if it doesn't exist + download_tokenizer(hf_home, tokenizer_id) + + # Export model + mlir_path = tmp_dir / "model.mlir" + config_path = tmp_dir / "config.json" + export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) + + # Compile model + vmfb_path = tmp_dir / "model.vmfb" + compile_model(mlir_path, vmfb_path, settings) + + # Write config + edited_config_path = tmp_dir / "edited_config.json" + config = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 2048, + "attn_head_count": 32, + "attn_head_dim": 100, + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "transformer_block_count": 26, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + logger.info(f"Saving edited config to: {edited_config_path}\n") + logger.info(f"Config: {json.dumps(config, indent=2)}") + with open(edited_config_path, "w") as f: + json.dump(config, f) + logger.info("Model artifacts setup successfully") + yield hf_home, tmp_dir + finally: + shutil.rmtree(tmp_dir) + + +@pytest.fixture(scope="module") +def available_port(): + return find_available_port() + + +@pytest.fixture(scope="module") +def llm_server(request, model_test_dir, available_port): + """Start the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - model_file (str): The model file to download. + - settings (dict): The settings for starting the server. + model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir. + available_port (int): The available port to start the server on. + + Yields: + subprocess.Popen: The server process that was started. + """ + logger.info("Starting LLM server...") + hf_home, tmp_dir = model_test_dir + model_file = request.param["model_file"] + settings = request.param["settings"] + + tokenizer_path = hf_home / "tokenizer.json" + config_path = tmp_dir / "edited_config.json" + vmfb_path = tmp_dir / "model.vmfb" + parameters_path = hf_home / model_file + + # Start llm server + server_process = start_llm_server( + available_port, + tokenizer_path, + config_path, + vmfb_path, + parameters_path, + settings, + ) + yield server_process + # Teardown: kill the server + server_process.terminate() + server_process.wait() diff --git a/build_tools/integration_tests/llm/cpu_llm_server_test.py b/app_tests/integration_tests/llm/cpu_llm_server_test.py similarity index 98% rename from build_tools/integration_tests/llm/cpu_llm_server_test.py rename to app_tests/integration_tests/llm/cpu_llm_server_test.py index 4d4ec5540..e7d0792d8 100644 --- a/build_tools/integration_tests/llm/cpu_llm_server_test.py +++ b/app_tests/integration_tests/llm/cpu_llm_server_test.py @@ -10,7 +10,7 @@ import requests import uuid -from utils import AccuracyValidationException +from .utils import AccuracyValidationException logger = logging.getLogger(__name__) diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py new file mode 100644 index 000000000..b8b5ae60f --- /dev/null +++ b/app_tests/integration_tests/llm/utils.py @@ -0,0 +1,180 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import multiprocessing +import os +import subprocess +import sys +import time + +import requests +from transformers import AutoTokenizer + +logger = logging.getLogger("__name__") + + +class AccuracyValidationException(RuntimeError): + pass + + +def download_huggingface_model(local_dir, repo_id, model_file): + model_path = local_dir / model_file + logger.info(f"Preparing model_path: {model_path}..") + if not os.path.exists(model_path): + logger.info(f"Downloading model {repo_id} {model_file} from Hugging Face...") + subprocess.run( + f"huggingface-cli download --local-dir {local_dir} {repo_id} {model_file}", + shell=True, + check=True, + ) + logger.info(f"Model downloaded to {model_path}") + else: + logger.info("Using cached model") + + +def download_tokenizer(local_dir, tokenizer_id): + # Set up tokenizer if it doesn't exist + tokenizer_path = local_dir / "tokenizer.json" + logger.info(f"Preparing tokenizer_path: {tokenizer_path}...") + if not os.path.exists(tokenizer_path): + logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...") + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_id, + ) + tokenizer.save_pretrained(local_dir) + logger.info(f"Tokenizer saved to {tokenizer_path}") + else: + logger.info("Using cached tokenizer") + + +def export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes): + bs_string = ",".join(map(str, batch_sizes)) + logger.info( + "Exporting model with following settings:\n" + f" MLIR Path: {mlir_path}\n" + f" Config Path: {config_path}\n" + f" Batch Sizes: {bs_string}" + ) + subprocess.run( + [ + "python", + "-m", + "sharktank.examples.export_paged_llm_v1", + f"--{model_path.suffix.strip('.')}-file={model_path}", + f"--output-mlir={mlir_path}", + f"--output-config={config_path}", + f"--bs={bs_string}", + ], + check=True, + ) + logger.info(f"Model successfully exported to {mlir_path}") + + +def compile_model(mlir_path, vmfb_path, device_settings): + logger.info(f"Compiling model to {vmfb_path}") + subprocess.run( + [ + "iree-compile", + mlir_path, + "-o", + vmfb_path, + ] + + device_settings["device_flags"], + check=True, + ) + logger.info(f"Model successfully compiled to {vmfb_path}") + + +def find_available_port(): + import socket + from contextlib import closing + + logger.info(f"Finding available port...") + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = s.getsockname()[1] + logger.info(f"Found available port: {port}") + return port + + +def wait_for_server(url, timeout=10): + logger.info(f"Waiting for server to start at {url}...") + start = time.time() + while time.time() - start < timeout: + try: + requests.get(f"{url}/health") + logger.info("Server successfully started") + return + except requests.exceptions.ConnectionError: + time.sleep(1) + raise TimeoutError(f"Server did not start within {timeout} seconds") + + +def _start_llm_server_args( + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + port, +): + return [ + sys.executable, + "-m", + "shortfin_apps.llm.server", + f"--tokenizer_json={tokenizer_path}", + f"--model_config={model_config_path}", + f"--vmfb={vmfb_path}", + f"--parameters={parameters_path}", + f"--device={settings['device']}", + f"--port={port}", + ] + + +def start_llm_server( + port, + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + timeout=10, + multi=False, +): + logger.info("Starting LLM server...") + if multi: + server_process = multiprocessing.Process( + target=subprocess.Popen( + _start_llm_server_args( + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + port, + ), + ) + ) + server_process.start() + + else: + # Start the server + server_process = subprocess.Popen( + _start_llm_server_args( + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + port, + ) + ) + logger.info("Process started... waiting for server") + # Wait for server to start + wait_for_server(f"http://localhost:{port}", timeout) + return server_process diff --git a/build_tools/integration_tests/llm/conftest.py b/build_tools/integration_tests/llm/conftest.py deleted file mode 100644 index 1103065bc..000000000 --- a/build_tools/integration_tests/llm/conftest.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import json -import logging -import os -from pathlib import Path -import pytest -import requests -import shutil -import subprocess -import time - -pytest.importorskip("transformers") -from transformers import AutoTokenizer - -logger = logging.getLogger(__name__) - - -@pytest.fixture(scope="module") -def model_test_dir(request, tmp_path_factory): - """Prepare model artifacts for starting the LLM server. - - Args: - request (FixtureRequest): The following params are accepted: - - repo_id (str): The Hugging Face repo ID. - - model_file (str): The model file to download. - - tokenizer_id (str): The tokenizer ID to download. - - settings (dict): The settings for sharktank export. - - batch_sizes (list): The batch sizes to use for the model. - tmp_path_factory (TempPathFactory): Temp dir to save artifacts to. - - Yields: - Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir. - """ - logger.info("Preparing model artifacts...") - - repo_id = request.param["repo_id"] - model_file = request.param["model_file"] - tokenizer_id = request.param["tokenizer_id"] - settings = request.param["settings"] - batch_sizes = request.param["batch_sizes"] - - tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test") - hf_home = os.environ.get("HF_HOME", None) - hf_home = Path(hf_home) if hf_home is not None else tmp_dir - try: - # Download model if it doesn't exist - model_path = hf_home / model_file - logger.info(f"Preparing model_path: {model_path}..") - if not os.path.exists(model_path): - logger.info( - f"Downloading model {repo_id} {model_file} from Hugging Face..." - ) - subprocess.run( - f"huggingface-cli download --local-dir {hf_home} {repo_id} {model_file}", - shell=True, - check=True, - ) - logger.info(f"Model downloaded to {model_path}") - else: - logger.info("Using cached model") - - # Set up tokenizer if it doesn't exist - tokenizer_path = hf_home / "tokenizer.json" - logger.info(f"Preparing tokenizer_path: {tokenizer_path}...") - if not os.path.exists(tokenizer_path): - logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_id, - ) - tokenizer.save_pretrained(hf_home) - logger.info(f"Tokenizer saved to {tokenizer_path}") - else: - logger.info("Using cached tokenizer") - - # Export model - mlir_path = tmp_dir / "model.mlir" - config_path = tmp_dir / "config.json" - bs_string = ",".join(map(str, batch_sizes)) - logger.info( - "Exporting model with following settings:\n" - f" MLIR Path: {mlir_path}\n" - f" Config Path: {config_path}\n" - f" Batch Sizes: {bs_string}" - ) - subprocess.run( - [ - "python", - "-m", - "sharktank.examples.export_paged_llm_v1", - f"--gguf-file={model_path}", - f"--output-mlir={mlir_path}", - f"--output-config={config_path}", - f"--bs={bs_string}", - ], - check=True, - ) - logger.info(f"Model successfully exported to {mlir_path}") - - # Compile model - vmfb_path = tmp_dir / "model.vmfb" - logger.info(f"Compiling model to {vmfb_path}") - subprocess.run( - [ - "iree-compile", - mlir_path, - "-o", - vmfb_path, - ] - + settings["device_flags"], - check=True, - ) - logger.info(f"Model successfully compiled to {vmfb_path}") - - # Write config if it doesn't exist - edited_config_path = tmp_dir / "edited_config.json" - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - logger.info(f"Saving edited config to: {edited_config_path}\n") - logger.info(f"Config: {json.dumps(config, indent=2)}") - with open(edited_config_path, "w") as f: - json.dump(config, f) - logger.info("Model artifacts setup successfully") - yield hf_home, tmp_dir - finally: - shutil.rmtree(tmp_dir) - - -@pytest.fixture(scope="module") -def available_port(port=8000, max_port=8100): - import socket - - logger.info(f"Finding available port in range {port}-{max_port}...") - - starting_port = port - - while port < max_port: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("localhost", port)) - s.close() - logger.info(f"Found available port: {port}") - return port - except socket.error: - port += 1 - - raise IOError(f"No available ports found within range {starting_port}-{max_port}") - - -def wait_for_server(url, timeout=10): - logger.info(f"Waiting for server to start at {url}...") - start = time.time() - while time.time() - start < timeout: - try: - requests.get(f"{url}/health") - logger.info("Server successfully started") - return - except requests.exceptions.ConnectionError: - time.sleep(1) - raise TimeoutError(f"Server did not start within {timeout} seconds") - - -@pytest.fixture(scope="module") -def llm_server(request, model_test_dir, available_port): - """Start the LLM server. - - Args: - request (FixtureRequest): The following params are accepted: - - model_file (str): The model file to download. - - settings (dict): The settings for starting the server. - model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir. - available_port (int): The available port to start the server on. - - Yields: - subprocess.Popen: The server process that was started. - """ - logger.info("Starting LLM server...") - # Start the server - hf_home, tmp_dir = model_test_dir - model_file = request.param["model_file"] - settings = request.param["settings"] - server_process = subprocess.Popen( - [ - "python", - "-m", - "shortfin_apps.llm.server", - f"--tokenizer_json={hf_home / 'tokenizer.json'}", - f"--model_config={tmp_dir / 'edited_config.json'}", - f"--vmfb={tmp_dir / 'model.vmfb'}", - f"--parameters={hf_home / model_file}", - f"--device={settings['device']}", - ] - ) - # Wait for server to start - wait_for_server(f"http://localhost:{available_port}") - yield server_process - # Teardown: kill the server - server_process.terminate() - server_process.wait() diff --git a/build_tools/integration_tests/llm/utils.py b/build_tools/integration_tests/llm/utils.py deleted file mode 100644 index b31a3e416..000000000 --- a/build_tools/integration_tests/llm/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - - -class AccuracyValidationException(RuntimeError): - pass