From 212141bb3ab91717857ad39b1b2612848db81f1a Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Wed, 8 Nov 2023 15:27:41 +0100 Subject: [PATCH] Add rotary_interleave=false (Hugging face models) (#2507) --- onmt/decoders/transformer.py | 88 +++++++++++++++++++++++++------ onmt/encoders/transformer.py | 16 ++++++ onmt/modules/multi_headed_attn.py | 20 +++++-- onmt/opts.py | 14 ++++- 4 files changed, 117 insertions(+), 21 deletions(-) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 84225124ba..9c78350460 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -42,6 +42,7 @@ def __init__( use_ckpting=[], parallel_gpu=1, sliding_window=0, + rotary_interleave=True, ): """ Args: @@ -60,6 +61,9 @@ def __init__( max_relative_positions (int): Max distance between inputs in relative positions representations + relative_positions_buckets (int): + relative position bias see + https://github.com/google-research/text-to-text-transfer-transformer aan_useffn (bool): Turn on the FFN layer in the AAN decoder full_context_alignment (bool): whether enable an extra full context decoder forward for @@ -69,9 +73,19 @@ def __init__( pos_ffn_activation_fn (ActivationFunction): activation function choice for PositionwiseFeedForward layer add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear + num_kv (int): number of heads for KV when different vs Q (multiquery) + add_ffnbias (bool): whether to add bias to the FF nn.Linear + parallel_residual (bool): Use parallel residual connections in each layer block, as used + by the GPT-J and GPT-NeoX models + shared_layer_norm (bool): When using parallel residual, share the input and post + attention layer norms. layer_norm (string): type of layer normalization standard/rms norm_eps (float): layer norm epsilon - + use_ckpting (List): layers for which we checkpoint for backward + parallel_gpu (int): Number of gpu for tensor parallelism + sliding_window (int): Width of the band mask and KV cache (cf Mistral Model) + rotary_interleave (bool): Interleave the head dimensions when rotary + embeddings are applied """ super(TransformerDecoderLayerBase, self).__init__() @@ -83,6 +97,7 @@ def __init__( dropout=attention_dropout, max_relative_positions=max_relative_positions, relative_positions_buckets=relative_positions_buckets, + rotary_interleave=rotary_interleave, attn_type="self", add_qkvbias=add_qkvbias, num_kv=num_kv, @@ -238,6 +253,7 @@ def __init__( use_ckpting=[], parallel_gpu=1, sliding_window=0, + rotary_interleave=True, ): """ Args: @@ -266,6 +282,7 @@ def __init__( use_ckpting=use_ckpting, parallel_gpu=parallel_gpu, sliding_window=sliding_window, + rotary_interleave=rotary_interleave, ) self.context_attn = MultiHeadedAttention( heads, @@ -424,6 +441,7 @@ def from_opt(cls, opt, embeddings): if opt.parallel_mode == "tensor_parallel" else 1, sliding_window=opt.sliding_window, + rotary_interleave=opt.rotary_interleave, ) def init_state(self, src, enc_out, enc_final_hs): @@ -486,8 +504,21 @@ class TransformerDecoder(TransformerDecoderBase): alignment_layer (int): N° Layer to supervise with for alignment guiding alignment_heads (int): N. of cross attention heads to use for alignment guiding + pos_ffn_activation_fn (ActivationFunction): + activation function choice for PositionwiseFeedForward layer add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear + num_kv (int): number of heads for KV when different vs Q (multiquery) + add_ffnbias (bool): whether to add bias to the FF nn.Linear + parallel_residual (bool): Use parallel residual connections in each layer block, as used + by the GPT-J and GPT-NeoX models + shared_layer_norm (bool): When using parallel residual, share the input and post + attention layer norms. layer_norm (string): type of layer normalization standard/rms + norm_eps (float): layer norm epsilon + use_ckpting (List): layers for which we checkpoint for backward + parallel_gpu (int): Number of gpu for tensor parallelism + sliding_window (int): Width of the band mask and KV cache (cf Mistral Model) + rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied """ def __init__( @@ -518,6 +549,7 @@ def __init__( use_ckpting=[], parallel_gpu=1, sliding_window=0, + rotary_interleave=True, ): super(TransformerDecoder, self).__init__( d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps @@ -548,6 +580,7 @@ def __init__( use_ckpting=use_ckpting, parallel_gpu=parallel_gpu, sliding_window=sliding_window, + rotary_interleave=rotary_interleave, ) for i in range(num_layers) ] @@ -716,22 +749,41 @@ def _forward( class TransformerLMDecoder(TransformerDecoderBase): """The Transformer decoder from GPT-2 Args: - num_layers (int): number of decoder layers. - d_model (int): size of the model - heads (int): number of heads - d_ff (int): size of the inner FF layer - copy_attn (bool): if using a separate copy attention - self_attn_type (str): type of self-attention scaled-dot, average - dropout (float): dropout in residual, self-attn(dot) and feed-forward - attention_dropout (float): dropout in context_attn (and self-attn(avg)) - embeddings (onmt.modules.Embeddings): - embeddings to use, should have positional encodings - max_relative_positions (int): - Max distance between inputs in relative positions representations - relative_positions_buckets (int): - Number of buckets when using Relative positions bias - aan_useffn (bool): Turn on the FFN layer in the AAN decoder - add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear + num_layers (int): number of decoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + copy_attn (bool): if using a separate copy attention + self_attn_type (str): type of self-attention scaled-dot, average + dropout (float): dropout in residual, self-attn(dot) and feed-forward + attention_dropout (float): dropout in context_attn (and self-attn(avg)) + embeddings (onmt.modules.Embeddings): + embeddings to use, should have positional encodings + max_relative_positions (int): + Max distance between inputs in relative positions representations + relative_positions_buckets (int): + Number of buckets when using Relative positions bias + aan_useffn (bool): Turn on the FFN layer in the AAN decoder + full_context_alignment (bool): + whether enable an extra full context decoder forward for alignment + alignment_layer (int): N° Layer to supervise with for alignment guiding + alignment_heads (int): + N. of cross attention heads to use for alignment guiding + pos_ffn_activation_fn (ActivationFunction): + activation function choice for PositionwiseFeedForward layer + add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear + num_kv (int): number of heads for KV when different vs Q (multiquery) + add_ffnbias (bool): whether to add bias to the FF nn.Linear + parallel_residual (bool): Use parallel residual connections in each layer block, as used + by the GPT-J and GPT-NeoX models + shared_layer_norm (bool): When using parallel residual, share the input and post + attention layer norms. + layer_norm (string): type of layer normalization standard/rms + norm_eps (float): layer norm epsilon + use_ckpting (List): layers for which we checkpoint for backward + parallel_gpu (int): Number of gpu for tensor parallelism + sliding_window (int): Width of the band mask and KV cache (cf Mistral Model) + rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied """ def __init__( @@ -762,6 +814,7 @@ def __init__( use_ckpting=[], parallel_gpu=1, sliding_window=0, + rotary_interleave=True, ): super(TransformerLMDecoder, self).__init__( d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps @@ -791,6 +844,7 @@ def __init__( use_ckpting=use_ckpting, parallel_gpu=parallel_gpu, sliding_window=sliding_window, + rotary_interleave=rotary_interleave, ) for i in range(num_layers) ] diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 184d44881f..12998957dc 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -29,6 +29,17 @@ class TransformerEncoderLayer(nn.Module): dropout (float): dropout probability(0-1.0). pos_ffn_activation_fn (ActivationFunction): activation function choice for PositionwiseFeedForward layer + add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear + num_kv (int): number of heads for KV when different vs Q (multiquery) + add_ffnbias (bool): whether to add bias to the FF nn.Linear + parallel_residual (bool): Use parallel residual connections in each layer block, as used + by the GPT-J and GPT-NeoX models + layer_norm (string): type of layer normalization standard/rms + norm_eps (float): layer norm epsilon + use_ckpting (List): layers for which we checkpoint for backward + parallel_gpu (int): Number of gpu for tensor parallelism + rotary_interleave (bool): Interleave the head dimensions when rotary + embeddings are applied """ def __init__( @@ -49,6 +60,7 @@ def __init__( norm_eps=1e-6, use_ckpting=[], parallel_gpu=1, + rotary_interleave=True, ): super(TransformerEncoderLayer, self).__init__() @@ -59,6 +71,7 @@ def __init__( is_decoder=False, max_relative_positions=max_relative_positions, relative_positions_buckets=relative_positions_buckets, + rotary_interleave=rotary_interleave, attn_type="self", add_qkvbias=add_qkvbias, num_kv=num_kv, @@ -163,6 +176,7 @@ def __init__( norm_eps=1e-6, use_ckpting=[], parallel_gpu=1, + rotary_interleave=True, ): super(TransformerEncoder, self).__init__() @@ -186,6 +200,7 @@ def __init__( norm_eps=norm_eps, use_ckpting=use_ckpting, parallel_gpu=parallel_gpu, + rotary_interleave=rotary_interleave, ) for i in range(num_layers) ] @@ -223,6 +238,7 @@ def from_opt(cls, opt, embeddings): parallel_gpu=opt.world_size if opt.parallel_mode == "tensor_parallel" else 1, + rotary_interleave=opt.rotary_interleave, ) def forward(self, src, src_len=None): diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index de5a6d085c..f7264340ce 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -28,9 +28,17 @@ def rotaryembeddings(dim: int, maxseqlen=8192, base=10000): return rope -def apply_rotary_emb(query, key, rope): +def apply_rotary_emb(query, key, rope, interleave=True): query = query.transpose(1, 2) key = key.transpose(1, 2) + if not interleave: + query = torch.cat( + (-query[..., query.shape[-1] // 2 :], query[..., : query.shape[-1] // 2]), + dim=-1, + ) + key = torch.cat( + (-key[..., key.shape[-1] // 2 :], key[..., : key.shape[-1] // 2]), dim=-1 + ) query_ = query.float().reshape(*query.shape[:-1], -1, 2) query_ = torch.view_as_complex(query_) key_ = key.float().reshape(*key.shape[:-1], -1, 2) @@ -243,6 +251,7 @@ def __init__( is_decoder: bool = True, max_relative_positions: int = 0, relative_positions_buckets: int = 0, + rotary_interleave: bool = True, attn_type: str = None, add_qkvbias=False, num_kv=0, @@ -336,6 +345,7 @@ def __init__( if max_relative_positions == -1: # rotary embeddings self.rope = rotaryembeddings(self.dim_per_head) + self.rotary_interleave = rotary_interleave if max_relative_positions == -2: # alibi positional bias self.alibi = AlibiPositionalBias(head_count) @@ -403,7 +413,9 @@ def forward( start_pos = step seqlen = query.size(2) rope = self.rope[start_pos : start_pos + seqlen] - query, key = apply_rotary_emb(query, key, rope=rope) + query, key = apply_rotary_emb( + query, key, rope, interleave=self.rotary_interleave + ) if self.layer_cache[1]["keys"].numel() != 0: key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) @@ -440,7 +452,9 @@ def forward( start_pos = 0 seqlen = query.size(2) rope = self.rope[start_pos : start_pos + seqlen].to(query.device) - query, key = apply_rotary_emb(query, key, rope=rope) + query, key = apply_rotary_emb( + query, key, rope, interleave=self.rotary_interleave + ) b, h, l, d = key.size() if self.num_kv > 0: diff --git a/onmt/opts.py b/onmt/opts.py index ce2902c784..38df61d30f 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -862,6 +862,17 @@ def model_opts(parser): help="This setting enable relative position bias" "more info: https://github.com/google-research/text-to-text-transfer-transformer", ) + group.add( + "--rotary_interleave", + "-rotary_interleave", + type=bool, + default=True, + help="Interleave the head dimensions when rotary" + " embeddings are applied." + " Otherwise the head dimensions are sliced in half." + "True = default Llama from Meta (original)" + "False = used by all Hugging face models", + ) group.add( "--heads", "-heads", @@ -927,7 +938,8 @@ def model_opts(parser): "-shared_layer_norm", action="store_true", help="Use a shared layer_norm in parallel residual attention" - "Note: must be true for Falcon 7B / false for Falcon 40B", + "Note: must be true for Falcon 7B / false for Falcon 40B" + "same for GPT-J and GPT-NeoX models", ) # Alignement options group = parser.add_argument_group("Model - Alignement")