From fd2f1452517d85d5dca58189dd18b9572f5aef8a Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Mon, 13 Nov 2023 13:20:13 +0100 Subject: [PATCH] fix rope device for long sequence (#2514) * fix rope device for long sequence * restore device removed by mistake --- onmt/modules/multi_headed_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 9eb8d465c5..963ca2d78b 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -421,7 +421,7 @@ def forward( if seqlen > self.rope.size(0): self.rope = rotaryembeddings( self.dim_per_head, maxseqlen=(seqlen + 2048) - ) + ).to(self.rope.device) rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave @@ -465,8 +465,8 @@ def forward( if seqlen > self.rope.size(0): self.rope = rotaryembeddings( self.dim_per_head, maxseqlen=(seqlen + 2048) - ) - rope = self.rope[start_pos : start_pos + seqlen] + ).to(self.rope.device) + rope = self.rope[start_pos : start_pos + seqlen].to(query.device) query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave )