Skip to content

Commit

Permalink
Support Llama 3.1 scaled rope (#618)
Browse files Browse the repository at this point in the history
* Support llama 3.1 rope

* Rename in enum

* Clippy

* Reshape
  • Loading branch information
EricLBuehler authored Jul 24, 2024
1 parent 9569269 commit 7acdd1c
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 28 deletions.
144 changes: 143 additions & 1 deletion mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::{
collections::HashMap,
f32::consts::PI,
ops::Mul,
str::FromStr,
sync::{
Expand All @@ -15,10 +16,13 @@ use candle_core::{
DType, Device, IndexOp, Result, Shape, Tensor, D,
};
use candle_nn::{Linear, Module, VarBuilder};
use serde::Deserialize;

pub use crate::layers_masker::CausalMasker;
pub use crate::layers_utils::{flash_attn, repeat_kv};
use crate::{cublaslt::CUBLASLT_HANDLE, pipeline::Phi3RopeScaling, INHIBIT_GEMM_F16};
use crate::{
cublaslt::CUBLASLT_HANDLE, models::llama, pipeline::Phi3RopeScaling, INHIBIT_GEMM_F16,
};

#[derive(Debug, Clone)]
pub struct RmsNorm {
Expand Down Expand Up @@ -245,6 +249,144 @@ impl PhiRotaryEmbedding {
}
}

/// RoPE for Llama3
#[derive(Debug, Clone)]
pub enum Llama3RotaryEmbedding {
Llama3 {
sin: Tensor,
cos: Tensor,
is_gptx: bool,
},
Default(RotaryEmbedding),
}

#[derive(Debug, Clone, Deserialize, Default)]
pub enum Llama3RopeType {
#[serde(rename = "llama3")]
Llama3,
#[default]
#[serde(rename = "default")]
Default,
}

#[derive(Debug, Clone, Deserialize, Default)]
pub struct Llama3RopeConfig {
pub factor: f32,
pub low_freq_factor: f32,
pub high_freq_factor: f32,
pub original_max_position_embeddings: usize,
pub rope_type: Llama3RopeType,
}

fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
(0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
.collect()
}

// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298
impl Llama3RotaryEmbedding {
pub fn new(dtype: DType, cfg: &llama::Config, dev: &Device, is_gpt_neox: bool) -> Result<Self> {
match &cfg.rope_scaling {
None
| Some(Llama3RopeConfig {
rope_type: Llama3RopeType::Default,
..
}) => Ok(Self::Default(RotaryEmbedding::new(
cfg.rope_theta,
cfg.hidden_size / cfg.num_attention_heads,
cfg.max_position_embeddings,
dev,
is_gpt_neox,
dtype,
)?)),
Some(rope_scaling) => {
let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
/ rope_scaling.low_freq_factor;
let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
/ rope_scaling.high_freq_factor;

let inv_freq = calculate_default_inv_freq(cfg)
.into_iter()
.map(|freq| {
let wavelen = 2. * PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / rope_scaling.factor
} else {
let smooth = (rope_scaling.original_max_position_embeddings as f32
/ wavelen
- rope_scaling.low_freq_factor)
/ (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
(1. - smooth) * freq / rope_scaling.factor + smooth * freq
}
})
.collect::<Vec<_>>();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;

let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
Ok(Self::Llama3 {
sin,
cos,
is_gptx: is_gpt_neox,
})
}
}
}

pub fn forward(
&self,
positions: &[usize],
positions_kernel: &Tensor,
q: &mut Tensor,
k: &mut Tensor,
b_sz: usize,
) -> Result<()> {
match self {
Self::Llama3 { sin, cos, is_gptx } => {
let (b_sz_seq_len, h, n_embd) = q.dims3()?;
*q = q
.reshape((b_sz, b_sz_seq_len / b_sz, h, n_embd))?
.transpose(1, 2)?;
let (b_sz_seq_len, h, n_embd) = k.dims3()?;
*k = k
.reshape((b_sz, b_sz_seq_len / b_sz, h, n_embd))?
.transpose(1, 2)?;

let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let mut q_embeds = Vec::new();
let mut k_embeds = Vec::new();
for (i, offset) in positions.iter().enumerate() {
let cos = cos.narrow(0, *offset, seq_len)?;
let sin = sin.narrow(0, *offset, seq_len)?;
let rope = if *is_gptx {
candle_nn::rotary_emb::rope
} else {
candle_nn::rotary_emb::rope_i
};
let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
q_embeds.push(q_embed);
k_embeds.push(k_embed);
}
*q = Tensor::cat(&q_embeds, 0)?;
*k = Tensor::cat(&k_embeds, 0)?;
Ok(())
}
Self::Default(rope) => rope.forward(positions, positions_kernel, q, k, b_sz),
}
}
}

/// Matrix multiplication, configurable to be via f16 (to use the faster GEMM kernels) optionally.
pub struct MatMul;

Expand Down
27 changes: 11 additions & 16 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor};
use candle_nn::{
embedding, linear_no_bias as linear, Embedding, Module, RotaryEmbedding, VarBuilder,
};
use candle_nn::{embedding, linear_no_bias as linear, Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;

Expand All @@ -14,7 +12,10 @@ use crate::{
},
device_map::DeviceMapper,
get_delta_from_lora_ab,
layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention},
layers::{
repeat_kv, CausalMasker, Llama3RopeConfig, Llama3RotaryEmbedding, MatMul, RmsNorm,
ScaledDotProductAttention,
},
layers_masker::PastKvLenCache,
merge_delta,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
Expand All @@ -37,6 +38,7 @@ pub struct Config {
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub max_position_embeddings: usize,
pub rope_scaling: Option<Llama3RopeConfig>,
}

struct CausalSelfAttention {
Expand All @@ -48,7 +50,7 @@ struct CausalSelfAttention {
num_key_value_heads: usize,
head_dim: usize,
use_flash_attn: bool,
rotary_emb: Arc<RotaryEmbedding>,
rotary_emb: Arc<Llama3RotaryEmbedding>,
max_seq_len: usize,
paged_attn: Option<PagedAttention>,
}
Expand Down Expand Up @@ -167,7 +169,7 @@ impl CausalSelfAttention {
fn load(
vb: VarBuilder,
cfg: &Config,
rope: Arc<RotaryEmbedding>,
rope: Arc<Llama3RotaryEmbedding>,
paged_attn: Option<PagedAttention>,
) -> Result<Self> {
let size_in = cfg.hidden_size;
Expand Down Expand Up @@ -321,7 +323,7 @@ impl Block {
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
rope: Arc<RotaryEmbedding>,
rope: Arc<Llama3RotaryEmbedding>,
paged_attn: Option<PagedAttention>,
) -> Result<Self> {
let attn = CausalSelfAttention::load(
Expand Down Expand Up @@ -440,15 +442,8 @@ impl Llama {
.device_for(i, false)
.unwrap_or(&normal_loading_metadata.real_device);
let rotary_emb = Arc::new(
RotaryEmbedding::new(
cfg.rope_theta,
head_dim,
cfg.max_position_embeddings,
device,
is_gptx,
vb.dtype(),
)
.expect("Failed to create RoPE"),
Llama3RotaryEmbedding::new(vb.dtype(), cfg, device, is_gptx)
.expect("Failed to create RoPE"),
);
let paged_attn = match &attention_mechanism {
AttentionImplementation::Eager => None,
Expand Down
3 changes: 3 additions & 0 deletions mistralrs-core/src/pipeline/normal_loaders.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{collections::HashMap, fmt::Debug, str::FromStr};

use crate::{
layers::Llama3RopeConfig,
lora::{LoraConfig, Ordering},
paged_attention::AttentionImplementation,
};
Expand Down Expand Up @@ -307,6 +308,7 @@ struct LlamaBasicConfig {
#[serde(default = "default_rope")]
rope_theta: f32,
max_position_embeddings: usize,
rope_scaling: Option<Llama3RopeConfig>,
}

fn default_rope() -> f32 {
Expand All @@ -329,6 +331,7 @@ impl LlamaBasicConfig {
rope_theta: basic_config.rope_theta,
use_flash_attn,
max_position_embeddings: basic_config.max_position_embeddings,
rope_scaling: basic_config.rope_scaling,
})
}
}
Expand Down
3 changes: 3 additions & 0 deletions mistralrs-core/src/vision_models/llava/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use candle_nn::Activation;
use serde::Deserialize;

use crate::layers::Llama3RopeConfig;
use crate::serde_default_fn;

use crate::models::llama::Config as LLaMAConfig;
Expand Down Expand Up @@ -43,6 +44,7 @@ pub struct LLaVATextConfig {
#[serde(default = "default_vocab_size")]
pub vocab_size: usize,
pub sliding_window: Option<usize>,
pub rope_scaling: Option<Llama3RopeConfig>,
}

serde_default_fn!(usize, default_num_hidden_layers, 32);
Expand Down Expand Up @@ -77,6 +79,7 @@ impl Config {
rms_norm_eps: self.text_config.rms_norm_eps,
rope_theta: self.text_config.rope_theta,
max_position_embeddings: self.text_config.max_position_embeddings,
rope_scaling: self.text_config.rope_scaling.clone(),
}
}

Expand Down
19 changes: 8 additions & 11 deletions mistralrs-core/src/xlora_models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

use crate::{
amoe::AnyMoeBaseModelMixin,
layers::ScaledDotProductAttention,
layers::{Llama3RotaryEmbedding, ScaledDotProductAttention},
lora::{linear_no_bias as linear, LinearLayerLike, LoraConfig, Ordering},
paged_attention::ModelConfigMetadata,
pipeline::{text_models_inputs_processor::PagedAttentionInputMetadata, IsqModel},
utils::progress::NiceProgressBar,
};
use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor};
use candle_nn::{embedding, Embedding, Module, RotaryEmbedding, VarBuilder};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::{collections::HashMap, sync::Arc};
use tqdm::Iter;
use tracing::info;
Expand All @@ -33,7 +33,7 @@ struct CausalSelfAttention {
num_key_value_heads: usize,
head_dim: usize,
use_flash_attn: bool,
rotary_emb: Arc<RotaryEmbedding>,
rotary_emb: Arc<Llama3RotaryEmbedding>,
max_seq_len: usize,
}

Expand Down Expand Up @@ -158,7 +158,7 @@ impl CausalSelfAttention {
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
rope: Arc<RotaryEmbedding>,
rope: Arc<Llama3RotaryEmbedding>,
preload_adapters: &Option<HashMap<String, (VarBuilder, LoraConfig)>>,
) -> Result<Self> {
let size_in = cfg.hidden_size;
Expand Down Expand Up @@ -369,7 +369,7 @@ impl Block {
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
rope: Arc<RotaryEmbedding>,
rope: Arc<Llama3RotaryEmbedding>,
preload_adapters: &Option<HashMap<String, (VarBuilder, LoraConfig)>>,
) -> Result<Self> {
let attn = CausalSelfAttention::load(
Expand Down Expand Up @@ -610,21 +610,18 @@ impl XLoraLlama {
cfg.rms_norm_eps,
mapper.set_nm_device(vb.pp("model.norm"), false),
)?;
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
let mut blocks: Vec<_> =
NiceProgressBar::<_, 'b'>(0..cfg.num_hidden_layers, "Loading repeating layers")
.into_iter()
.map(|i| {
let rotary_emb = Arc::new(
RotaryEmbedding::new(
cfg.rope_theta,
head_dim,
cfg.max_position_embeddings,
Llama3RotaryEmbedding::new(
vb.dtype(),
cfg,
mapper
.device_for(i, false)
.unwrap_or(&normal_loading_metadata.real_device),
is_gptx,
vb.dtype(),
)
.expect("Failed to create RoPE"),
);
Expand Down

0 comments on commit 7acdd1c

Please sign in to comment.