Skip to content

Commit

Permalink
[sharktank] Evaluation - Add Perplexity test for vmfb (#306)
Browse files Browse the repository at this point in the history
Add Perplexity test for vmfb
  • Loading branch information
archana-ramalingam authored Oct 29, 2024
1 parent f2b1a01 commit 072be20
Show file tree
Hide file tree
Showing 9 changed files with 1,331 additions and 69 deletions.
63 changes: 58 additions & 5 deletions .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ concurrency:
cancel-in-progress: true

jobs:
test_perplexity:
test_perplexity_vmfb:
timeout-minutes: 1000
name: "Evaluation Tests - perplexity"
name: "Evaluation Tests - perplexity_vmfb"
strategy:
matrix:
version: [3.11]
runs-on: [llama-mi300]
runs-on: [llama-mi300x-3]
fail-fast: false
runs-on: ${{matrix.runs-on}}
defaults:
Expand Down Expand Up @@ -58,5 +58,58 @@ jobs:
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
- name: Run perplexity test
run: pytest -n 4 -v -s sharktank/tests/evaluate/perplexity_test.py --longrun
# Try with the latest nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
pip install -f https://iree.dev/pip-release-links.html --upgrade \
iree-compiler \
iree-runtime \
"numpy<2.0"
- name: Run perplexity test with vmfb
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target='gfx942' --llama3-8b-f16-model-path=/data/llama-3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama-3.1/8b/tokenizer_config.json

test_perplexity_torch:
timeout-minutes: 1000
name: "Evaluation Tests - perplexity_torch"
strategy:
matrix:
version: [3.11]
runs-on: [llama-mi300x-3]
fail-fast: false
runs-on: ${{matrix.runs-on}}
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
steps:
- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@v3
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
uses: actions/checkout@v3

- name: Cache Pip Packages
uses: actions/cache@v4
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}

- name: Install sharktank deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
- name: Run perplexity test in eager mode
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama-3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama-3.1/8b/tokenizer_config.json
94 changes: 75 additions & 19 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,19 @@ def pytest_addoption(parser):
help="Enable long and slow tests",
)

# TODO: Remove all hardcoded paths in CI tests
parser.addoption(
"--llama3-8b-tokenizer-path",
type=Path,
action="store",
default="/data/extra/models/llama3.1_8B/tokenizer_config.json",
help="Llama3.1 8b tokenizer path, defaults to 30F CI system path",
)

parser.addoption(
"--llama3-8b-f16-gguf-path",
"--llama3-8b-f16-model-path",
type=Path,
action="store",
default="/data/extra/models/llama3.1_8B/llama8b_f16.gguf",
help="Llama3.1 8b gguf model path, defaults to 30F CI system path",
help="Llama3.1 8b model path, defaults to 30F CI system path",
)

parser.addoption(
Expand All @@ -100,16 +99,14 @@ def pytest_addoption(parser):
"--llama3-405b-tokenizer-path",
type=Path,
action="store",
default="/data/extra/models/llama3.1_405B/tokenizer_config.json",
help="Llama3.1 405b tokenizer path, defaults to 30F CI system path",
)

parser.addoption(
"--llama3-405b-f16-gguf-path",
"--llama3-405b-f16-model-path",
type=Path,
action="store",
default="/data/extra/models/llama3.1_405B/llama405b_fp16.gguf",
help="Llama3.1 405b gguf model path, defaults to 30F CI system path",
help="Llama3.1 405b model path, defaults to 30F CI system path",
)

parser.addoption(
Expand All @@ -121,20 +118,49 @@ def pytest_addoption(parser):
)

parser.addoption(
"--baseline-perplexity-score-json",
"--baseline-perplexity-scores",
type=Path,
action="store",
default="sharktank/tests/evaluate/baseline_perplexity_scores.json",
help="Llama3.1 8B & 405B model baseline perplexity scores json",
help="Llama3.1 8B & 405B model baseline perplexity scores",
)

parser.addoption(
"--iree-device",
type=str,
action="store",
help="List an IREE device from iree-run-module --list_devices",
)

parser.addoption(
"--iree-hip-target",
action="store",
default="gfx942",
help="Specify the iree-hip target version (e.g., gfx942)",
)

parser.addoption(
"--iree-hal-target-backends",
action="store",
default="rocm",
help="Specify the iree-hal target backend (e.g., rocm)",
)

parser.addoption(
"--tensor-parallelism-size",
action="store",
type=int,
default=1,
help="Number of devices for tensor parallel sharding",
)

