Skip to content

Commit

Permalink
Fixrotary (OpenNMT#2511)
Browse files Browse the repository at this point in the history
* fix rotary embed with interleave false
  • Loading branch information
vince62s authored Nov 11, 2023
1 parent 212141b commit da2fe8f
Showing 1 changed file with 38 additions and 24 deletions.
62 changes: 38 additions & 24 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,42 @@
# are both < 2048 tokens.


def rotaryembeddings(dim: int, maxseqlen=8192, base=10000):
def rotaryembeddings(dim: int, maxseqlen=2048, base=10000):
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
tmax = torch.arange(maxseqlen, device=inv_freq.device)
rope = torch.outer(tmax, inv_freq).float()
# rope is now matrix [maxseqlen, dim/2]
rope = torch.polar(torch.ones_like(rope), rope)
rope = torch.cat((rope, rope), dim=1)
return rope


def apply_rotary_emb(query, key, rope, interleave=True):
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if not interleave:
query = torch.cat(
(-query[..., query.shape[-1] // 2 :], query[..., : query.shape[-1] // 2]),
dim=-1,
)
key = torch.cat(
(-key[..., key.shape[-1] // 2 :], key[..., : key.shape[-1] // 2]), dim=-1
)
query_ = query.float().reshape(*query.shape[:-1], -1, 2)
query_ = torch.view_as_complex(query_)
key_ = key.float().reshape(*key.shape[:-1], -1, 2)
key_ = torch.view_as_complex(key_)
rope = rope.view(1, query_.size(1), 1, query_.size(3))
query_out = torch.view_as_real(query_ * rope).flatten(3)
key_out = torch.view_as_real(key_ * rope).flatten(3)
return query_out.transpose(1, 2).type_as(query), key_out.transpose(1, 2).type_as(
key
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_emb(query, key, rope, interleave):
if interleave:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query_ = query.float().reshape(*query.shape[:-1], -1, 2)
query_ = torch.view_as_complex(query_)
key_ = key.float().reshape(*key.shape[:-1], -1, 2)
key_ = torch.view_as_complex(key_)
rope = rope[:, : rope.size(1) // 2].view(1, query_.size(1), 1, query_.size(3))
query_out = torch.view_as_real(query_ * rope).flatten(3)
key_out = torch.view_as_real(key_ * rope).flatten(3)
return query_out.transpose(1, 2).type_as(query), key_out.transpose(
1, 2
).type_as(key)
else:
cos, sin = rope.real, rope.imag
q_embed = (query * cos) + (rotate_half(query) * sin)
k_embed = (key * cos) + (rotate_half(key) * sin)
return q_embed.type_as(query), k_embed.type_as(key)


# Help functions for max_relative positions
Expand Down Expand Up @@ -412,6 +418,10 @@ def forward(
if self.max_relative_positions == -1: # Rotary Embeddings
start_pos = step
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
Expand Down Expand Up @@ -444,14 +454,19 @@ def forward(
key = self.maybe_ckpt(self.linear_keys, key)
value = self.maybe_ckpt(self.linear_values, value)
query = self.maybe_ckpt(self.linear_query, query)

key = shape(key, self.dim_per_head)
value = shape(value, self.dim_per_head)
query = shape(query, self.dim_per_head)

if self.max_relative_positions == -1: # Rotary Embeddings
start_pos = 0
seqlen = query.size(2)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)
Expand All @@ -472,7 +487,6 @@ def forward(
# Ultimately flashv2 will be part of pytorch https://github.com/pytorch/pytorch/pull/105602
# In the meantime: if vanilla tranformer or Rotary embeddings (not rel_pos, not alibi)
# then use flash2 if seq len > 256 otherwise use xtransformer from pt2 uptream

flash2 = (
self.flash2
and l > 256 # https://github.com/Dao-AILab/flash-attention/issues/591
Expand Down

0 comments on commit da2fe8f

Please sign in to comment.