Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank] Evaluation - Add Perplexity test #233

Merged
merged 38 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
aee0d58
Add 'datasets' package to load golden dataset
archana-ramalingam Sep 27, 2024
1ee8594
Isolate padding function in tokenizer
archana-ramalingam Sep 27, 2024
0103293
Add utility function to load/run LLMs for evaluation pipeline
archana-ramalingam Sep 27, 2024
1dfdbc6
Add perplexity test
archana-ramalingam Sep 27, 2024
14050f5
Cleanup
archana-ramalingam Sep 27, 2024
b7c75f3
delete file
archana-ramalingam Sep 27, 2024
a44a8a2
Add perplexity test
archana-ramalingam Sep 27, 2024
8034432
Fix dataset loading
archana-ramalingam Sep 27, 2024
a26b17d
Update page_cache_size
archana-ramalingam Sep 27, 2024
cd079a7
add run_perplexity and prompts
archana-ramalingam Sep 27, 2024
df84163
Merge branch 'main' into perplexity-test
archana-ramalingam Sep 28, 2024
3e0871e
Shift logits and change activation dtype
archana-ramalingam Sep 30, 2024
4a74107
Add Grok model
archana-ramalingam Sep 30, 2024
64e812d
Remove decode and run prefill on every turn
archana-ramalingam Oct 1, 2024
29e6031
Change activation dtype to enable quantized models
archana-ramalingam Oct 1, 2024
9c168f3
Add timing wrapper
archana-ramalingam Oct 1, 2024
38590bb
Add instructions to run evaluation-perplexity
archana-ramalingam Oct 2, 2024
2bf8739
Add prompts text file
archana-ramalingam Oct 2, 2024
70f6ba5
Add logging + cleanup
archana-ramalingam Oct 2, 2024
848da59
Add CI perplexity test
archana-ramalingam Oct 2, 2024
7e49580
Update prompt file path
archana-ramalingam Oct 2, 2024
ec6968f
Remove unit tests for nightly
archana-ramalingam Oct 2, 2024
f7667ec
Add relative path + push attention_mask to device
archana-ramalingam Oct 3, 2024
7054141
Remove debug changes
archana-ramalingam Oct 3, 2024
70a7b10
Merge branch 'main' into perplexity-test
archana-ramalingam Oct 3, 2024
134c77f
Update dtype to F32 for compatibility across torch versions
archana-ramalingam Oct 4, 2024
27f4e15
Merge branch 'main' into perplexity-test
archana-ramalingam Oct 7, 2024
8a0a081
Add decode
archana-ramalingam Oct 10, 2024
e47fe4a
Fix padding logits
archana-ramalingam Oct 11, 2024
b15c06d
Add local model path
archana-ramalingam Oct 11, 2024
6da2b38
Add CI test for evaluation
archana-ramalingam Oct 11, 2024
0afe63f
Add perplexity calculated from prefill logits only
archana-ramalingam Oct 15, 2024
e4ccb10
Merge branch 'main' into perplexity-test
archana-ramalingam Oct 16, 2024
478f1a1
Add CI tests for perplexity
archana-ramalingam Oct 16, 2024
96458a8
Clean up
archana-ramalingam Oct 16, 2024
b4e3635
Merge branch 'perplexity-test' of https://github.com/nod-ai/sharktank…
archana-ramalingam Oct 16, 2024
4af53c3
Clean up
archana-ramalingam Oct 16, 2024
e2eb98c
Update argument
archana-ramalingam Oct 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ onnx==1.15.0
huggingface-hub==0.22.2
transformers==4.40.0
sentencepiece==0.2.0
datasets==3.0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should steer towards putting requirements in the subproject requirements files instead of this top level file, especially if this is a test-only requirement


# It is expected that you have installed a PyTorch version/variant specific
# to your needs, so we only include a minimum version spec.
Expand Down
222 changes: 222 additions & 0 deletions sharktank/sharktank/utils/load_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import math

import torch

from sharktank.layers import *
from sharktank.types import *
from sharktank.models.llama.llama import *

from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer


class TorchGenerator:
"""Generator that runs directly on the Torch model."""

