From a9080b60c3b39fd066e1c6f66ebc40c136a10d4e Mon Sep 17 00:00:00 2001 From: vince62s Date: Sat, 11 Nov 2023 12:20:59 +0100 Subject: [PATCH] . --- onmt/modules/multi_headed_attn.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 99bb32a20f..9eb8d465c5 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -44,12 +44,12 @@ def apply_rotary_emb(query, key, rope, interleave): query_ = torch.view_as_complex(query_) key_ = key.float().reshape(*key.shape[:-1], -1, 2) key_ = torch.view_as_complex(key_) - rope = rope[:, :rope.size(1)//2].view(1, query_.size(1), 1, query_.size(3)) + rope = rope[:, : rope.size(1) // 2].view(1, query_.size(1), 1, query_.size(3)) query_out = torch.view_as_real(query_ * rope).flatten(3) key_out = torch.view_as_real(key_ * rope).flatten(3) - return query_out.transpose(1, 2).type_as(query), key_out.transpose(1, 2).type_as( - key - ) + return query_out.transpose(1, 2).type_as(query), key_out.transpose( + 1, 2 + ).type_as(key) else: cos, sin = rope.real, rope.imag q_embed = (query * cos) + (rotate_half(query) * sin) @@ -419,7 +419,9 @@ def forward( start_pos = step seqlen = query.size(2) if seqlen > self.rope.size(0): - self.rope = rotaryembeddings(self.dim_per_head, maxseqlen=(seqlen + 2048)) + self.rope = rotaryembeddings( + self.dim_per_head, maxseqlen=(seqlen + 2048) + ) rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave @@ -461,7 +463,9 @@ def forward( start_pos = 0 seqlen = query.size(2) if seqlen > self.rope.size(0): - self.rope = rotaryembeddings(self.dim_per_head, maxseqlen=(seqlen + 2048)) + self.rope = rotaryembeddings( + self.dim_per_head, maxseqlen=(seqlen + 2048) + ) rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave