Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Nov 11, 2023
1 parent 6095641 commit a9080b6
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def apply_rotary_emb(query, key, rope, interleave):
query_ = torch.view_as_complex(query_)
key_ = key.float().reshape(*key.shape[:-1], -1, 2)
key_ = torch.view_as_complex(key_)
rope = rope[:, :rope.size(1)//2].view(1, query_.size(1), 1, query_.size(3))
rope = rope[:, : rope.size(1) // 2].view(1, query_.size(1), 1, query_.size(3))
query_out = torch.view_as_real(query_ * rope).flatten(3)
key_out = torch.view_as_real(key_ * rope).flatten(3)
return query_out.transpose(1, 2).type_as(query), key_out.transpose(1, 2).type_as(
key
)
return query_out.transpose(1, 2).type_as(query), key_out.transpose(
1, 2
).type_as(key)
else:
cos, sin = rope.real, rope.imag
q_embed = (query * cos) + (rotate_half(query) * sin)
Expand Down Expand Up @@ -419,7 +419,9 @@ def forward(
start_pos = step
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(self.dim_per_head, maxseqlen=(seqlen + 2048))
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
Expand Down Expand Up @@ -461,7 +463,9 @@ def forward(
start_pos = 0
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(self.dim_per_head, maxseqlen=(seqlen + 2048))
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
Expand Down

0 comments on commit a9080b6

Please sign in to comment.