Skip to content

Commit

Permalink
Add in some missing grok specific model structure and constants
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Sep 9, 2024
1 parent b5f535d commit 4095db0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
head_count_kv: int,
rms_epsilon: float,
use_hf: bool = False,
use_grok: bool = False,
):
super().__init__(theta)
self.add_module(
Expand All @@ -53,6 +54,7 @@ def __init__(
self.head_dim = head_dim
self.head_count_kv = head_count_kv
self.use_hf = use_hf
self.use_grok = use_grok

def forward(
self,
Expand Down Expand Up @@ -141,6 +143,8 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_grok:
attn_weights = 30 * torch.tanh(attn_weights * (0.08838834764831845 / 30.0))
self.assert_not_nan(attn_weights)

# Apply attention mask.
Expand Down
4 changes: 4 additions & 0 deletions sharktank/sharktank/models/grok/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
use_hf=True,
use_grok=True,
)
)
self.attn_blocks.append(
Expand Down Expand Up @@ -250,6 +252,7 @@ def decode(
)

h = self.token_embedding(tokens)
h *= 78.38367176906169
self.trace_tensor("mixtral.token_embedding", h)

# Iterate over attention blocks.
Expand Down Expand Up @@ -278,4 +281,5 @@ def decode(

h = self.output_norm(h)
logits = self.output_lm_head(h)
logits = logits * 0.5773502691896257
return logits

0 comments on commit 4095db0

Please sign in to comment.