Skip to content

Commit

Permalink
[TEST ONLY]Drop scaled_dot_product_attention
Browse files Browse the repository at this point in the history
ghstack-source-id: 5f8656957e939e0727b571f61145306405663cbd
Pull Request resolved: #6367
  • Loading branch information
helunwencser committed Oct 18, 2024
1 parent 7493aae commit f1bb895
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# Please refer to README.md in the same folder for more information.

import math
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple
Expand Down Expand Up @@ -251,7 +252,10 @@ def forward(

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = scores + attn_mask
scores = F.softmax(scores.float(), dim=-1).type_as(q)
y = torch.matmul(scores, v)

return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down

0 comments on commit f1bb895

Please sign in to comment.