parser.addoption(
"--bs",
action="store",
type=int,
default=4,
help="Batch size for mlir export",
)


def set_fixture_from_cli_option(
request: FixtureRequest,
Expand Down Expand Up @@ -183,27 +209,57 @@ def iree_hip_target_type(request: FixtureRequest) -> Optional[str]:


@pytest.fixture(scope="class")
def get_model_path(request: FixtureRequest):
def tensor_parallelism_size(request: FixtureRequest) -> Optional[str]:
return set_fixture_from_cli_option(
request, "tensor_parallelism_size", "tensor_parallelism_size"
)


@pytest.fixture(scope="class")
def baseline_perplexity_scores(request: FixtureRequest) -> Optional[str]:
return set_fixture_from_cli_option(
request, "baseline_perplexity_scores", "baseline_perplexity_scores"
)


@pytest.fixture(scope="class")
def batch_size(request: FixtureRequest) -> Optional[str]:
return set_fixture_from_cli_option(request, "bs", "batch_size")


@pytest.fixture(scope="class")
def get_model_artifacts(request: FixtureRequest):
model_path = {}
model_path["llama3_8b_tokenizer_path"] = set_fixture_from_cli_option(
request, "--llama3-8b-tokenizer-path", "llama3_8b_tokenizer"
)
model_path["llama3_8b_f16_gguf_path"] = set_fixture_from_cli_option(
request, "--llama3-8b-f16-gguf-path", "llama3_8b_f16_model"
model_path["llama3_8b_f16_model_path"] = set_fixture_from_cli_option(
request, "--llama3-8b-f16-model-path", "llama3_8b_f16_model"
)
model_path["llama3_8b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-8b-fp8-model-path", "llama3_8b_fp8_model"
)
model_path["llama3_405b_tokenizer_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-tokenizer-path", "llama3_405b_tokenizer"
)
model_path["llama3_405b_f16_gguf_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-f16-gguf-path", "llama3_405b_f16_model"
model_path["llama3_405b_f16_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-f16-model-path", "llama3_405b_f16_model"
)
model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model"
)
model_path["baseline_perplexity_score_json"] = set_fixture_from_cli_option(
request, "--baseline-perplexity-score-json", "baseline_perplexity_score_json"
)
return model_path


@pytest.fixture(scope="class")
def get_iree_flags(request: FixtureRequest):
model_path = {}
model_path["iree_device"] = set_fixture_from_cli_option(
request, "--iree-device", "iree_device"
)
model_path["iree_hip_target"] = set_fixture_from_cli_option(
request, "--iree-hip-target", "iree_hip_target"
)
model_path["iree_hal_target_backends"] = set_fixture_from_cli_option(
request, "--iree-hal-target-backends", "iree_hal_target_backends"
)
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s")
)

__all__ = ["Perplexity", "run_perplexity"]
__all__ = ["Perplexity_torch", "run_perplexity_torch"]


class Perplexity:
class Perplexity_torch:
"""
Perplexity (PPL) is one of the most common metrics for evaluating language models.
It is defined as the exponentiated average negative log-likelihood of a sequence,
Expand All @@ -59,8 +59,6 @@ def __init__(
device,
kv_cache_type,
):
self.batch_size = 16

self.device = device
self.kv_cache_type = kv_cache_type
self.activation_dtype = torch.float32
Expand Down Expand Up @@ -173,6 +171,8 @@ def get_logits(self):
(self.token_ids != 0).int().detach().clone().to(self.device)
)

self.bs = len(self.test_prompts)

is_first_token = True
start = 0
for i in tqdm(
Expand Down Expand Up @@ -263,8 +263,6 @@ def compute_perplexity(self):
def get_perplexity(self, test_prompts):

self.test_prompts = test_prompts
self.bs = len(self.test_prompts)

self.get_logits()

self.out_logits = self.out_logits[..., :-1, :].contiguous()
Expand All @@ -282,15 +280,15 @@ def get_perplexity(self, test_prompts):
return self.compute_perplexity()


def run_perplexity(
def run_perplexity_torch(
dataset,
tokenizer,
device,
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
):
perplexity = Perplexity(device=device, kv_cache_type=kv_cache_type)
perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type)

perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel)
test_prompts = perplexity.get_prompts()
Expand Down Expand Up @@ -326,7 +324,7 @@ def main(argv):
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)

ppl = run_perplexity(
ppl = run_perplexity_torch(
dataset=dataset,
tokenizer=tokenizer,
device=device,
Expand Down
Loading

0 comments on commit 072be20

Please sign in to comment.