Skip to content

Commit

Permalink
Add rotary_interleave=false (Hugging face models) (OpenNMT#2507)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Nov 8, 2023
1 parent 3743fa3 commit 212141b
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 21 deletions.
88 changes: 71 additions & 17 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
):
"""
Args:
Expand All @@ -60,6 +61,9 @@ def __init__(
max_relative_positions (int):
Max distance between inputs in relative positions
representations
relative_positions_buckets (int):
relative position bias see
https://github.com/google-research/text-to-text-transfer-transformer
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
full_context_alignment (bool):
whether enable an extra full context decoder forward for
Expand All @@ -69,9 +73,19 @@ def __init__(
pos_ffn_activation_fn (ActivationFunction):
activation function choice for PositionwiseFeedForward layer
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
num_kv (int): number of heads for KV when different vs Q (multiquery)
add_ffnbias (bool): whether to add bias to the FF nn.Linear
parallel_residual (bool): Use parallel residual connections in each layer block, as used
by the GPT-J and GPT-NeoX models
shared_layer_norm (bool): When using parallel residual, share the input and post
attention layer norms.
layer_norm (string): type of layer normalization standard/rms
norm_eps (float): layer norm epsilon
use_ckpting (List): layers for which we checkpoint for backward
parallel_gpu (int): Number of gpu for tensor parallelism
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
"""
super(TransformerDecoderLayerBase, self).__init__()

Expand All @@ -83,6 +97,7 @@ def __init__(
dropout=attention_dropout,
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
attn_type="self",
add_qkvbias=add_qkvbias,
num_kv=num_kv,
Expand Down Expand Up @@ -238,6 +253,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
):
"""
Args:
Expand Down Expand Up @@ -266,6 +282,7 @@ def __init__(
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
)
self.context_attn = MultiHeadedAttention(
heads,
Expand Down Expand Up @@ -424,6 +441,7 @@ def from_opt(cls, opt, embeddings):
if opt.parallel_mode == "tensor_parallel"
else 1,
sliding_window=opt.sliding_window,
rotary_interleave=opt.rotary_interleave,
)

def init_state(self, src, enc_out, enc_final_hs):
Expand Down Expand Up @@ -486,8 +504,21 @@ class TransformerDecoder(TransformerDecoderBase):
alignment_layer (int): N° Layer to supervise with for alignment guiding
alignment_heads (int):
N. of cross attention heads to use for alignment guiding
pos_ffn_activation_fn (ActivationFunction):
activation function choice for PositionwiseFeedForward layer
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
num_kv (int): number of heads for KV when different vs Q (multiquery)
add_ffnbias (bool): whether to add bias to the FF nn.Linear
parallel_residual (bool): Use parallel residual connections in each layer block, as used
by the GPT-J and GPT-NeoX models
shared_layer_norm (bool): When using parallel residual, share the input and post
attention layer norms.
layer_norm (string): type of layer normalization standard/rms
norm_eps (float): layer norm epsilon
use_ckpting (List): layers for which we checkpoint for backward
parallel_gpu (int): Number of gpu for tensor parallelism
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
"""

