From d746298aef9d1944edf50317e7fb985ee01fa9e1 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 21 Apr 2024 16:27:11 -0700 Subject: [PATCH] [llama3] Enable a llama3 model. * I had a NYI assert for GQA. Fixes that. * Just grabs some pre-made weights of unknown providence to test for gross correctness. --- sharktank/sharktank/models/llama/llama.py | 20 ++++++++++--- sharktank/sharktank/types/theta.py | 4 +-- sharktank/sharktank/utils/hf_datasets.py | 34 +++++++++++++++++++++++ 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 6829e0ed2..a90b889ed 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -305,10 +305,6 @@ def forward( xq=xq, xk=xk, mask=embedding_batch_mask ) - # TODO: Some model variants do some form of kv repetition to expand the - # count of kv heads to the count of attention heads used by the q. - assert self.head_count == self.head_count_kv, "NYI: KV expansion" - # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride @@ -334,6 +330,22 @@ def forward( else: raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}") + # Expand kv heads for GQA. + gqa_n_rep = self.head_count // self.head_count_kv + assert gqa_n_rep > 0 + if gqa_n_rep > 1: + + def repeat_kv(x: torch.Tensor) -> torch.Tensor: + bs, slen, n_kv_heads, head_dim = x.shape + return ( + x.unsqueeze(-2) + .expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim) + ) + + xk = repeat_kv(xk) + xv = repeat_kv(xv) + # Tranpose into [bs, heads, sl, dim] xq = xq.transpose(1, 2) keys = xk.transpose(1, 2) diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 2a5095368..3db3c59f4 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -537,11 +537,11 @@ def _dataset_load_helper( ) -> Dataset: path = Path(path) suffixes = path.suffixes - if file_type == "gguf" or suffixes == [".gguf"]: + if file_type == "gguf" or suffixes == [".gguf"] or suffixes[-1] == ".gguf": from . import gguf_interop return gguf_interop.load_file(path) - elif file_type == "irpa" or suffixes == [".irpa"]: + elif file_type == "irpa" or suffixes == [".irpa"] or suffixes[-1] == ".irpa": return _dataset_load_irpa(path, mmap=mmap) else: raise IOError( diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index 08bdfb458..842b46cdb 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -83,6 +83,40 @@ def alias_dataset(from_name: str, to_name: str): # Dataset definitions ################################################################################ +Dataset( + "QuantFactory/Llama-3-8B_q4_1_gguf", + ( + RemoteFile( + "gguf", + "QuantFactory/Meta-Llama-3-8B-GGUF", + "Meta-Llama-3-8B.Q4_1.gguf", + ), + RemoteFile( + "tokenizer_config.json", + "NousResearch/Meta-Llama-3-8B", + "tokenizer_config.json", + extra_filenames=["tokenizer.json"], + ), + ), +).alias_to("llama3_8B_q4_1") + +Dataset( + "QuantFactory/Llama-3-8B_q8_0_gguf", + ( + RemoteFile( + "gguf", + "QuantFactory/Meta-Llama-3-8B-GGUF", + "Meta-Llama-3-8B.Q8_0.gguf", + ), + RemoteFile( + "tokenizer_config.json", + "NousResearch/Meta-Llama-3-8B", + "tokenizer_config.json", + extra_filenames=["tokenizer.json"], + ), + ), +).alias_to("llama3_8B_q8_0") + Dataset( "SlyEcho/open_llama_3b_v2_f16_gguf", (