Skip to content

Commit

Permalink
yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Apr 24, 2024
1 parent 63b7ef8 commit 7f99dfa
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ def forward(
values = xv.transpose(1, 2)

# Flash attention.
attn_output = torch.nn.functional.scaled_dot_product_attention(xq, keys, values, attention_mask)
attn_output = torch.nn.functional.scaled_dot_product_attention(
xq, keys, values, attention_mask
)
attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)

# Project.
Expand Down
4 changes: 3 additions & 1 deletion sharktank/sharktank/models/llama/llama_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def forward(
values = values.transpose(1, 2)

# Flash attention.
attn_output = torch.nn.functional.scaled_dot_product_attention(xq, keys, values, attention_mask)
attn_output = torch.nn.functional.scaled_dot_product_attention(
xq, keys, values, attention_mask
)
attn_output = attn_output.transpose(1, 2).reshape(bs, q_len, -1)

# Project.
Expand Down

0 comments on commit 7f99dfa

Please sign in to comment.