Skip to content

Commit

Permalink
[llama3] Enable a llama3 model.
Browse files Browse the repository at this point in the history
* I had a NYI assert for GQA. Fixes that.
* Just grabs some pre-made weights of unknown providence to test for gross correctness.
  • Loading branch information
stellaraccident committed Apr 21, 2024
1 parent 714dedf commit d746298
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 6 deletions.
20 changes: 16 additions & 4 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,6 @@ def forward(
xq=xq, xk=xk, mask=embedding_batch_mask
)

# TODO: Some model variants do some form of kv repetition to expand the
# count of kv heads to the count of attention heads used by the q.
assert self.head_count == self.head_count_kv, "NYI: KV expansion"

# Full sequence length.
kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride

Expand All @@ -334,6 +330,22 @@ def forward(
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}")

# Expand kv heads for GQA.
gqa_n_rep = self.head_count // self.head_count_kv
assert gqa_n_rep > 0
if gqa_n_rep > 1:

def repeat_kv(x: torch.Tensor) -> torch.Tensor:
bs, slen, n_kv_heads, head_dim = x.shape
return (
x.unsqueeze(-2)
.expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim)
)

xk = repeat_kv(xk)
xv = repeat_kv(xv)

# Tranpose into [bs, heads, sl, dim]
xq = xq.transpose(1, 2)
keys = xk.transpose(1, 2)
Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/types/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,11 +537,11 @@ def _dataset_load_helper(
) -> Dataset:
path = Path(path)
suffixes = path.suffixes
if file_type == "gguf" or suffixes == [".gguf"]:
if file_type == "gguf" or suffixes == [".gguf"] or suffixes[-1] == ".gguf":
from . import gguf_interop

return gguf_interop.load_file(path)
elif file_type == "irpa" or suffixes == [".irpa"]:
elif file_type == "irpa" or suffixes == [".irpa"] or suffixes[-1] == ".irpa":
return _dataset_load_irpa(path, mmap=mmap)
else:
raise IOError(
Expand Down
34 changes: 34 additions & 0 deletions sharktank/sharktank/utils/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,40 @@ def alias_dataset(from_name: str, to_name: str):
# Dataset definitions
################################################################################

Dataset(
"QuantFactory/Llama-3-8B_q4_1_gguf",
(
RemoteFile(
"gguf",
"QuantFactory/Meta-Llama-3-8B-GGUF",
"Meta-Llama-3-8B.Q4_1.gguf",
),
RemoteFile(
"tokenizer_config.json",
"NousResearch/Meta-Llama-3-8B",
"tokenizer_config.json",
extra_filenames=["tokenizer.json"],
),
),
).alias_to("llama3_8B_q4_1")

Dataset(
"QuantFactory/Llama-3-8B_q8_0_gguf",
(
RemoteFile(
"gguf",
"QuantFactory/Meta-Llama-3-8B-GGUF",
"Meta-Llama-3-8B.Q8_0.gguf",
),
RemoteFile(
"tokenizer_config.json",
"NousResearch/Meta-Llama-3-8B",
"tokenizer_config.json",
extra_filenames=["tokenizer.json"],
),
),
).alias_to("llama3_8B_q8_0")

Dataset(
"SlyEcho/open_llama_3b_v2_f16_gguf",
(
Expand Down

0 comments on commit d746298

Please sign in to comment.