From a18297d8367c469e0d1c399bce82e4e77a3de01f Mon Sep 17 00:00:00 2001 From: Cedar Date: Tue, 29 Oct 2024 15:54:02 -0700 Subject: [PATCH] Revert "Add Back LLM Server Test (#358)" This reverts commit f6d54f3604b4625cdd1c58f3f34d7cab8c34b517. --- .../shortfin_apps/llm/components/service.py | 4 +- shortfin/requirements-tests.txt | 1 - .../tests/apps/llm/cpu_llm_server_test.py | 173 ------------------ 3 files changed, 2 insertions(+), 176 deletions(-) delete mode 100644 shortfin/tests/apps/llm/cpu_llm_server_test.py diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 646d186f8..1e6245d53 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -73,7 +73,7 @@ def start(self): ) ] + self.inference_modules, - devices=self.sysman.ls.devices, + fiber=self.main_fiber, trace_execution=False, ) # Resolve prefill entrypoints. @@ -393,7 +393,7 @@ async def run(self): "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(args)]), ) # Invoke. Logits are of shape [bs, bsl, d]. - (logits,) = await fn(*args, fiber=self.fiber) + (logits,) = await fn(*args) # Return results. for i in range(req_count): diff --git a/shortfin/requirements-tests.txt b/shortfin/requirements-tests.txt index ec40ec6b3..668023a1e 100644 --- a/shortfin/requirements-tests.txt +++ b/shortfin/requirements-tests.txt @@ -11,7 +11,6 @@ wheel # Deps needed for shortfin_apps.llm dataclasses-json tokenizers -sentencepiece # Deps needed for shortfin_apps.sd pillow diff --git a/shortfin/tests/apps/llm/cpu_llm_server_test.py b/shortfin/tests/apps/llm/cpu_llm_server_test.py deleted file mode 100644 index c69ad41b6..000000000 --- a/shortfin/tests/apps/llm/cpu_llm_server_test.py +++ /dev/null @@ -1,173 +0,0 @@ -import pytest -import subprocess -import time -import requests -import os -import json -import uuid -import tempfile -import shutil - -BATCH_SIZES = [1, 4] - -cpu_settings = { - "device_flags": ["-iree-hal-target-backends=llvm-cpu"], - "device": "local-task", -} - -gpu_settings = { - "device_flags": ["-iree-hal-target-backends=rocm", "--iree-hip-target=gfx1100"], - "device": "hip", -} - -settings = cpu_settings - - -@pytest.fixture(scope="module") -def model_test_dir(): - tmp_dir = tempfile.mkdtemp() - try: - # Create necessary directories - os.makedirs(tmp_dir, exist_ok=True) - - # Download model if it doesn't exist - model_path = os.path.join(tmp_dir, "open-llama-3b-v2-f16.gguf") - if not os.path.exists(model_path): - subprocess.run( - f"huggingface-cli download --local-dir {tmp_dir} SlyEcho/open_llama_3b_v2_gguf open-llama-3b-v2-f16.gguf", - shell=True, - check=True, - ) - - # Set up tokenizer if it doesn't exist - tokenizer_path = os.path.join(tmp_dir, "tokenizer.json") - if not os.path.exists(tokenizer_path): - tokenizer_setup = f""" -from transformers import AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_3b_v2") -tokenizer.save_pretrained("{tmp_dir}") -""" - subprocess.run(["python", "-c", tokenizer_setup], check=True) - - # Export model if it doesn't exist - mlir_path = os.path.join(tmp_dir, "model.mlir") - config_path = os.path.join(tmp_dir, "config.json") - if not os.path.exists(mlir_path) or not os.path.exists(config_path): - bs_string = ",".join(map(str, BATCH_SIZES)) - subprocess.run( - [ - "python", - "-m", - "sharktank.examples.export_paged_llm_v1", - f"--gguf-file={model_path}", - f"--output-mlir={mlir_path}", - f"--output-config={config_path}", - f"--bs={bs_string}", - ], - check=True, - ) - - # Compile model if it doesn't exist - vmfb_path = os.path.join(tmp_dir, "model.vmfb") - if not os.path.exists(vmfb_path): - subprocess.run( - [ - "iree-compile", - mlir_path, - "-o", - vmfb_path, - ] - + settings["device_flags"], - check=True, - ) - - # Write config if it doesn't exist - edited_config_path = os.path.join(tmp_dir, "edited_config.json") - if not os.path.exists(edited_config_path): - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": BATCH_SIZES, - "decode_batch_sizes": BATCH_SIZES, - "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - with open(edited_config_path, "w") as f: - json.dump(config, f) - - yield tmp_dir - finally: - shutil.rmtree(tmp_dir) - - -@pytest.fixture(scope="module") -def llm_server(model_test_dir): - # Start the server - server_process = subprocess.Popen( - [ - "python", - "-m", - "shortfin_apps.llm.server", - f"--tokenizer={os.path.join(model_test_dir, 'tokenizer.json')}", - f"--model_config={os.path.join(model_test_dir, 'edited_config.json')}", - f"--vmfb={os.path.join(model_test_dir, 'model.vmfb')}", - f"--parameters={os.path.join(model_test_dir, 'open-llama-3b-v2-f16.gguf')}", - f"--device={settings['device']}", - ] - ) - - # Wait for server to start - time.sleep(2) - - yield server_process - - # Teardown: kill the server - server_process.terminate() - server_process.wait() - - -def do_generate(prompt): - headers = {"Content-Type": "application/json"} - # Create a GenerateReqInput-like structure - data = { - "text": prompt, - "sampling_params": {"max_tokens": 50, "temperature": 0.7}, - "rid": uuid.uuid4().hex, - "return_logprob": False, - "logprob_start_len": -1, - "top_logprobs_num": 0, - "return_text_in_logprobs": False, - "stream": False, - } - - print("Prompt text:") - print(data["text"]) - - BASE_URL = "http://localhost:8000" - - response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data) - print(f"Generate endpoint status code: {response.status_code}") - - if response.status_code == 200: - print("Generated text:") - data = response.text - assert data.startswith("data: ") - data = data[6:] - assert data.endswith("\n\n") - data = data[:-2] - - return data - else: - response.raise_for_status() - - -def test_llm_server(llm_server): - # Here you would typically make requests to your server - # and assert on the responses - assert llm_server.poll() is None - output = do_generate("1 2 3 4 5 ") - print(output) - assert output.startswith("6 7 8")