def __init__(
Expand Down Expand Up @@ -518,6 +549,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
):
super(TransformerDecoder, self).__init__(
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
Expand Down Expand Up @@ -548,6 +580,7 @@ def __init__(
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -716,22 +749,41 @@ def _forward(
class TransformerLMDecoder(TransformerDecoderBase):
"""The Transformer decoder from GPT-2
Args:
num_layers (int): number of decoder layers.
d_model (int): size of the model
heads (int): number of heads
d_ff (int): size of the inner FF layer
copy_attn (bool): if using a separate copy attention
self_attn_type (str): type of self-attention scaled-dot, average
dropout (float): dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float): dropout in context_attn (and self-attn(avg))
embeddings (onmt.modules.Embeddings):
embeddings to use, should have positional encodings
max_relative_positions (int):
Max distance between inputs in relative positions representations
relative_positions_buckets (int):
Number of buckets when using Relative positions bias
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
num_layers (int): number of decoder layers.
d_model (int): size of the model
heads (int): number of heads
d_ff (int): size of the inner FF layer
copy_attn (bool): if using a separate copy attention
self_attn_type (str): type of self-attention scaled-dot, average
dropout (float): dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float): dropout in context_attn (and self-attn(avg))
embeddings (onmt.modules.Embeddings):
embeddings to use, should have positional encodings
max_relative_positions (int):
Max distance between inputs in relative positions representations
relative_positions_buckets (int):
Number of buckets when using Relative positions bias
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
full_context_alignment (bool):
whether enable an extra full context decoder forward for alignment
alignment_layer (int): N° Layer to supervise with for alignment guiding
alignment_heads (int):
N. of cross attention heads to use for alignment guiding
pos_ffn_activation_fn (ActivationFunction):
activation function choice for PositionwiseFeedForward layer
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
num_kv (int): number of heads for KV when different vs Q (multiquery)
add_ffnbias (bool): whether to add bias to the FF nn.Linear
parallel_residual (bool): Use parallel residual connections in each layer block, as used
by the GPT-J and GPT-NeoX models
shared_layer_norm (bool): When using parallel residual, share the input and post
attention layer norms.
layer_norm (string): type of layer normalization standard/rms
norm_eps (float): layer norm epsilon
use_ckpting (List): layers for which we checkpoint for backward
parallel_gpu (int): Number of gpu for tensor parallelism
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
"""

def __init__(
Expand Down Expand Up @@ -762,6 +814,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
):
super(TransformerLMDecoder, self).__init__(
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
Expand Down Expand Up @@ -791,6 +844,7 @@ def __init__(
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
)
for i in range(num_layers)
]
Expand Down
16 changes: 16 additions & 0 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ class TransformerEncoderLayer(nn.Module):
dropout (float): dropout probability(0-1.0).
pos_ffn_activation_fn (ActivationFunction):
activation function choice for PositionwiseFeedForward layer
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
num_kv (int): number of heads for KV when different vs Q (multiquery)
add_ffnbias (bool): whether to add bias to the FF nn.Linear
parallel_residual (bool): Use parallel residual connections in each layer block, as used
by the GPT-J and GPT-NeoX models
layer_norm (string): type of layer normalization standard/rms
norm_eps (float): layer norm epsilon
use_ckpting (List): layers for which we checkpoint for backward
parallel_gpu (int): Number of gpu for tensor parallelism
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
"""

def __init__(
Expand All @@ -49,6 +60,7 @@ def __init__(
norm_eps=1e-6,
use_ckpting=[],
parallel_gpu=1,
rotary_interleave=True,
):
super(TransformerEncoderLayer, self).__init__()

Expand All @@ -59,6 +71,7 @@ def __init__(
is_decoder=False,
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
attn_type="self",
add_qkvbias=add_qkvbias,
num_kv=num_kv,
Expand Down Expand Up @@ -163,6 +176,7 @@ def __init__(
norm_eps=1e-6,
use_ckpting=[],
parallel_gpu=1,
rotary_interleave=True,
):
super(TransformerEncoder, self).__init__()

Expand All @@ -186,6 +200,7 @@ def __init__(
norm_eps=norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
rotary_interleave=rotary_interleave,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -223,6 +238,7 @@ def from_opt(cls, opt, embeddings):
parallel_gpu=opt.world_size
if opt.parallel_mode == "tensor_parallel"
else 1,
rotary_interleave=opt.rotary_interleave,
)

def forward(self, src, src_len=None):
Expand Down
20 changes: 17 additions & 3 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,17 @@ def rotaryembeddings(dim: int, maxseqlen=8192, base=10000):
return rope


def apply_rotary_emb(query, key, rope):
def apply_rotary_emb(query, key, rope, interleave=True):
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if not interleave:
query = torch.cat(
(-query[..., query.shape[-1] // 2 :], query[..., : query.shape[-1] // 2]),
dim=-1,
)
key = torch.cat(
(-key[..., key.shape[-1] // 2 :], key[..., : key.shape[-1] // 2]), dim=-1
)
query_ = query.float().reshape(*query.shape[:-1], -1, 2)
query_ = torch.view_as_complex(query_)
key_ = key.float().reshape(*key.shape[:-1], -1, 2)
Expand Down Expand Up @@ -243,6 +251,7 @@ def __init__(
is_decoder: bool = True,
max_relative_positions: int = 0,
relative_positions_buckets: int = 0,
rotary_interleave: bool = True,
attn_type: str = None,
add_qkvbias=False,
num_kv=0,
Expand Down Expand Up @@ -336,6 +345,7 @@ def __init__(

if max_relative_positions == -1: # rotary embeddings
self.rope = rotaryembeddings(self.dim_per_head)
self.rotary_interleave = rotary_interleave

if max_relative_positions == -2: # alibi positional bias
self.alibi = AlibiPositionalBias(head_count)
Expand Down Expand Up @@ -403,7 +413,9 @@ def forward(
start_pos = step
seqlen = query.size(2)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(query, key, rope=rope)
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)

if self.layer_cache[1]["keys"].numel() != 0:
key = torch.cat((self.layer_cache[1]["keys"], key), dim=2)
Expand Down Expand Up @@ -440,7 +452,9 @@ def forward(
start_pos = 0
seqlen = query.size(2)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
query, key = apply_rotary_emb(query, key, rope=rope)
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)

b, h, l, d = key.size()
if self.num_kv > 0:
Expand Down
14 changes: 13 additions & 1 deletion onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,17 @@ def model_opts(parser):
help="This setting enable relative position bias"
"more info: https://github.com/google-research/text-to-text-transfer-transformer",
)
group.add(
"--rotary_interleave",
"-rotary_interleave",
type=bool,
default=True,
help="Interleave the head dimensions when rotary"
" embeddings are applied."
" Otherwise the head dimensions are sliced in half."
"True = default Llama from Meta (original)"
"False = used by all Hugging face models",
)
group.add(
"--heads",
"-heads",
Expand Down Expand Up @@ -927,7 +938,8 @@ def model_opts(parser):
"-shared_layer_norm",
action="store_true",
help="Use a shared layer_norm in parallel residual attention"
"Note: must be true for Falcon 7B / false for Falcon 40B",
"Note: must be true for Falcon 7B / false for Falcon 40B"
"same for GPT-J and GPT-NeoX models",
)
# Alignement options
group = parser.add_argument_group("Model - Alignement")
Expand Down

0 comments on commit 212141b

Please sign in to comment.