diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index c8de6ab30..57d3fac3f 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -2,6 +2,7 @@ use std::{ collections::HashMap, + f32::consts::PI, ops::Mul, str::FromStr, sync::{ @@ -15,10 +16,13 @@ use candle_core::{ DType, Device, IndexOp, Result, Shape, Tensor, D, }; use candle_nn::{Linear, Module, VarBuilder}; +use serde::Deserialize; pub use crate::layers_masker::CausalMasker; pub use crate::layers_utils::{flash_attn, repeat_kv}; -use crate::{cublaslt::CUBLASLT_HANDLE, pipeline::Phi3RopeScaling, INHIBIT_GEMM_F16}; +use crate::{ + cublaslt::CUBLASLT_HANDLE, models::llama, pipeline::Phi3RopeScaling, INHIBIT_GEMM_F16, +}; #[derive(Debug, Clone)] pub struct RmsNorm { @@ -245,6 +249,144 @@ impl PhiRotaryEmbedding { } } +/// RoPE for Llama3 +#[derive(Debug, Clone)] +pub enum Llama3RotaryEmbedding { + Llama3 { + sin: Tensor, + cos: Tensor, + is_gptx: bool, + }, + Default(RotaryEmbedding), +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub enum Llama3RopeType { + #[serde(rename = "llama3")] + Llama3, + #[default] + #[serde(rename = "default")] + Default, +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct Llama3RopeConfig { + pub factor: f32, + pub low_freq_factor: f32, + pub high_freq_factor: f32, + pub original_max_position_embeddings: usize, + pub rope_type: Llama3RopeType, +} + +fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec { + let head_dim = cfg.hidden_size / cfg.num_attention_heads; + (0..head_dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32)) + .collect() +} + +// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298 +impl Llama3RotaryEmbedding { + pub fn new(dtype: DType, cfg: &llama::Config, dev: &Device, is_gpt_neox: bool) -> Result { + match &cfg.rope_scaling { + None + | Some(Llama3RopeConfig { + rope_type: Llama3RopeType::Default, + .. + }) => Ok(Self::Default(RotaryEmbedding::new( + cfg.rope_theta, + cfg.hidden_size / cfg.num_attention_heads, + cfg.max_position_embeddings, + dev, + is_gpt_neox, + dtype, + )?)), + Some(rope_scaling) => { + let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.low_freq_factor; + let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.high_freq_factor; + + let inv_freq = calculate_default_inv_freq(cfg) + .into_iter() + .map(|freq| { + let wavelen = 2. * PI / freq; + if wavelen < high_freq_wavelen { + freq + } else if wavelen > low_freq_wavelen { + freq / rope_scaling.factor + } else { + let smooth = (rope_scaling.original_max_position_embeddings as f32 + / wavelen + - rope_scaling.low_freq_factor) + / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor); + (1. - smooth) * freq / rope_scaling.factor + smooth * freq + } + }) + .collect::>(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + + let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)? + .to_dtype(DType::F32)? + .reshape((cfg.max_position_embeddings, 1))?; + let freqs = t.matmul(&inv_freq)?; + let sin = freqs.sin()?.to_dtype(dtype)?; + let cos = freqs.cos()?.to_dtype(dtype)?; + Ok(Self::Llama3 { + sin, + cos, + is_gptx: is_gpt_neox, + }) + } + } + } + + pub fn forward( + &self, + positions: &[usize], + positions_kernel: &Tensor, + q: &mut Tensor, + k: &mut Tensor, + b_sz: usize, + ) -> Result<()> { + match self { + Self::Llama3 { sin, cos, is_gptx } => { + let (b_sz_seq_len, h, n_embd) = q.dims3()?; + *q = q + .reshape((b_sz, b_sz_seq_len / b_sz, h, n_embd))? + .transpose(1, 2)?; + let (b_sz_seq_len, h, n_embd) = k.dims3()?; + *k = k + .reshape((b_sz, b_sz_seq_len / b_sz, h, n_embd))? + .transpose(1, 2)?; + + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let mut q_embeds = Vec::new(); + let mut k_embeds = Vec::new(); + for (i, offset) in positions.iter().enumerate() { + let cos = cos.narrow(0, *offset, seq_len)?; + let sin = sin.narrow(0, *offset, seq_len)?; + let rope = if *is_gptx { + candle_nn::rotary_emb::rope + } else { + candle_nn::rotary_emb::rope_i + }; + let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; + let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; + q_embeds.push(q_embed); + k_embeds.push(k_embed); + } + *q = Tensor::cat(&q_embeds, 0)?; + *k = Tensor::cat(&k_embeds, 0)?; + Ok(()) + } + Self::Default(rope) => rope.forward(positions, positions_kernel, q, k, b_sz), + } + } +} + /// Matrix multiplication, configurable to be via f16 (to use the faster GEMM kernels) optionally. pub struct MatMul; diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 12bf82078..67fc62506 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -1,9 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor}; -use candle_nn::{ - embedding, linear_no_bias as linear, Embedding, Module, RotaryEmbedding, VarBuilder, -}; +use candle_nn::{embedding, linear_no_bias as linear, Embedding, Module, VarBuilder}; use serde::Deserialize; use std::sync::Arc; @@ -14,7 +12,10 @@ use crate::{ }, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention}, + layers::{ + repeat_kv, CausalMasker, Llama3RopeConfig, Llama3RotaryEmbedding, MatMul, RmsNorm, + ScaledDotProductAttention, + }, layers_masker::PastKvLenCache, merge_delta, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, @@ -37,6 +38,7 @@ pub struct Config { pub rms_norm_eps: f64, pub rope_theta: f32, pub max_position_embeddings: usize, + pub rope_scaling: Option, } struct CausalSelfAttention { @@ -48,7 +50,7 @@ struct CausalSelfAttention { num_key_value_heads: usize, head_dim: usize, use_flash_attn: bool, - rotary_emb: Arc, + rotary_emb: Arc, max_seq_len: usize, paged_attn: Option, } @@ -167,7 +169,7 @@ impl CausalSelfAttention { fn load( vb: VarBuilder, cfg: &Config, - rope: Arc, + rope: Arc, paged_attn: Option, ) -> Result { let size_in = cfg.hidden_size; @@ -321,7 +323,7 @@ impl Block { mapper: &dyn DeviceMapper, layer_idx: usize, loading_isq: bool, - rope: Arc, + rope: Arc, paged_attn: Option, ) -> Result { let attn = CausalSelfAttention::load( @@ -440,15 +442,8 @@ impl Llama { .device_for(i, false) .unwrap_or(&normal_loading_metadata.real_device); let rotary_emb = Arc::new( - RotaryEmbedding::new( - cfg.rope_theta, - head_dim, - cfg.max_position_embeddings, - device, - is_gptx, - vb.dtype(), - ) - .expect("Failed to create RoPE"), + Llama3RotaryEmbedding::new(vb.dtype(), cfg, device, is_gptx) + .expect("Failed to create RoPE"), ); let paged_attn = match &attention_mechanism { AttentionImplementation::Eager => None, diff --git a/mistralrs-core/src/pipeline/normal_loaders.rs b/mistralrs-core/src/pipeline/normal_loaders.rs index 7b8460749..416716cff 100644 --- a/mistralrs-core/src/pipeline/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/normal_loaders.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, fmt::Debug, str::FromStr}; use crate::{ + layers::Llama3RopeConfig, lora::{LoraConfig, Ordering}, paged_attention::AttentionImplementation, }; @@ -307,6 +308,7 @@ struct LlamaBasicConfig { #[serde(default = "default_rope")] rope_theta: f32, max_position_embeddings: usize, + rope_scaling: Option, } fn default_rope() -> f32 { @@ -329,6 +331,7 @@ impl LlamaBasicConfig { rope_theta: basic_config.rope_theta, use_flash_attn, max_position_embeddings: basic_config.max_position_embeddings, + rope_scaling: basic_config.rope_scaling, }) } } diff --git a/mistralrs-core/src/vision_models/llava/config.rs b/mistralrs-core/src/vision_models/llava/config.rs index 87196ea57..038fb985f 100644 --- a/mistralrs-core/src/vision_models/llava/config.rs +++ b/mistralrs-core/src/vision_models/llava/config.rs @@ -1,6 +1,7 @@ use candle_nn::Activation; use serde::Deserialize; +use crate::layers::Llama3RopeConfig; use crate::serde_default_fn; use crate::models::llama::Config as LLaMAConfig; @@ -43,6 +44,7 @@ pub struct LLaVATextConfig { #[serde(default = "default_vocab_size")] pub vocab_size: usize, pub sliding_window: Option, + pub rope_scaling: Option, } serde_default_fn!(usize, default_num_hidden_layers, 32); @@ -77,6 +79,7 @@ impl Config { rms_norm_eps: self.text_config.rms_norm_eps, rope_theta: self.text_config.rope_theta, max_position_embeddings: self.text_config.max_position_embeddings, + rope_scaling: self.text_config.rope_scaling.clone(), } } diff --git a/mistralrs-core/src/xlora_models/llama.rs b/mistralrs-core/src/xlora_models/llama.rs index efb2d4e0f..a0af0ba17 100644 --- a/mistralrs-core/src/xlora_models/llama.rs +++ b/mistralrs-core/src/xlora_models/llama.rs @@ -2,14 +2,14 @@ use crate::{ amoe::AnyMoeBaseModelMixin, - layers::ScaledDotProductAttention, + layers::{Llama3RotaryEmbedding, ScaledDotProductAttention}, lora::{linear_no_bias as linear, LinearLayerLike, LoraConfig, Ordering}, paged_attention::ModelConfigMetadata, pipeline::{text_models_inputs_processor::PagedAttentionInputMetadata, IsqModel}, utils::progress::NiceProgressBar, }; use candle_core::{quantized::QMatMul, DType, Device, Result, Tensor}; -use candle_nn::{embedding, Embedding, Module, RotaryEmbedding, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; use std::{collections::HashMap, sync::Arc}; use tqdm::Iter; use tracing::info; @@ -33,7 +33,7 @@ struct CausalSelfAttention { num_key_value_heads: usize, head_dim: usize, use_flash_attn: bool, - rotary_emb: Arc, + rotary_emb: Arc, max_seq_len: usize, } @@ -158,7 +158,7 @@ impl CausalSelfAttention { mapper: &dyn DeviceMapper, layer_idx: usize, loading_isq: bool, - rope: Arc, + rope: Arc, preload_adapters: &Option>, ) -> Result { let size_in = cfg.hidden_size; @@ -369,7 +369,7 @@ impl Block { mapper: &dyn DeviceMapper, layer_idx: usize, loading_isq: bool, - rope: Arc, + rope: Arc, preload_adapters: &Option>, ) -> Result { let attn = CausalSelfAttention::load( @@ -610,21 +610,18 @@ impl XLoraLlama { cfg.rms_norm_eps, mapper.set_nm_device(vb.pp("model.norm"), false), )?; - let head_dim = cfg.hidden_size / cfg.num_attention_heads; let mut blocks: Vec<_> = NiceProgressBar::<_, 'b'>(0..cfg.num_hidden_layers, "Loading repeating layers") .into_iter() .map(|i| { let rotary_emb = Arc::new( - RotaryEmbedding::new( - cfg.rope_theta, - head_dim, - cfg.max_position_embeddings, + Llama3RotaryEmbedding::new( + vb.dtype(), + cfg, mapper .device_for(i, false) .unwrap_or(&normal_loading_metadata.real_device), is_gptx, - vb.dtype(), ) .expect("Failed to create RoPE"), );