diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 7b80bd61b..a528cfa13 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -15,13 +15,13 @@ concurrency: cancel-in-progress: true jobs: - test_perplexity: + test_perplexity_vmfb: timeout-minutes: 1000 - name: "Evaluation Tests - perplexity" + name: "Evaluation Tests - perplexity_vmfb" strategy: matrix: version: [3.11] - runs-on: [llama-mi300] + runs-on: [llama-mi300x-3] fail-fast: false runs-on: ${{matrix.runs-on}} defaults: @@ -58,5 +58,58 @@ jobs: -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ - - name: Run perplexity test - run: pytest -n 4 -v -s sharktank/tests/evaluate/perplexity_test.py --longrun + # Try with the latest nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade \ + iree-compiler \ + iree-runtime \ + "numpy<2.0" + - name: Run perplexity test with vmfb + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target='gfx942' --llama3-8b-f16-model-path=/data/llama-3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama-3.1/8b/tokenizer_config.json + + test_perplexity_torch: + timeout-minutes: 1000 + name: "Evaluation Tests - perplexity_torch" + strategy: + matrix: + version: [3.11] + runs-on: [llama-mi300x-3] + fail-fast: false + runs-on: ${{matrix.runs-on}} + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }} + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@v3 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@v3 + + - name: Cache Pip Packages + uses: actions/cache@v4 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install sharktank deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + - name: Run perplexity test in eager mode + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama-3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama-3.1/8b/tokenizer_config.json diff --git a/sharktank/conftest.py b/sharktank/conftest.py index a5583b711..2076c39eb 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -72,20 +72,19 @@ def pytest_addoption(parser): help="Enable long and slow tests", ) + # TODO: Remove all hardcoded paths in CI tests parser.addoption( "--llama3-8b-tokenizer-path", type=Path, action="store", - default="/data/extra/models/llama3.1_8B/tokenizer_config.json", help="Llama3.1 8b tokenizer path, defaults to 30F CI system path", ) parser.addoption( - "--llama3-8b-f16-gguf-path", + "--llama3-8b-f16-model-path", type=Path, action="store", - default="/data/extra/models/llama3.1_8B/llama8b_f16.gguf", - help="Llama3.1 8b gguf model path, defaults to 30F CI system path", + help="Llama3.1 8b model path, defaults to 30F CI system path", ) parser.addoption( @@ -100,16 +99,14 @@ def pytest_addoption(parser): "--llama3-405b-tokenizer-path", type=Path, action="store", - default="/data/extra/models/llama3.1_405B/tokenizer_config.json", help="Llama3.1 405b tokenizer path, defaults to 30F CI system path", ) parser.addoption( - "--llama3-405b-f16-gguf-path", + "--llama3-405b-f16-model-path", type=Path, action="store", - default="/data/extra/models/llama3.1_405B/llama405b_fp16.gguf", - help="Llama3.1 405b gguf model path, defaults to 30F CI system path", + help="Llama3.1 405b model path, defaults to 30F CI system path", ) parser.addoption( @@ -121,20 +118,49 @@ def pytest_addoption(parser): ) parser.addoption( - "--baseline-perplexity-score-json", + "--baseline-perplexity-scores", type=Path, action="store", default="sharktank/tests/evaluate/baseline_perplexity_scores.json", - help="Llama3.1 8B & 405B model baseline perplexity scores json", + help="Llama3.1 8B & 405B model baseline perplexity scores", + ) + + parser.addoption( + "--iree-device", + type=str, + action="store", + help="List an IREE device from iree-run-module --list_devices", ) parser.addoption( "--iree-hip-target", action="store", - default="gfx942", help="Specify the iree-hip target version (e.g., gfx942)", ) + parser.addoption( + "--iree-hal-target-backends", + action="store", + default="rocm", + help="Specify the iree-hal target backend (e.g., rocm)", + ) + + parser.addoption( + "--tensor-parallelism-size", + action="store", + type=int, + default=1, + help="Number of devices for tensor parallel sharding", + ) + + parser.addoption( + "--bs", + action="store", + type=int, + default=4, + help="Batch size for mlir export", + ) + def set_fixture_from_cli_option( request: FixtureRequest, @@ -183,13 +209,32 @@ def iree_hip_target_type(request: FixtureRequest) -> Optional[str]: @pytest.fixture(scope="class") -def get_model_path(request: FixtureRequest): +def tensor_parallelism_size(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option( + request, "tensor_parallelism_size", "tensor_parallelism_size" + ) + + +@pytest.fixture(scope="class") +def baseline_perplexity_scores(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option( + request, "baseline_perplexity_scores", "baseline_perplexity_scores" + ) + + +@pytest.fixture(scope="class") +def batch_size(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option(request, "bs", "batch_size") + + +@pytest.fixture(scope="class") +def get_model_artifacts(request: FixtureRequest): model_path = {} model_path["llama3_8b_tokenizer_path"] = set_fixture_from_cli_option( request, "--llama3-8b-tokenizer-path", "llama3_8b_tokenizer" ) - model_path["llama3_8b_f16_gguf_path"] = set_fixture_from_cli_option( - request, "--llama3-8b-f16-gguf-path", "llama3_8b_f16_model" + model_path["llama3_8b_f16_model_path"] = set_fixture_from_cli_option( + request, "--llama3-8b-f16-model-path", "llama3_8b_f16_model" ) model_path["llama3_8b_fp8_model_path"] = set_fixture_from_cli_option( request, "--llama3-8b-fp8-model-path", "llama3_8b_fp8_model" @@ -197,13 +242,24 @@ def get_model_path(request: FixtureRequest): model_path["llama3_405b_tokenizer_path"] = set_fixture_from_cli_option( request, "--llama3-405b-tokenizer-path", "llama3_405b_tokenizer" ) - model_path["llama3_405b_f16_gguf_path"] = set_fixture_from_cli_option( - request, "--llama3-405b-f16-gguf-path", "llama3_405b_f16_model" + model_path["llama3_405b_f16_model_path"] = set_fixture_from_cli_option( + request, "--llama3-405b-f16-model-path", "llama3_405b_f16_model" ) model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option( request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model" ) - model_path["baseline_perplexity_score_json"] = set_fixture_from_cli_option( - request, "--baseline-perplexity-score-json", "baseline_perplexity_score_json" - ) return model_path + + +@pytest.fixture(scope="class") +def get_iree_flags(request: FixtureRequest): + model_path = {} + model_path["iree_device"] = set_fixture_from_cli_option( + request, "--iree-device", "iree_device" + ) + model_path["iree_hip_target"] = set_fixture_from_cli_option( + request, "--iree-hip-target", "iree_hip_target" + ) + model_path["iree_hal_target_backends"] = set_fixture_from_cli_option( + request, "--iree-hal-target-backends", "iree_hal_target_backends" + ) diff --git a/sharktank/sharktank/evaluate/perplexity.py b/sharktank/sharktank/evaluate/perplexity_torch.py similarity index 97% rename from sharktank/sharktank/evaluate/perplexity.py rename to sharktank/sharktank/evaluate/perplexity_torch.py index aa9d35dcc..fc3aa5fca 100644 --- a/sharktank/sharktank/evaluate/perplexity.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -42,10 +42,10 @@ logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") ) -__all__ = ["Perplexity", "run_perplexity"] +__all__ = ["Perplexity_torch", "run_perplexity_torch"] -class Perplexity: +class Perplexity_torch: """ Perplexity (PPL) is one of the most common metrics for evaluating language models. It is defined as the exponentiated average negative log-likelihood of a sequence, @@ -59,8 +59,6 @@ def __init__( device, kv_cache_type, ): - self.batch_size = 16 - self.device = device self.kv_cache_type = kv_cache_type self.activation_dtype = torch.float32 @@ -173,6 +171,8 @@ def get_logits(self): (self.token_ids != 0).int().detach().clone().to(self.device) ) + self.bs = len(self.test_prompts) + is_first_token = True start = 0 for i in tqdm( @@ -263,8 +263,6 @@ def compute_perplexity(self): def get_perplexity(self, test_prompts): self.test_prompts = test_prompts - self.bs = len(self.test_prompts) - self.get_logits() self.out_logits = self.out_logits[..., :-1, :].contiguous() @@ -282,7 +280,7 @@ def get_perplexity(self, test_prompts): return self.compute_perplexity() -def run_perplexity( +def run_perplexity_torch( dataset, tokenizer, device, @@ -290,7 +288,7 @@ def run_perplexity( tensor_parallelism_size, attention_kernel, ): - perplexity = Perplexity(device=device, kv_cache_type=kv_cache_type) + perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type) perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel) test_prompts = perplexity.get_prompts() @@ -326,7 +324,7 @@ def main(argv): dataset = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) - ppl = run_perplexity( + ppl = run_perplexity_torch( dataset=dataset, tokenizer=tokenizer, device=device, diff --git a/sharktank/sharktank/evaluate/perplexity_vmfb.py b/sharktank/sharktank/evaluate/perplexity_vmfb.py new file mode 100644 index 000000000..fedf7c1c9 --- /dev/null +++ b/sharktank/sharktank/evaluate/perplexity_vmfb.py @@ -0,0 +1,453 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys +import logging +import json +import time +import random +from datetime import timedelta +from tqdm import tqdm + +import numpy as np + +from datasets import load_dataset + +import torch +from torch.nn import CrossEntropyLoss + +from sharktank.models.llama.llama import * +from sharktank.models.mixtral.mixtral import * +from sharktank.models.grok.grok import * + +from ..models.llama.sharding import shard_theta + +from sharktank.layers import * +from sharktank.types import * + +from sharktank.utils import cli +from sharktank.utils.vmfb_runner import * +from sharktank.utils.load_llm import * +from sharktank.utils.create_cache import * +from sharktank.utils.export_artifacts import * + +log_levels = { + "info": logging.INFO, + "debug": logging.DEBUG, +} +logger = logging.getLogger("eval") + +logger.setLevel(log_levels["info"]) + +logger.root.handlers[0].setFormatter( + logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") +) + +__all__ = ["Perplexity", "run_perplexity"] + + +class Perplexity: + """ + Perplexity (PPL) is one of the most common metrics for evaluating language models. + It is defined as the exponentiated average negative log-likelihood of a sequence, + calculated with exponent base `e`. + + For more information, see https://huggingface.co/docs/transformers/perplexity + """ + + def __init__( + self, + torch_device, + iree_device, + iree_hip_target, + iree_hal_target_backends, + kv_cache_type, + tensor_parallelism_size, + attention_kernel, + ): + self.torch_device = torch_device + self.iree_device = iree_device + self.iree_hip_target = iree_hip_target + self.iree_hal_target_backends = iree_hal_target_backends + self.kv_cache_type = kv_cache_type + self.activation_dtype = torch.float32 + self.attention_dtype = torch.float32 + self.tensor_parallelism_size = tensor_parallelism_size + self.attention_kernel = attention_kernel + + def timeit(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + seconds = end - start + time_taken = abs(timedelta(seconds=round(seconds))) + + if seconds < 1: + time_taken = f" {seconds * 1000} ms" + + func_name = func.__name__ + if func_name == "get_perplexity": + func_name = f"Total time to calculate perplexity" + elif func_name == "compile_model": + func_name = f"Total time to export and compile" + logger.info(f" {func_name}: {time_taken}") + return result + + return wrapper + + def print_token_comparison(self, i): + if i <= self.max_prompt_length: + batch_predicted_token_id = [[i[-1]] for i in self.batch.results] + batch_predicted_token = self.generator.tokenizer.decode( + batch_predicted_token_id + ) + logger.debug(f"Predicted:") + logger.debug(f"{batch_predicted_token}") + logger.debug(f"{batch_predicted_token_id}") + + expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist() + expected_token = self.generator.tokenizer.decode(expected_token_id) + logger.debug(f"Expected:") + logger.debug(f"{expected_token}") + logger.debug(f"{expected_token_id}") + + @timeit + def compile_model(self, weight_path_str): + self.weight_path_str = weight_path_str + + logger.info(f"Compiling: {self.weight_path_str}") + + export_artifacts = ExportArtifacts( + irpa_path=self.weight_path_str, + batch_size=self.bs, + iree_hip_target=self.iree_hip_target, + iree_hal_target_backends=self.iree_hal_target_backends, + attention_kernel=self.attention_kernel, + tensor_parallelism_size=self.tensor_parallelism_size, + ) + vmfb_path = export_artifacts.get_artifacts() + return vmfb_path + + @timeit + def load_model(self, weight_path, tokenizer, vmfb_path): + + config = LlamaModelConfig( + hp=configs.LlamaHParams.from_gguf_props(weight_path.properties), + block_seq_stride=16, + kv_cache_type=self.kv_cache_type, + device=self.torch_device, + activation_dtype=self.activation_dtype, + attention_dtype=self.attention_dtype, + tensor_parallelism_size=self.tensor_parallelism_size, + ) + + if config.tensor_parallelism_size > 1: + weight_path.root_theta = shard_theta(weight_path.root_theta, config) + + theta = weight_path.root_theta + + if config.hp.expert_count: + if config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, config) + else: + model = PagedMixtralModelV1(theta, config) + else: + model = PagedLlamaModelV1(theta, config) + + self.generator = TorchGenerator(model, tokenizer) + + self.runner = vmfbRunner( + device=self.iree_device, + vmfb_path=vmfb_path, + external_weight_path=self.weight_path_str, + ) + + @timeit + def get_prompts(self): + test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ + "text" + ] + num_test_prompts = 219 + + random.seed(0) + test_prompts = random.sample(test_prompts, num_test_prompts) + + # Ignore prompts that are: empty, less than 20 tokens or a title. + test_prompts = [ + s.replace("\n", "").rstrip() + for s in test_prompts + if s != "" and len(s.split()) >= 20 and s.count("=") < 2 + ] + + self.bs = len(test_prompts) + + return test_prompts + + def prefill_vmfb(self, token_batch, i): + + logger.debug(f"Prefill:") + + logger.debug("Input:") + logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") + + token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens( + token_ids=token_batch.tolist(), + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + logger.debug(f"{token_batch}") + + token_batch = torch.tensor(token_batch, device=self.torch_device) + self.seq_lens_batch = torch.tensor(seq_lens_batch, device=self.torch_device) + + self.batch = self.generator.begin_eval_batch( + token_batch=token_batch, + seq_lens_batch=self.seq_lens_batch, + bs=self.bs, + ) + + seq_block_ids = self.batch.pad_block_ids() + prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"]( + token_batch, + self.seq_lens_batch, + seq_block_ids, + self.batch.cache_state[0].to(torch.float16), + ) + + prefill_logits = torch.tensor(prefill_logits[:, :, :]) + + tokens = torch.tensor( + self.generator.model.extract_tokens_from_logits( + prefill_logits, seq_lens_batch + ) + ).unsqueeze(1) + self.batch.add_result_token(tokens) + + self.print_token_comparison(i) + return prefill_logits + + def decode_vmfb(self, token_batch, i): + logger.debug("Decode:") + + logger.debug("Input:") + logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") + logger.debug(f"{token_batch.tolist()}") + + start_positions = self.seq_lens_batch.clone() + self.seq_lens_batch.add_(1) + self.batch.allocate_seq_block_ids() + seq_block_ids = self.batch.pad_block_ids() + + decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"]( + token_batch, + self.seq_lens_batch, + start_positions, + seq_block_ids, + self.batch.cache_state[0].to(torch.float16), + ) + + decode_logits = torch.tensor(decode_logits[:, :, :]) + + tokens = torch.tensor( + self.generator.model.extract_tokens_from_logits( + decode_logits, [1] * self.bs + ), + device=self.generator.model.device, + ).unsqueeze(1) + self.batch.add_result_token(tokens) + self.print_token_comparison(i) + return decode_logits + + @timeit + def get_logits(self): + + token_ids, seq_lens = self.generator.tokenizer.encode( + self.test_prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + logger.info(f" Prompts for Evaluation:") + for idx, prompt in enumerate(self.test_prompts): + logger.info( + f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" + ) + + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.torch_device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.torch_device) + ) + + is_first_token = True + start = 0 + for i in tqdm( + range(start, self.max_prompt_length - 1), + desc="eval: Calculating logits", + ): + logger.debug(f"Iteration: {i}") + + if is_first_token: + + token_batch = self.token_ids[:, : i + 1] + + prefill_logits = self.prefill_vmfb(token_batch, i) + self.out_logits = prefill_logits[:, 0:1, :] + + is_first_token = False + + else: + token_batch = self.token_ids[:, i : i + 1] + + decode_logits = self.decode_vmfb(token_batch, i) + self.out_logits = torch.cat((self.out_logits, decode_logits), 1) + + pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1] + + self.pad_logits = torch.zeros( + self.out_logits.shape[0], pad_logits_shape, self.out_logits.shape[2] + ) + + self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to( + self.torch_device + ) + + @timeit + def compute_perplexity(self): + loss_fct = CrossEntropyLoss(reduction="none") + + ## perplexity = e ^ (sum(losses) / num_tokenized_tokens) + crossentropy_loss = ( + loss_fct(self.out_logits.transpose(1, 2), self.token_ids) + * self.attention_mask + ).sum(1) + crossentropy_loss = torch.tensor(crossentropy_loss.tolist()) + perplexity_batch = torch.exp( + crossentropy_loss / self.attention_mask.sum(1) + ).tolist() + + perplexity_batch = [round(ppl, 6) for ppl in perplexity_batch] + + return { + "perplexities": perplexity_batch, + "mean_perplexity": round(np.mean(perplexity_batch), 6), + } + + @timeit + def get_perplexity(self, test_prompts): + + self.test_prompts = test_prompts + + self.get_logits() + + self.out_logits = self.out_logits[..., :-1, :].contiguous() + self.token_ids = self.token_ids[..., 1:].contiguous() + self.attention_mask = self.attention_mask[..., 1:].contiguous() + + logger.debug(f"Final Logits shape: {self.out_logits.shape}") + logger.debug(f"Token ids: {self.token_ids}, \n{self.token_ids.shape}") + logger.debug( + f"Mask shape: {self.attention_mask}, \n{self.attention_mask.shape}" + ) + + assert self.token_ids.shape == self.out_logits.shape[0:2] + + return self.compute_perplexity() + + +def run_perplexity( + weight_path, + weight_path_str, + tokenizer, + torch_device, + iree_device, + iree_hip_target, + iree_hal_target_backends, + kv_cache_type, + tensor_parallelism_size, + attention_kernel, +): + perplexity = Perplexity( + torch_device=torch_device, + iree_device=iree_device, + iree_hip_target=iree_hip_target, + iree_hal_target_backends=iree_hal_target_backends, + kv_cache_type=kv_cache_type, + tensor_parallelism_size=tensor_parallelism_size, + attention_kernel=attention_kernel, + ) + + test_prompts = perplexity.get_prompts() + logger.info(f" Total test prompts: {len(test_prompts)}") + + vmfb_path = perplexity.compile_model(weight_path_str) + perplexity.load_model(weight_path, tokenizer, vmfb_path) + ppl = perplexity.get_perplexity(test_prompts) + + return ppl + + +def main(argv): + parser = cli.create_parser() + parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") + parser.add_argument("--torch-device", help="Torch device (or default)") + parser.add_argument("--iree-device", help="List an IREE device, eg: 'hip://0'") + parser.add_argument( + "--iree-hip-target", + action="store", + default="gfx942", + help="Specify the iree-hip target version (e.g., gfx942)", + ) + parser.add_argument( + "--iree-hal-target-backends", + action="store", + default="rocm", + help="Specify the iree-hal target backends (e.g., rocm)", + ) + parser.add_argument( + "--attention-kernel", + type=str, + default="decomposed", + choices=["decomposed", "torch_sdpa"], + ) + parser.add_argument( + "--tensor-parallelism-size", + type=int, + default=1, + help="Number of devices for tensor parallel sharding", + ) + + cli.add_tokenizer_options(parser) + cli.add_input_dataset_options(parser) + args = cli.parse(parser, args=argv) + + torch_device = torch.device(args.torch_device) if args.torch_device else None + iree_device = args.iree_device + kv_cache_type = args.kv_cache_type + weight_path = cli.get_input_dataset(args) + tokenizer = cli.get_tokenizer(args) + weight_path_str = str(args.irpa_file) + + ppl = run_perplexity( + weight_path=weight_path, + weight_path_str=weight_path_str, + tokenizer=tokenizer, + torch_device=torch_device, + iree_device=iree_device, + iree_hip_target=args.iree_hip_target, + iree_hal_target_backends=args.iree_hal_target_backends, + kv_cache_type=kv_cache_type, + tensor_parallelism_size=args.tensor_parallelism_size, + attention_kernel=args.attention_kernel, + ) + + logger.info(f"\n{json.dumps(ppl, indent=2)}") + return ppl + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py new file mode 100644 index 000000000..b7e7bb2d4 --- /dev/null +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -0,0 +1,162 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import subprocess +import logging +import time +from pathlib import Path +from datetime import timedelta + +import iree.compiler as ireec + +logger = logging.getLogger("eval") + +logger.setLevel(logging.INFO) + +logger.root.handlers[0].setFormatter( + logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") +) + + +class ExportArtifacts: + def __init__( + self, + irpa_path: str, + batch_size: int, + iree_hip_target: str, + iree_hal_target_backends: str, + attention_kernel: str, + tensor_parallelism_size: int, + ): + self.sharktank_dir = str( + Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent + ) + self.irpa_path = irpa_path + self.batch_size = batch_size + self.iree_hip_target = iree_hip_target + self.iree_hal_target_backends = iree_hal_target_backends + self.attention_kernel = attention_kernel + self.tensor_parallelism_size = tensor_parallelism_size + + def timeit(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + seconds = end - start + time_taken = abs(timedelta(seconds=round(seconds))) + + if seconds < 1: + time_taken = f" {seconds * 1000} ms" + + func_name = func.__name__ + logger.info(f" {func_name}: {time_taken}") + return result + + return wrapper + + @timeit + def export_to_mlir( + self, + mlir_path: str, + json_path: str, + ): + export_args = [ + "python3", + "-m", + "sharktank.examples.export_paged_llm_v1", + "--irpa-file", + self.irpa_path, + "--output-mlir", + mlir_path, + "--output-config", + json_path, + "--bs", + str(self.batch_size), + ] + if self.attention_kernel == "decomposed": + export_args.append("--attention-kernel") + export_args.append(self.attention_kernel) + elif self.attention_kernel == "torch_sdpa": + raise NotImplementedError("attention_kernel torch_sdpa not implemented yet") + + cwd = self.sharktank_dir + cmd = subprocess.list2cmdline(export_args) + + logger.info(f"Exporting mlir:\n" f"cd {cwd} && {cmd}") + + proc = subprocess.run(export_args, capture_output=True, cwd=cwd, text=True) + if proc.returncode != 0: + logger.error( + f"Error exporting mlir with export_paged_llm_v1.py\n" + f"{proc.stdout+proc.stderr}" + ) + else: + logger.info(f"Exported to mlir successfully:\n" f"{proc.stdout}") + + return proc.returncode + + @timeit + def compile_to_vmfb( + self, + mlir_path, + vmfb_path, + ): + # TODO: Control flag to enable multiple backends + compile_flags = ["--iree-hip-target=" + self.iree_hip_target] + + try: + ireec.compile_file( + input_file=mlir_path, + target_backends=[self.iree_hal_target_backends], + extra_args=compile_flags, + output_file=vmfb_path, + ) + except Exception as error: + logger.error(f"Error running iree-compile:\n" f"{error}") + else: + logger.info(f"Compiled to vmfb successfully:\n" f"{vmfb_path}") + + def create_file(self, suffix, prefix): + file_path = Path(prefix).with_suffix(suffix) + f = open(file_path, "w") + return file_path + + def get_artifacts(self): + + self.dir_path = self.sharktank_dir + "/" + "tmp_perplexity_ci_artifacts/" + temp_dir = Path(self.dir_path) + temp_dir.mkdir(parents=True, exist_ok=True) + + model_name = ( + str(self.irpa_path).split("/")[-1].split(".")[0] + + "_" + + self.attention_kernel + ) + mlir_path = str( + self.create_file(suffix=".mlir", prefix=self.dir_path + model_name) + ) + json_path = str( + self.create_file(suffix=".json", prefix=self.dir_path + model_name) + ) + vmfb_path = str( + self.create_file(suffix=".vmfb", prefix=self.dir_path + model_name) + ) + + if self.attention_kernel == "decomposed": + returncode = self.export_to_mlir( + mlir_path=mlir_path, + json_path=json_path, + ) + + if returncode == 0: + self.compile_to_vmfb( + mlir_path=mlir_path, + vmfb_path=vmfb_path, + ) + + return vmfb_path diff --git a/sharktank/sharktank/utils/vmfb_runner.py b/sharktank/sharktank/utils/vmfb_runner.py new file mode 100644 index 000000000..cdbf96c9d --- /dev/null +++ b/sharktank/sharktank/utils/vmfb_runner.py @@ -0,0 +1,82 @@ +from iree import runtime as ireert +from iree.runtime._binding import create_hal_driver + + +class vmfbRunner: + def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None): + + # If an extra plugin is requested, add a global flag to load the plugin + # and create the driver using the non-caching creation function, as + # the caching creation function may ignore the flag. + if extra_plugin: + ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}") + haldriver = create_hal_driver(device) + + # No plugin requested: create the driver with the caching create + # function. + else: + haldriver = ireert.get_driver(device) + if "://" in device: + try: + device_idx = int(device.split("://")[-1]) + device_uri = None + except: + device_idx = None + device_uri = device.split("://")[-1] + else: + device_idx = 0 + device_uri = None + if device_uri: + if not any(x in device for x in ["cpu", "task"]): + allocators = ["caching"] + haldevice = haldriver.create_device_by_uri( + device_uri, allocators=allocators + ) + else: + haldevice = haldriver.create_device_by_uri(device_uri) + else: + hal_device_id = haldriver.query_available_devices()[device_idx]["device_id"] + if not any(x in device for x in ["cpu", "task"]): + allocators = ["caching"] + haldevice = haldriver.create_device( + hal_device_id, allocators=allocators + ) + else: + haldevice = haldriver.create_device(hal_device_id) + + self.config = ireert.Config(device=haldevice) + mods = [] + if not isinstance(vmfb_path, list): + vmfb_path = [vmfb_path] + for path in vmfb_path: + mods.append(ireert.VmModule.mmap(self.config.vm_instance, path)) + vm_modules = [ + *mods, + ireert.create_hal_module(self.config.vm_instance, self.config.device), + ] + + # TODO: Enable multiple weight files + if external_weight_path: + index = ireert.ParameterIndex() + if not isinstance(external_weight_path, list): + external_weight_path = [external_weight_path] + for i, path in enumerate(external_weight_path): + if path in ["", None]: + continue + index.load(path) + # TODO: extend scope + param_module = ireert.create_io_parameters_module( + self.config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(i, param_module) + del param_module + del index + + self.ctx = ireert.SystemContext( + vm_modules=vm_modules, + config=self.config, + ) + + def unload(self): + self.ctx = None + self.config = None diff --git a/sharktank/tests/evaluate/baseline_perplexity_scores.json b/sharktank/tests/evaluate/baseline_perplexity_scores.json index 45515566e..d9d0d454b 100644 --- a/sharktank/tests/evaluate/baseline_perplexity_scores.json +++ b/sharktank/tests/evaluate/baseline_perplexity_scores.json @@ -209,5 +209,110 @@ 1.915619 ], "mean_perplexity": 6.060831 + }, + "llama3_8B_f16_decomposed_vmfb": { + "perplexities": [ + 21419.466797, + 21546.818359, + 14827.014648, + 16375.65918, + 8945.300781, + 9944.508789, + 16438.810547, + 10728.957031, + 9669.796875, + 14450.475586, + 27094.927734, + 8578.132812, + 22942.267578, + 8198.905273, + 4902.405762, + 14073.242188, + 11952.408203, + 9045.265625, + 7347.615234, + 14579.709961, + 20511.626953, + 15005.15332, + 15205.226562, + 22462.205078, + 17937.900391, + 11057.017578, + 11663.111328, + 11390.241211, + 7898.138672, + 7637.557129, + 10265.848633, + 16729.228516, + 5744.851074, + 7046.032227, + 7316.122559, + 7153.626953, + 8192.285156, + 5918.197266, + 12119.681641, + 13367.679688, + 6873.890137, + 7742.501953, + 13619.378906, + 7469.197754, + 8517.003906, + 5852.495605, + 21839.90625, + 13266.838867, + 45137.652344, + 13815.619141, + 14725.118164, + 14006.322266, + 27869.220703, + 8008.710449, + 6843.859863, + 10156.393555, + 7417.569824, + 17133.203125, + 4873.34668, + 8810.631836, + 13012.022461, + 10515.050781, + 6490.756348, + 6884.498535, + 13199.611328, + 9676.604492, + 2992.313965, + 12557.617188, + 13808.018555, + 12141.337891, + 10426.229492, + 16427.511719, + 13736.017578, + 9114.052734, + 14844.96875, + 11502.46875, + 6369.100098, + 10188.533203, + 5520.150391, + 10693.388672, + 4136.566895, + 12878.518555, + 6268.281738, + 17126.113281, + 10425.692383, + 42463.15625, + 21795.568359, + 6170.659668, + 17573.275391, + 6537.691406, + 8774.048828, + 14328.767578, + 35863.398438, + 10549.089844, + 5560.846191, + 8987.045898, + 6189.242188, + 13732.914062, + 10735.333984, + 12495.99707 + ], + "mean_perplexity": 12543.547432 } } diff --git a/sharktank/tests/evaluate/perplexity_test.py b/sharktank/tests/evaluate/perplexity_torch_test.py similarity index 61% rename from sharktank/tests/evaluate/perplexity_test.py rename to sharktank/tests/evaluate/perplexity_torch_test.py index faf3a263f..042132f20 100644 --- a/sharktank/tests/evaluate/perplexity_test.py +++ b/sharktank/tests/evaluate/perplexity_torch_test.py @@ -8,19 +8,20 @@ import pytest import json -from sharktank.evaluate import perplexity +from sharktank.evaluate import perplexity_torch longrun = pytest.mark.skipif("not config.getoption('longrun')") -@pytest.mark.usefixtures("get_model_path") +@pytest.mark.usefixtures( + "get_model_artifacts", "tensor_parallelism_size", "baseline_perplexity_scores" +) class PerplexityTest(unittest.TestCase): def setUp(self): self.current_perplexity_all = {} self.delta = 5e-1 self.tensor_parallelism_size = 8 - - with open(self.baseline_perplexity_score_json, "r") as f: + with open(self.baseline_perplexity_scores, "r") as f: self.baseline_perplexity = json.load(f) @longrun @@ -31,44 +32,54 @@ def test_llama3_8B_f16_decomposed(self): model_name = "llama3_8B_f16_decomposed" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity.main( + current_perplexity = perplexity_torch.main( [ - f"--gguf-file={self.llama3_8b_f16_model}", + f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", ] ) + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + self.assertAlmostEqual( baseline_perplexity["mean_perplexity"], current_perplexity["mean_perplexity"], delta=self.delta, - msg=f"Perplexity is deviating more than {self.delta}", + msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @pytest.mark.xfail( reason="Non-decomposed attention is not supported yet", ) @longrun - def test_llama3_8B_f16_non_decomposed(self): + def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_f16_non_decomposed" + model_name = "llama3_8B_f16" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity.main( + current_perplexity = perplexity_torch.main( [ - f"--gguf-file={self.llama3_8b_f16_model}", + f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", f"--attention-kernel=torch_sdpa", ] ) + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + self.assertAlmostEqual( baseline_perplexity["mean_perplexity"], current_perplexity["mean_perplexity"], delta=self.delta, - msg=f"Perplexity is deviating more than {self.delta}", + msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @pytest.mark.xfail( @@ -82,46 +93,59 @@ def test_llama3_8B_fp8_decomposed(self): model_name = "llama3_8B_fp8_decomposed" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity.main( + current_perplexity = perplexity_torch.main( [ - f"--gguf-file={self.llama3_8b_fp8_model}", + f"--irpa-file={self.llama3_8b_fp8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", ] ) + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + self.assertAlmostEqual( baseline_perplexity["mean_perplexity"], current_perplexity["mean_perplexity"], delta=self.delta, - msg=f"Perplexity is deviating more than {self.delta}", + msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @pytest.mark.xfail( reason="Non-decomposed attention is not supported yet", ) @longrun - def test_llama3_8B_fp8_non_decomposed(self): + def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_fp8_non_decomposed" + model_name = "llama3_8B_fp8" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity.main( + current_perplexity = perplexity_torch.main( [ - f"--gguf-file={self.llama3_8b_fp8_model}", + f"--irpa-file={self.llama3_8b_fp8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", f"--attention-kernel=torch_sdpa", ] ) + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + self.assertAlmostEqual( baseline_perplexity["mean_perplexity"], current_perplexity["mean_perplexity"], delta=self.delta, - msg=f"Perplexity is deviating more than {self.delta}", + msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) + @pytest.mark.xfail( + reason="Sharding needs to be fixed", + ) @longrun def test_llama3_405B_f16_decomposed(self): @@ -130,46 +154,56 @@ def test_llama3_405B_f16_decomposed(self): model_name = "llama3_405B_f16_decomposed" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity.main( + current_perplexity = perplexity_torch.main( [ - f"--gguf-file={self.llama3_405b_f16_model}", + f"--irpa-file={self.llama3_405b_f16_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", ] ) + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + self.assertAlmostEqual( baseline_perplexity["mean_perplexity"], current_perplexity["mean_perplexity"], delta=self.delta, - msg=f"Perplexity is deviating more than {self.delta}", + msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @pytest.mark.xfail( reason="Non-decomposed attention is not supported yet", ) @longrun - def test_llama3_405B_f16_non_decomposed(self): + def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_f16_non_decomposed" + model_name = "llama3_405B_f16" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity.main( + current_perplexity = perplexity_torch.main( [ - f"--gguf-file={self.llama3_405b_f16_model}", + f"--irpa-file={self.llama3_405b_f16_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", ] ) + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + self.assertAlmostEqual( baseline_perplexity["mean_perplexity"], current_perplexity["mean_perplexity"], delta=self.delta, - msg=f"Perplexity is deviating more than {self.delta}", + msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @pytest.mark.xfail( @@ -183,46 +217,56 @@ def test_llama3_405B_fp8_decomposed(self): model_name = "llama3_405B_fp8_decomposed" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity.main( + current_perplexity = perplexity_torch.main( [ - f"--gguf-file={self.llama3_405b_fp8_model}", + f"--irpa-file={self.llama3_405b_fp8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", ] ) + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + self.assertAlmostEqual( baseline_perplexity["mean_perplexity"], current_perplexity["mean_perplexity"], delta=self.delta, - msg=f"Perplexity is deviating more than {self.delta}", + msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @pytest.mark.xfail( reason="Non-decomposed attention is not supported yet", ) @longrun - def test_llama3_405B_fp8_non_decomposed(self): + def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_fp8_non_decomposed" + model_name = "llama3_405B_fp8" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity.main( + current_perplexity = perplexity_torch.main( [ - f"--gguf-file={self.llama3_405b_fp8_model}", + f"--irpa-file={self.llama3_405b_fp8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", ] ) + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + self.assertAlmostEqual( baseline_perplexity["mean_perplexity"], current_perplexity["mean_perplexity"], delta=self.delta, - msg=f"Perplexity is deviating more than {self.delta}", + msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_vmfb_test.py new file mode 100644 index 000000000..93ffbe61c --- /dev/null +++ b/sharktank/tests/evaluate/perplexity_vmfb_test.py @@ -0,0 +1,309 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +import pytest +import json + +from sharktank.evaluate import perplexity_vmfb + +longrun = pytest.mark.skipif("not config.getoption('longrun')") + + +@pytest.mark.usefixtures( + "get_model_artifacts", + "get_iree_flags", + "tensor_parallelism_size", + "baseline_perplexity_scores", +) +class PerplexityTest(unittest.TestCase): + def setUp(self): + self.current_perplexity_all = {} + self.delta = 5e-1 + self.tensor_parallelism_size = 8 + with open(self.baseline_perplexity_scores, "r") as f: + self.baseline_perplexity = json.load(f) + + @longrun + def test_llama3_8B_f16_decomposed(self): + + # Llama 3.1 8B decomposed + + model_name = "llama3_8B_f16_decomposed_vmfb" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_vmfb.main( + [ + f"--irpa-file={self.llama3_8b_f16_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size=1", + f"--attention-kernel=decomposed", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="Non-decomposed attention is not supported yet", + ) + @longrun + def test_llama3_8B_f16(self): + + # Llama 3.1 8B non-decomposed + + model_name = "llama3_8B_f16_vmfb" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_vmfb.main( + [ + f"--irpa-file={self.llama3_8b_f16_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size=1", + f"--attention-kernel=torch_sdpa", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="FP8 model is unsupported", + ) + @longrun + def test_llama3_8B_fp8_decomposed(self): + + # Llama 3.1 8B decomposed + + model_name = "llama3_8B_fp8_decomposed_vmfb" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_vmfb.main( + [ + f"--irpa-file={self.llama3_8b_fp8_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size=1", + f"--attention-kernel=decomposed", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="FP8 model is unsupported", + ) + @longrun + def test_llama3_8B_fp8(self): + + # Llama 3.1 8B non-decomposed + + model_name = "llama3_8B_fp8_vmfb" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_vmfb.main( + [ + f"--irpa-file={self.llama3_8b_fp8_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size=1", + f"--attention-kernel=torch_sdpa", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="Sharding is unsupported", + ) + @longrun + def test_llama3_405B_f16_decomposed(self): + + # Llama 3.1 405B decomposed + + model_name = "llama3_405B_f16_decomposed_vmfb" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_vmfb.main( + [ + f"--irpa-file={self.llama3_405b_f16_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=decomposed", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="Non-decomposed attention is not supported yet", + ) + @longrun + def test_llama3_405B_f16(self): + + # Llama 3.1 405B non-decomposed + + model_name = "llama3_405B_f16_vmfb" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_vmfb.main( + [ + f"--irpa-file={self.llama3_405b_f16_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=torch_sdpa", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="FP8 model is unsupported", + ) + @longrun + def test_llama3_405B_fp8_decomposed(self): + + # Llama 3.1 405B decomposed + + model_name = "llama3_405B_fp8_decomposed_vmfb" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_vmfb.main( + [ + f"--irpa-file={self.llama3_405b_fp8_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=decomposed", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="FP8 model is unsupported", + ) + @longrun + def test_llama3_405B_fp8(self): + + # Llama 3.1 405B non-decomposed + + model_name = "llama3_405B_fp8_vmfb" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_vmfb.main( + [ + f"--irpa-file={self.llama3_405b_fp8_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=torch_sdpa", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + +if __name__ == "__main__": + unittest.main()