Skip to content

Commit

Permalink
More pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Oct 20, 2024
1 parent 9e0deb1 commit 5b44929
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,13 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask

attn_weights = ops.softmax(ops.to(attn_weights, dtype=torch.float32), dim=-1)
attn_weights = ops.softmax(
ops.to(attn_weights, dtype=torch.float32), dim=-1
)
attn_weights = ops.to(attn_weights, dtype=xq.dtype)
attn_output = ops.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_output = ops.matmul(
attn_weights, values
) # (bs, heads, slen, head_dim)
else:
is_causal = attention_mask is None and batch_seq_len == 1
attn_output = torch.nn.functional.scaled_dot_product_attention(
Expand Down

0 comments on commit 5b44929

Please sign in to comment.