def __init__(
self,
model: PagedLlamaModelV1,
tokenizer: InferenceTokenizer,
page_cache_size: int = 128,
# Need to look at the model more for this.
end_token: int = 2,
):
self.model = model
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
else:
self.shared_cache_state = None
self.free_pages = list(range(1, 128))
self.end_token = end_token

@property
def block_seq_stride(self) -> int:
return self.model.cache.block_seq_stride

def begin_batch(self, prompts: list[str], add_start_token: bool):
token_ids, seq_lens = self.tokenizer.encode(
prompts,
pad_to_multiple_of=self.model.cache.pad_sequence_stride,
add_start_token=add_start_token,
)
token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
if self.shared_cache_state is not None:
cache_state = self.shared_cache_state
else:
cache_state = self.model.cache.direct.allocate(bs=len(prompts))
return Batch(self, token_ids, seq_lens, cache_state)

def begin_eval_batch(
self,
token_batch: torch.tensor,
seq_lens_batch: torch.tensor,
bs: int,
):

if self.shared_cache_state is not None:
cache_state = self.shared_cache_state
else:
cache_state = self.model.cache.direct.allocate(bs=bs)
return Batch(self, token_batch, seq_lens_batch, cache_state)

def alloc_page(self) -> int:
if self.model.cache.is_direct:
# We don't allocate block ids for the direct cache.
return 0

return self.free_pages.pop()

def release_page(self, index: int):
if self.model.cache.is_direct:
return
self.free_pages.append(index)


class Batch:
def __init__(
self,
parent: TorchGenerator,
token_ids: torch.Tensor,
seq_lens: torch.Tensor,
cache_state: list[torch.Tensor],
):
self.bs = token_ids.shape[0]
# assert seq_lens.shape[0] == self.bs
self.parent = parent
self.token_ids = token_ids
self.seq_lens = seq_lens
self.cache_state = cache_state
self.results: list[list[int]] = [[] for _ in range(self.bs)]
self.done_result_indices: set[int] = set()

# Assemble the batch.
seq_stride = self.parent.block_seq_stride
self.seq_block_ids: list[list[int]] = []
for seq_len in self.seq_lens:
blocks_needed = (
int(math.ceil(seq_len / seq_stride)) if seq_stride > 0 else 0
)
row = []
for _ in range(blocks_needed):
row.append(self.parent.alloc_page())
self.seq_block_ids.append(row)

@property
def done(self) -> bool:
return len(self.done_result_indices) == self.bs

def detokenize(self) -> list[str]:
return self.parent.tokenizer.decode(self.results)

def print_current_results(self):
results = self.detokenize()
for i, s in enumerate(results):
seq_len = int(self.seq_lens[i])
print(f" {i}({len(self.results[i])}, {seq_len}): {s}")

def add_result_token(self, tokens: torch.Tensor):
for i in range(self.bs):
token = tokens[i][0]
if token == self.parent.end_token:
self.done_result_indices.add(i)
if i in self.done_result_indices:
continue
token = int(tokens[i, 0])
self.results[i].append(token)

def allocate_seq_block_ids(self):
for i in range(self.bs):
sl = int(self.seq_lens[i])
if (sl % self.parent.block_seq_stride) == 0:
needed_blocks = sl // self.parent.block_seq_stride + 1
else:
needed_blocks = math.ceil(sl / self.parent.block_seq_stride)
block_ids_row = self.seq_block_ids[i]
while len(block_ids_row) < needed_blocks:
block_ids_row.append(self.parent.alloc_page())

def prefill(self):
model = self.parent.model
attention_mask = model.attention_mask(
model.input_mask(self.seq_lens, self.token_ids.shape[1])
)
seq_block_ids_tensor = self.pad_block_ids()
# print(f":: Invoke prefill:")
trace_tensor("prefill.token_ids", self.token_ids)
trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor)
trace_tensor("prefill.attention_mask", attention_mask)
# print("prefill.token_ids", self.token_ids)
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved
# print("prefill.seq_block_ids", seq_block_ids_tensor)
# print("prefill.attention_mask", attention_mask.shape)
# print("prefill.cache_state", self.cache_state[0].shape)
self.prefill_logits = model.prefill(
self.token_ids,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)

