-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[sharktank] Evaluation - Add Perplexity test #233
Merged
Merged
Changes from 8 commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
aee0d58
Add 'datasets' package to load golden dataset
archana-ramalingam 1ee8594
Isolate padding function in tokenizer
archana-ramalingam 0103293
Add utility function to load/run LLMs for evaluation pipeline
archana-ramalingam 1dfdbc6
Add perplexity test
archana-ramalingam 14050f5
Cleanup
archana-ramalingam b7c75f3
delete file
archana-ramalingam a44a8a2
Add perplexity test
archana-ramalingam 8034432
Fix dataset loading
archana-ramalingam a26b17d
Update page_cache_size
archana-ramalingam cd079a7
add run_perplexity and prompts
archana-ramalingam df84163
Merge branch 'main' into perplexity-test
archana-ramalingam 3e0871e
Shift logits and change activation dtype
archana-ramalingam 4a74107
Add Grok model
archana-ramalingam 64e812d
Remove decode and run prefill on every turn
archana-ramalingam 29e6031
Change activation dtype to enable quantized models
archana-ramalingam 9c168f3
Add timing wrapper
archana-ramalingam 38590bb
Add instructions to run evaluation-perplexity
archana-ramalingam 2bf8739
Add prompts text file
archana-ramalingam 70f6ba5
Add logging + cleanup
archana-ramalingam 848da59
Add CI perplexity test
archana-ramalingam 7e49580
Update prompt file path
archana-ramalingam ec6968f
Remove unit tests for nightly
archana-ramalingam f7667ec
Add relative path + push attention_mask to device
archana-ramalingam 7054141
Remove debug changes
archana-ramalingam 70a7b10
Merge branch 'main' into perplexity-test
archana-ramalingam 134c77f
Update dtype to F32 for compatibility across torch versions
archana-ramalingam 27f4e15
Merge branch 'main' into perplexity-test
archana-ramalingam 8a0a081
Add decode
archana-ramalingam e47fe4a
Fix padding logits
archana-ramalingam b15c06d
Add local model path
archana-ramalingam 6da2b38
Add CI test for evaluation
archana-ramalingam 0afe63f
Add perplexity calculated from prefill logits only
archana-ramalingam e4ccb10
Merge branch 'main' into perplexity-test
archana-ramalingam 478f1a1
Add CI tests for perplexity
archana-ramalingam 96458a8
Clean up
archana-ramalingam b4e3635
Merge branch 'perplexity-test' of https://github.com/nod-ai/sharktank…
archana-ramalingam 4af53c3
Clean up
archana-ramalingam e2eb98c
Update argument
archana-ramalingam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import math | ||
|
||
import torch | ||
|
||
from sharktank.layers import * | ||
from sharktank.types import * | ||
from sharktank.models.llama.llama import * | ||
|
||
from ..utils.debugging import trace_tensor | ||
from ..utils.tokenizer import InferenceTokenizer | ||
|
||
|
||
class TorchGenerator: | ||
"""Generator that runs directly on the Torch model.""" | ||
|
||
def __init__( | ||
self, | ||
model: PagedLlamaModelV1, | ||
tokenizer: InferenceTokenizer, | ||
page_cache_size: int = 128, | ||
# Need to look at the model more for this. | ||
end_token: int = 2, | ||
): | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
if model.cache.is_paged: | ||
self.shared_cache_state = model.cache.paged.allocate(page_cache_size) | ||
else: | ||
self.shared_cache_state = None | ||
self.free_pages = list(range(1, 128)) | ||
self.end_token = end_token | ||
|
||
@property | ||
def block_seq_stride(self) -> int: | ||
return self.model.cache.block_seq_stride | ||
|
||
def begin_batch(self, prompts: list[str], add_start_token: bool): | ||
token_ids, seq_lens = self.tokenizer.encode( | ||
prompts, | ||
pad_to_multiple_of=self.model.cache.pad_sequence_stride, | ||
add_start_token=add_start_token, | ||
) | ||
token_ids = torch.tensor(token_ids, device=self.model.device) | ||
seq_lens = torch.tensor(seq_lens, device=self.model.device) | ||
if self.shared_cache_state is not None: | ||
cache_state = self.shared_cache_state | ||
else: | ||
cache_state = self.model.cache.direct.allocate(bs=len(prompts)) | ||
return Batch(self, token_ids, seq_lens, cache_state) | ||
|
||
def begin_eval_batch( | ||
self, | ||
token_batch: torch.tensor, | ||
seq_lens_batch: torch.tensor, | ||
bs: int, | ||
): | ||
|
||
if self.shared_cache_state is not None: | ||
cache_state = self.shared_cache_state | ||
else: | ||
cache_state = self.model.cache.direct.allocate(bs=bs) | ||
return Batch(self, token_batch, seq_lens_batch, cache_state) | ||
|
||
def alloc_page(self) -> int: | ||
if self.model.cache.is_direct: | ||
# We don't allocate block ids for the direct cache. | ||
return 0 | ||
|
||
return self.free_pages.pop() | ||
|
||
def release_page(self, index: int): | ||
if self.model.cache.is_direct: | ||
return | ||
self.free_pages.append(index) | ||
|
||
|
||
class Batch: | ||
def __init__( | ||
self, | ||
parent: TorchGenerator, | ||
token_ids: torch.Tensor, | ||
seq_lens: torch.Tensor, | ||
cache_state: list[torch.Tensor], | ||
): | ||
self.bs = token_ids.shape[0] | ||
# assert seq_lens.shape[0] == self.bs | ||
self.parent = parent | ||
self.token_ids = token_ids | ||
self.seq_lens = seq_lens | ||
self.cache_state = cache_state | ||
self.results: list[list[int]] = [[] for _ in range(self.bs)] | ||
self.done_result_indices: set[int] = set() | ||
|
||
# Assemble the batch. | ||
seq_stride = self.parent.block_seq_stride | ||
self.seq_block_ids: list[list[int]] = [] | ||
for seq_len in self.seq_lens: | ||
blocks_needed = ( | ||
int(math.ceil(seq_len / seq_stride)) if seq_stride > 0 else 0 | ||
) | ||
row = [] | ||
for _ in range(blocks_needed): | ||
row.append(self.parent.alloc_page()) | ||
self.seq_block_ids.append(row) | ||
|
||
@property | ||
def done(self) -> bool: | ||
return len(self.done_result_indices) == self.bs | ||
|
||
def detokenize(self) -> list[str]: | ||
return self.parent.tokenizer.decode(self.results) | ||
|
||
def print_current_results(self): | ||
results = self.detokenize() | ||
for i, s in enumerate(results): | ||
seq_len = int(self.seq_lens[i]) | ||
print(f" {i}({len(self.results[i])}, {seq_len}): {s}") | ||
|
||
def add_result_token(self, tokens: torch.Tensor): | ||
for i in range(self.bs): | ||
token = tokens[i][0] | ||
if token == self.parent.end_token: | ||
self.done_result_indices.add(i) | ||
if i in self.done_result_indices: | ||
continue | ||
token = int(tokens[i, 0]) | ||
self.results[i].append(token) | ||
|
||
def allocate_seq_block_ids(self): | ||
for i in range(self.bs): | ||
sl = int(self.seq_lens[i]) | ||
if (sl % self.parent.block_seq_stride) == 0: | ||
needed_blocks = sl // self.parent.block_seq_stride + 1 | ||
else: | ||
needed_blocks = math.ceil(sl / self.parent.block_seq_stride) | ||
block_ids_row = self.seq_block_ids[i] | ||
while len(block_ids_row) < needed_blocks: | ||
block_ids_row.append(self.parent.alloc_page()) | ||
|
||
def prefill(self): | ||
model = self.parent.model | ||
attention_mask = model.attention_mask( | ||
model.input_mask(self.seq_lens, self.token_ids.shape[1]) | ||
) | ||
seq_block_ids_tensor = self.pad_block_ids() | ||
# print(f":: Invoke prefill:") | ||
trace_tensor("prefill.token_ids", self.token_ids) | ||
trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor) | ||
trace_tensor("prefill.attention_mask", attention_mask) | ||
# print("prefill.token_ids", self.token_ids) | ||
archana-ramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# print("prefill.seq_block_ids", seq_block_ids_tensor) | ||
# print("prefill.attention_mask", attention_mask.shape) | ||
# print("prefill.cache_state", self.cache_state[0].shape) | ||
self.prefill_logits = model.prefill( | ||
self.token_ids, | ||
attention_mask=attention_mask, | ||
seq_block_ids=seq_block_ids_tensor, | ||
cache_state=self.cache_state, | ||
) | ||
|
||
# TODO: Generalize the sampling and don't make it swap on/off cpu. | ||
# TODO: Normalize the output of extract_tokens_from_logits into | ||
# tensor [bs, 1]. | ||
tokens = torch.tensor( | ||
model.extract_tokens_from_logits(self.prefill_logits, self.seq_lens) | ||
).unsqueeze(1) | ||
# print(f":: Prefill results:\n{tokens.tolist()}") | ||
self.add_result_token(tokens) | ||
self.next_tokens = tokens.to(device=model.device) | ||
return self.cache_state | ||
|
||
def decode(self, cache_state): | ||
self.cache_state = cache_state | ||
model = self.parent.model | ||
start_positions = self.seq_lens.clone() | ||
self.seq_lens.add_(1) | ||
self.allocate_seq_block_ids() | ||
# TODO: Allocate more blocks on overflow. | ||
seq_block_ids_tensor = self.pad_block_ids() | ||
decode_attention_mask = model.decode_attention_mask( | ||
model.input_mask( | ||
self.seq_lens, | ||
seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride, | ||
) | ||
) | ||
trace_tensor("decode.token_ids", self.token_ids) | ||
trace_tensor("decode.start_positions", start_positions) | ||
trace_tensor("decode.seq_block_ids", seq_block_ids_tensor) | ||
trace_tensor("decode.attention_mask", decode_attention_mask) | ||
# print("decode.token_ids", self.token_ids) | ||
|
||
self.decode_logits = model.decode( | ||
self.token_ids, | ||
attention_mask=decode_attention_mask, | ||
start_positions=start_positions, | ||
seq_block_ids=seq_block_ids_tensor, | ||
cache_state=self.cache_state, | ||
) | ||
|
||
# print("decode", len(self.decode_logits)) | ||
# trace_tensor("decode.logits", self.decode_logits) | ||
# # TODO: Normalize the output of extract_tokens_from_logits into | ||
# # tensor [bs, 1]. | ||
tokens = torch.tensor( | ||
model.extract_tokens_from_logits(self.decode_logits, [1] * self.bs), | ||
device=self.parent.model.device, | ||
).unsqueeze(1) | ||
self.add_result_token(tokens) | ||
# self.next_tokens = tokens | ||
|
||
return self.cache_state | ||
|
||
def pad_block_ids(self) -> torch.Tensor: | ||
max_length = max(len(r) for r in self.seq_block_ids) | ||
rows = [r + (max_length - len(r)) * [0] for r in self.seq_block_ids] | ||
return torch.tensor(rows, device=self.parent.model.device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should steer towards putting requirements in the subproject requirements files instead of this top level file, especially if this is a test-only requirement