diff --git a/.github/workflows/test.yaml b/.github/workflows/ci-shark-platform.yml similarity index 65% rename from .github/workflows/test.yaml rename to .github/workflows/ci-shark-platform.yml index 503e44e8a..d9f4a35da 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/ci-shark-platform.yml @@ -4,13 +4,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: Integration Tests +name: CI - shark-platform on: workflow_dispatch: - schedule: - # Weekdays at 13:00 UTC = 05:00 PST / 06:00 PDT. - - cron: "5 4 * * 1-5" + pull_request: + push: + branches: + - main concurrency: # A PR number if a pull request and otherwise the commit hash. This cancels @@ -21,14 +22,13 @@ concurrency: cancel-in-progress: true jobs: - test_llama: - name: "Integration Tests - llama" + test_shortfin_llm_server: + name: "Integration Tests - Shortfin LLM Server" strategy: matrix: version: [3.11] - os: [ubuntu-latest, windows-latest] fail-fast: false - runs-on: ${{matrix.os}} + runs-on: nodai-amdgpu-mi250-x86-64 defaults: run: shell: bash @@ -70,31 +70,5 @@ jobs: iree-runtime \ "numpy<2.0" - - name: Run llama test - run: ./build_tools/integration_tests/llama_export_compile_serve.sh - - test_punet: - name: "Integration Tests - punet" - runs-on: nodai-amdgpu-mi250-x86-64 - env: - VENV_DIR: ${{ github.workspace }}/.venv - steps: - - name: "Checkout Code" - uses: actions/checkout@v3 - - - name: "Setup Python venv" - run: python3 -m venv ${VENV_DIR} - - - name: Install pip deps - run: | - source ${VENV_DIR}/bin/activate - python -m pip install --no-compile --upgrade pip - pip install --no-compile -r pytorch-rocm-requirements.txt - pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ shortfin/ - - - name: Run punet tests - run: | - source ${VENV_DIR}/bin/activate - pytest -v sharktank/ -m model_punet + - name: Run LLM Integration Tests + run: pytest -v build_tools/integration_tests/llm --log-cli-level=INFO diff --git a/build_tools/integration_tests/llama_export_compile_serve.sh b/build_tools/integration_tests/llama_export_compile_serve.sh deleted file mode 100755 index edd54b688..000000000 --- a/build_tools/integration_tests/llama_export_compile_serve.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -set -xeuo pipefail - -# Assume that the environment is already set up: -# * Python venv set up with requirements, sharktank, and shortfin -# * iree-compile and iree-run-module on $PATH -# * authenticated with `huggingface-cli login` - -# Input variables. -# Default model: https://huggingface.co/SlyEcho/open_llama_3b_v2_gguf -# Default tokenizer: https://huggingface.co/openlm-research/open_llama_3b_v2 -TEMP_DIR="${TEMP_DIR:-/tmp/sharktank/llama}" -HUGGING_FACE_MODEL_NAME="${HUGGING_FACE_MODEL_NAME:-SlyEcho/open_llama_3b_v2_gguf}" -HUGGING_FACE_MODEL_FILE="${HUGGING_FACE_MODEL_FILE:-open-llama-3b-v2-f16.gguf}" -HUGGING_FACE_TOKENIZER="${HUGGING_FACE_TOKENIZER:-openlm-research/open_llama_3b_v2}" - -# Derived variables. -LOCAL_GGUF_FILE="${TEMP_DIR}/${HUGGING_FACE_MODEL_FILE}" -LOCAL_MLIR_FILE="${TEMP_DIR}/model.mlir" -LOCAL_CONFIG_FILE="${TEMP_DIR}/config.json" -LOCAL_VMFB_FILE="${TEMP_DIR}/model.vmfb" - -mkdir -p ${TEMP_DIR} - -huggingface-cli download --local-dir ${TEMP_DIR} ${HUGGING_FACE_MODEL_NAME} ${HUGGING_FACE_MODEL_FILE} - -python -m sharktank.examples.export_paged_llm_v1 \ - --gguf-file="${LOCAL_GGUF_FILE}" \ - --output-mlir="${LOCAL_MLIR_FILE}" \ - --output-config="${LOCAL_CONFIG_FILE}" - -iree-compile "${LOCAL_MLIR_FILE}" \ - --iree-hal-target-backends=llvm-cpu \ - --iree-llvmcpu-target-cpu-features=host \ - -o ${LOCAL_VMFB_FILE} - -python -m shortfin.llm.impl.service_v1_cli \ - --tokenizer="${HUGGING_FACE_TOKENIZER}" \ - --config="${LOCAL_CONFIG_FILE}" \ - --vmfb="${LOCAL_VMFB_FILE}" \ - --gguf="${LOCAL_GGUF_FILE}" diff --git a/build_tools/integration_tests/llm/conftest.py b/build_tools/integration_tests/llm/conftest.py new file mode 100644 index 000000000..1bc014e63 --- /dev/null +++ b/build_tools/integration_tests/llm/conftest.py @@ -0,0 +1,206 @@ +import json +import logging +import os +from pathlib import Path +import pytest +import requests +import shutil +import subprocess +import time + +pytest.importorskip("transformers") +from transformers import AutoTokenizer + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def model_test_dir(request, tmp_path_factory): + """Prepare model artifacts for starting the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - repo_id (str): The Hugging Face repo ID. + - model_file (str): The model file to download. + - tokenizer_id (str): The tokenizer ID to download. + - settings (dict): The settings for sharktank export. + - batch_sizes (list): The batch sizes to use for the model. + tmp_path_factory (TempPathFactory): Temp dir to save artifacts to. + + Yields: + Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir. + """ + logger.info("Preparing model artifacts...") + + repo_id = request.param["repo_id"] + model_file = request.param["model_file"] + tokenizer_id = request.param["tokenizer_id"] + settings = request.param["settings"] + batch_sizes = request.param["batch_sizes"] + + tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test") + hf_home = os.environ.get("HF_HOME", None) + hf_home = Path(hf_home) if hf_home is not None else tmp_dir + try: + # Download model if it doesn't exist + model_path = hf_home / model_file + logger.info(f"Preparing model_path: {model_path}..") + if not os.path.exists(model_path): + logger.info( + f"Downloading model {repo_id} {model_file} from Hugging Face..." + ) + subprocess.run( + f"huggingface-cli download --local-dir {hf_home} {repo_id} {model_file}", + shell=True, + check=True, + ) + logger.info(f"Model downloaded to {model_path}") + else: + logger.info("Using cached model") + + # Set up tokenizer if it doesn't exist + tokenizer_path = hf_home / "tokenizer.json" + logger.info(f"Preparing tokenizer_path: {tokenizer_path}...") + if not os.path.exists(tokenizer_path): + logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...") + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_id, + ) + tokenizer.save_pretrained(hf_home) + logger.info(f"Tokenizer saved to {tokenizer_path}") + else: + logger.info("Using cached tokenizer") + + # Export model + mlir_path = tmp_dir / "model.mlir" + config_path = tmp_dir / "config.json" + bs_string = ",".join(map(str, batch_sizes)) + logger.info( + "Exporting model with following settings:\n" + f" MLIR Path: {mlir_path}\n" + f" Config Path: {config_path}\n" + f" Batch Sizes: {bs_string}" + ) + subprocess.run( + [ + "python", + "-m", + "sharktank.examples.export_paged_llm_v1", + f"--gguf-file={model_path}", + f"--output-mlir={mlir_path}", + f"--output-config={config_path}", + f"--bs={bs_string}", + ], + check=True, + ) + logger.info(f"Model successfully exported to {mlir_path}") + + # Compile model + vmfb_path = tmp_dir / "model.vmfb" + logger.info(f"Compiling model to {vmfb_path}") + subprocess.run( + [ + "iree-compile", + mlir_path, + "-o", + vmfb_path, + ] + + settings["device_flags"], + check=True, + ) + logger.info(f"Model successfully compiled to {vmfb_path}") + + # Write config if it doesn't exist + edited_config_path = tmp_dir / "edited_config.json" + config = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 2048, + "attn_head_count": 32, + "attn_head_dim": 100, + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "transformer_block_count": 26, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + logger.info(f"Saving edited config to: {edited_config_path}\n") + logger.info(f"Config: {json.dumps(config, indent=2)}") + with open(edited_config_path, "w") as f: + json.dump(config, f) + logger.info("Model artifacts setup successfully") + yield hf_home, tmp_dir + finally: + shutil.rmtree(tmp_dir) + + +@pytest.fixture(scope="module") +def available_port(port=8000, max_port=8100): + import socket + + logger.info(f"Finding available port in range {port}-{max_port}...") + + starting_port = port + + while port < max_port: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", port)) + s.close() + logger.info(f"Found available port: {port}") + return port + except socket.error: + port += 1 + + raise IOError(f"No available ports found within range {starting_port}-{max_port}") + + +def wait_for_server(url, timeout=10): + logger.info(f"Waiting for server to start at {url}...") + start = time.time() + while time.time() - start < timeout: + try: + requests.get(f"{url}/health") + logger.info("Server successfully started") + return + except requests.exceptions.ConnectionError: + time.sleep(1) + raise TimeoutError(f"Server did not start within {timeout} seconds") + + +@pytest.fixture(scope="module") +def llm_server(request, model_test_dir, available_port): + """Start the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - model_file (str): The model file to download. + - settings (dict): The settings for starting the server. + model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir. + available_port (int): The available port to start the server on. + + Yields: + subprocess.Popen: The server process that was started. + """ + logger.info("Starting LLM server...") + # Start the server + hf_home, tmp_dir = model_test_dir + model_file = request.param["model_file"] + settings = request.param["settings"] + server_process = subprocess.Popen( + [ + "python", + "-m", + "shortfin_apps.llm.server", + f"--tokenizer={hf_home / 'tokenizer.json'}", + f"--model_config={tmp_dir / 'edited_config.json'}", + f"--vmfb={tmp_dir / 'model.vmfb'}", + f"--parameters={hf_home / model_file}", + f"--device={settings['device']}", + ] + ) + # Wait for server to start + wait_for_server(f"http://localhost:{available_port}") + yield server_process + # Teardown: kill the server + server_process.terminate() + server_process.wait() diff --git a/build_tools/integration_tests/llm/cpu_llm_server_test.py b/build_tools/integration_tests/llm/cpu_llm_server_test.py new file mode 100644 index 000000000..1b27e12da --- /dev/null +++ b/build_tools/integration_tests/llm/cpu_llm_server_test.py @@ -0,0 +1,85 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import os +import pytest +import requests +import uuid + +logger = logging.getLogger(__name__) + +CPU_SETTINGS = { + "device_flags": [ + "-iree-hal-target-backends=llvm-cpu", + "--iree-llvmcpu-target-cpu=host", + ], + "device": "local-task", +} +IREE_HIP_TARGET = os.environ.get("IREE_HIP_TARGET", "gfx1100") +gpu_settings = { + "device_flags": [ + "-iree-hal-target-backends=rocm", + f"--iree-hip-target={IREE_HIP_TARGET}", + ], + "device": "hip", +} + + +def do_generate(prompt, port): + logger.info("Generating request...") + headers = {"Content-Type": "application/json"} + # Create a GenerateReqInput-like structure + data = { + "text": prompt, + "sampling_params": {"max_tokens": 50, "temperature": 0.7}, + "rid": uuid.uuid4().hex, + "return_logprob": False, + "logprob_start_len": -1, + "top_logprobs_num": 0, + "return_text_in_logprobs": False, + "stream": False, + } + logger.info("Prompt text:") + logger.info(data["text"]) + BASE_URL = f"http://localhost:{port}" + response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data) + logger.info(f"Generate endpoint status code: {response.status_code}") + if response.status_code == 200: + logger.info("Generated text:") + data = response.text + assert data.startswith("data: ") + data = data[6:] + assert data.endswith("\n\n") + data = data[:-2] + return data + else: + response.raise_for_status() + + +@pytest.mark.parametrize( + "model_test_dir,llm_server", + [ + ( + { + "repo_id": "SlyEcho/open_llama_3b_v2_gguf", + "model_file": "open-llama-3b-v2-f16.gguf", + "tokenizer_id": "openlm-research/open_llama_3b_v2", + "settings": CPU_SETTINGS, + "batch_sizes": [1, 4], + }, + {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, + ) + ], + indirect=True, +) +def test_llm_server(llm_server, available_port): + # Here you would typically make requests to your server + # and assert on the responses + assert llm_server.poll() is None + output = do_generate("1 2 3 4 5 ", available_port) + logger.info(output) + assert output.startswith("6 7 8") diff --git a/shortfin/build_tools/python_lsan_suppressions.txt b/shortfin/build_tools/python_lsan_suppressions.txt index 498f1ea72..f3ac58064 100644 --- a/shortfin/build_tools/python_lsan_suppressions.txt +++ b/shortfin/build_tools/python_lsan_suppressions.txt @@ -8,3 +8,4 @@ leak:import_find_and_load leak:pyo3::pyclass::create_type_object leak:ufunc leak:pydantic_core +leak:sentencepiece diff --git a/shortfin/requirements-tests.txt b/shortfin/requirements-tests.txt index 668023a1e..c04c97af2 100644 --- a/shortfin/requirements-tests.txt +++ b/shortfin/requirements-tests.txt @@ -11,6 +11,8 @@ wheel # Deps needed for shortfin_apps.llm dataclasses-json tokenizers +huggingface_hub[cli] +sentencepiece # Deps needed for shortfin_apps.sd pillow