# TODO: Generalize the sampling and don't make it swap on/off cpu.
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(self.prefill_logits, self.seq_lens)
).unsqueeze(1)
# print(f":: Prefill results:\n{tokens.tolist()}")
self.add_result_token(tokens)
self.next_tokens = tokens.to(device=model.device)
return self.cache_state

def decode(self, cache_state):
self.cache_state = cache_state
model = self.parent.model
start_positions = self.seq_lens.clone()
self.seq_lens.add_(1)
self.allocate_seq_block_ids()
# TODO: Allocate more blocks on overflow.
seq_block_ids_tensor = self.pad_block_ids()
decode_attention_mask = model.decode_attention_mask(
model.input_mask(
self.seq_lens,
seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride,
)
)
trace_tensor("decode.token_ids", self.token_ids)
trace_tensor("decode.start_positions", start_positions)
trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)
trace_tensor("decode.attention_mask", decode_attention_mask)
# print("decode.token_ids", self.token_ids)

self.decode_logits = model.decode(
self.token_ids,
attention_mask=decode_attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)

# print("decode", len(self.decode_logits))
# trace_tensor("decode.logits", self.decode_logits)
# # TODO: Normalize the output of extract_tokens_from_logits into
# # tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(self.decode_logits, [1] * self.bs),
device=self.parent.model.device,
).unsqueeze(1)
self.add_result_token(tokens)
# self.next_tokens = tokens

return self.cache_state

def pad_block_ids(self) -> torch.Tensor:
max_length = max(len(r) for r in self.seq_block_ids)
rows = [r + (max_length - len(r)) * [0] for r in self.seq_block_ids]
return torch.tensor(rows, device=self.parent.model.device)
42 changes: 29 additions & 13 deletions sharktank/sharktank/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,49 @@ class InferenceTokenizer(ABC):
"""Simple inference tokenizer."""

def encode(
self, texts: list[str], pad_to_multiple_of: int = 1, pad_token: int = 0
self,
texts: list[str],
pad_to_multiple_of: int = 1,
add_start_token: bool = True,
) -> tuple[list[list[int]]]:
"""Encodes a list of texts into a padded list of tokens.

Returns a list of list of tokens and a list of unpadded lengths.
"""
raw_rows = self._encode(texts)
raw_rows = self._encode(texts, add_start_token)
raw_rows, lengths = self.pad_tokens(
token_ids=raw_rows, pad_to_multiple_of=pad_to_multiple_of
)
return raw_rows, lengths

def decode(self, tokens: Union[list[list[int]]], lens: Optional[list[int]] = None):
"""Decodes a list of tokens."""
if lens is not None:
tokens = list(tokens)
for i, row_length in enumerate(lens):
tokens[i] = tokens[i][0:row_length]
return self._decode(tokens)

def pad_tokens(
self,
token_ids: list[list[int]],
pad_to_multiple_of: int,
pad_token: int = 0,
):
max_length = 0
lengths: list[int] = []
for row in raw_rows:
for row in token_ids:
lengths.append(len(row))
max_length = max(max_length, len(row))
if pad_to_multiple_of > 1:
max_length = int(
pad_to_multiple_of * math.ceil(max_length / pad_to_multiple_of)
)
for row in raw_rows:
for row in token_ids:
pad_count = max_length - len(row)
row.extend(pad_count * [pad_token])
return raw_rows, lengths

def decode(self, tokens: Union[list[list[int]]], lens: Optional[list[int]] = None):
"""Decodes a list of tokens."""
if lens is not None:
tokens = list(tokens)
for i, row_length in enumerate(lens):
tokens[i] = tokens[i][0:row_length]
return self._decode(tokens)
return token_ids, lengths

@abstractmethod
def _encode(self, texts: list[str]) -> list[list[int]]:
Expand All @@ -76,9 +91,10 @@ class _TransformersTokenizer(InferenceTokenizer):
def __init__(self, t: AutoTokenizer):
self._t = t

def _encode(self, texts: list[str]) -> list[list[int]]:
def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]:
results = t.batch_encode_plus(
texts,
add_special_tokens=add_start_token,
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved
padding=False,
truncation=False,
)
Expand Down
Loading
Loading