From da2fe8f86778734cfbeabc68175a5721610c975c Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Sat, 11 Nov 2023 12:29:13 +0100 Subject: [PATCH] Fixrotary (#2511) * fix rotary embed with interleave false --- onmt/modules/multi_headed_attn.py | 62 +++++++++++++++++++------------ 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index f7264340ce..9eb8d465c5 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -19,36 +19,42 @@ # are both < 2048 tokens. -def rotaryembeddings(dim: int, maxseqlen=8192, base=10000): +def rotaryembeddings(dim: int, maxseqlen=2048, base=10000): inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) tmax = torch.arange(maxseqlen, device=inv_freq.device) rope = torch.outer(tmax, inv_freq).float() # rope is now matrix [maxseqlen, dim/2] rope = torch.polar(torch.ones_like(rope), rope) + rope = torch.cat((rope, rope), dim=1) return rope -def apply_rotary_emb(query, key, rope, interleave=True): - query = query.transpose(1, 2) - key = key.transpose(1, 2) - if not interleave: - query = torch.cat( - (-query[..., query.shape[-1] // 2 :], query[..., : query.shape[-1] // 2]), - dim=-1, - ) - key = torch.cat( - (-key[..., key.shape[-1] // 2 :], key[..., : key.shape[-1] // 2]), dim=-1 - ) - query_ = query.float().reshape(*query.shape[:-1], -1, 2) - query_ = torch.view_as_complex(query_) - key_ = key.float().reshape(*key.shape[:-1], -1, 2) - key_ = torch.view_as_complex(key_) - rope = rope.view(1, query_.size(1), 1, query_.size(3)) - query_out = torch.view_as_real(query_ * rope).flatten(3) - key_out = torch.view_as_real(key_ * rope).flatten(3) - return query_out.transpose(1, 2).type_as(query), key_out.transpose(1, 2).type_as( - key - ) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_emb(query, key, rope, interleave): + if interleave: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + query_ = query.float().reshape(*query.shape[:-1], -1, 2) + query_ = torch.view_as_complex(query_) + key_ = key.float().reshape(*key.shape[:-1], -1, 2) + key_ = torch.view_as_complex(key_) + rope = rope[:, : rope.size(1) // 2].view(1, query_.size(1), 1, query_.size(3)) + query_out = torch.view_as_real(query_ * rope).flatten(3) + key_out = torch.view_as_real(key_ * rope).flatten(3) + return query_out.transpose(1, 2).type_as(query), key_out.transpose( + 1, 2 + ).type_as(key) + else: + cos, sin = rope.real, rope.imag + q_embed = (query * cos) + (rotate_half(query) * sin) + k_embed = (key * cos) + (rotate_half(key) * sin) + return q_embed.type_as(query), k_embed.type_as(key) # Help functions for max_relative positions @@ -412,6 +418,10 @@ def forward( if self.max_relative_positions == -1: # Rotary Embeddings start_pos = step seqlen = query.size(2) + if seqlen > self.rope.size(0): + self.rope = rotaryembeddings( + self.dim_per_head, maxseqlen=(seqlen + 2048) + ) rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave @@ -444,6 +454,7 @@ def forward( key = self.maybe_ckpt(self.linear_keys, key) value = self.maybe_ckpt(self.linear_values, value) query = self.maybe_ckpt(self.linear_query, query) + key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) query = shape(query, self.dim_per_head) @@ -451,7 +462,11 @@ def forward( if self.max_relative_positions == -1: # Rotary Embeddings start_pos = 0 seqlen = query.size(2) - rope = self.rope[start_pos : start_pos + seqlen].to(query.device) + if seqlen > self.rope.size(0): + self.rope = rotaryembeddings( + self.dim_per_head, maxseqlen=(seqlen + 2048) + ) + rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave ) @@ -472,7 +487,6 @@ def forward( # Ultimately flashv2 will be part of pytorch https://github.com/pytorch/pytorch/pull/105602 # In the meantime: if vanilla tranformer or Rotary embeddings (not rel_pos, not alibi) # then use flash2 if seq len > 256 otherwise use xtransformer from pt2 uptream - flash2 = ( self.flash2 and l > 256 # https://github.com/Dao-AILab/flash-attention/issues/591