Skip to content

Commit

Permalink
fix rope device for long sequence (OpenNMT#2514)
Browse files Browse the repository at this point in the history
* fix rope device for long sequence
* restore device removed by mistake
  • Loading branch information
vince62s authored Nov 13, 2023
1 parent f3059a5 commit fd2f145
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def forward(
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
)
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
Expand Down Expand Up @@ -465,8 +465,8 @@ def forward(
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
)
rope = self.rope[start_pos : start_pos + seqlen]
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)
Expand Down

0 comments on commit fd2f145

Please sign in to comment.