diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e69de29 diff --git a/fish_speech_core/lib/models/text2semantic/mod.rs b/fish_speech_core/lib/models/text2semantic/mod.rs index fab1f12..8709f92 100644 --- a/fish_speech_core/lib/models/text2semantic/mod.rs +++ b/fish_speech_core/lib/models/text2semantic/mod.rs @@ -2,9 +2,9 @@ pub mod utils; use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{ - embedding, ops::silu, ops::softmax_last_dim, Embedding, Linear, Module, RmsNorm, VarBuilder, + embedding, kv_cache::KvCache, ops::silu, ops::softmax_last_dim, Embedding, Linear, Module, + RmsNorm, VarBuilder, }; -use candle_transformers::utils::repeat_kv; use serde::Deserialize; use serde_json; use std::fs::File; @@ -116,18 +116,6 @@ impl Module for FeedForward { } } -pub struct Cache { - /// TODO: Does this require Arc>? - kvs: Option<(Tensor, Tensor)>, -} - -impl Cache { - pub fn new() -> Result { - // Precompute freqs_cis - Ok(Self { kvs: None }) - } -} - /// Returns (cos, sin) for the full possible batch size fn precompute_freqs_cis( config: &BaseModelArgs, @@ -137,7 +125,7 @@ fn precompute_freqs_cis( let n_elem = config.dim / config.n_head; let theta: Vec<_> = (0..n_elem) .step_by(2) - .map(|i| 1f32 / (config.rope_base as f32).powf(i as f32 / n_elem as f32)) + .map(|i| 1f32 / config.rope_base.powf(i as f32 / n_elem as f32)) .collect(); let theta = Tensor::new(theta.as_slice(), device)?; let idx_theta = Tensor::arange(0, config.max_seq_len as u32, device)? @@ -173,17 +161,17 @@ pub struct Attention { dim: usize, wqkv: Linear, wo: Linear, - cache: Cache, + cache: KvCache, } impl Attention { - pub fn load(vb: &VarBuilder, config: &BaseModelArgs) -> Result { + pub fn load(vb: &VarBuilder, config: &BaseModelArgs, is_fast: bool) -> Result { let total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim; // KQV for all heads, but in a batch let wqkv = Linear::new(vb.get((total_head_dim, config.dim), "wqkv.weight")?, None); let wo = Linear::new(vb.get((config.dim, config.dim), "wo.weight")?, None); - let cache = Cache::new()?; + let cache = KvCache::new(2, if is_fast { config.num_codebooks } else { 1024 }); Ok(Self { n_head: config.n_head, @@ -204,8 +192,8 @@ impl Attention { cos: &Tensor, sin: &Tensor, ) -> Result<(Tensor, Tensor)> { - let q_embed = candle_nn::rotary_emb::rope_i(&q, &cos, &sin)?; - let k_embed = candle_nn::rotary_emb::rope_i(&k, &cos, &sin)?; + let q_embed = candle_nn::rotary_emb::rope_i(q, cos, sin)?; + let k_embed = candle_nn::rotary_emb::rope_i(k, cos, sin)?; Ok((q_embed, k_embed)) } @@ -255,38 +243,38 @@ impl Attention { let query_states = query_states .reshape((bsz, seqlen, self.n_head, self.head_dim))? - .transpose(1, 2)? - .contiguous()?; + .transpose(1, 2)?; let key_states = key_states .reshape((bsz, seqlen, self.n_local_heads, self.head_dim))? - .transpose(1, 2)? - .contiguous()?; + .transpose(1, 2)?; let value_states = value_states .reshape((bsz, seqlen, self.n_local_heads, self.head_dim))? - .transpose(1, 2)? - .contiguous()?; + .transpose(1, 2)?; - // Logic copied from phi3.rs - let _seqlen_offset = match &self.cache.kvs { - None => 0, - Some((prev_k, _)) => prev_k.dim(2)?, - }; - let (query_states, key_states) = - self.apply_rotary_emb_qkv(&query_states, &key_states, freqs_cis.0, freqs_cis.1)?; - let (key_states, value_states) = match &self.cache.kvs { - None => (key_states, value_states), - Some((prev_k, prev_v)) => { - let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; - let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; - (key_states, value_states) - } - }; - self.cache.kvs = Some((key_states.clone(), value_states.clone())); + let (query_states, key_states) = self.apply_rotary_emb_qkv( + &query_states.contiguous()?, + &key_states.contiguous()?, + freqs_cis.0, + freqs_cis.1, + )?; + + let (key_states, value_states) = self + .cache + .append(&key_states.contiguous()?, &value_states.contiguous()?)?; - // Repeat KV cache - let key_states = repeat_kv(key_states, self.n_head / self.n_local_heads)?.contiguous()?; - let value_states = - repeat_kv(value_states, self.n_head / self.n_local_heads)?.contiguous()?; + // Length changes after pulling + let kv_seqlen = key_states.dim(2)?; + let n_rep = self.n_head / self.n_local_heads; + // TODO: Consider whether there's a better way to do this with 2 copy2ds instead of ucopy + // https://github.com/huggingface/candle/pull/2043 got the duplication wrong but there might still be something there + let key_states = key_states + .unsqueeze(2)? + .expand((bsz, self.n_local_heads, n_rep, kv_seqlen, self.head_dim))? + .reshape((bsz, self.n_local_heads * n_rep, kv_seqlen, self.head_dim))?; + let value_states = value_states + .unsqueeze(2)? + .expand((bsz, self.n_local_heads, n_rep, kv_seqlen, self.head_dim))? + .reshape((bsz, self.n_local_heads * n_rep, kv_seqlen, self.head_dim))?; // TODO: Add optional flash attention let y = @@ -297,7 +285,7 @@ impl Attention { } pub fn clear_cache(&mut self) { - self.cache.kvs = None; + self.cache.reset(); } } @@ -309,8 +297,8 @@ pub struct TransformerBlock { } impl TransformerBlock { - pub fn load(vb: &VarBuilder, cfg: &BaseModelArgs) -> Result { - let attention = Attention::load(&vb.pp("attention"), cfg)?; + pub fn load(vb: &VarBuilder, cfg: &BaseModelArgs, is_fast: bool) -> Result { + let attention = Attention::load(&vb.pp("attention"), cfg, is_fast)?; let feed_forward = FeedForward::load(&vb.pp("feed_forward"), cfg)?; let ffn_norm = RmsNorm::new(vb.get(cfg.dim, "ffn_norm.weight")?, cfg.norm_eps); let attention_norm = RmsNorm::new(vb.get(cfg.dim, "attention_norm.weight")?, cfg.norm_eps); @@ -331,10 +319,9 @@ impl TransformerBlock { ) -> Result { let residual = x; let x = self.attention_norm.forward(x)?; - let x = (self.attention.forward(&x, mask, freqs_cis)? + residual)?; + let x = (residual + self.attention.forward(&x, mask, freqs_cis)?)?; let residual = &x; - let x = residual + self.feed_forward.forward(&self.ffn_norm.forward(&x)?); - x + residual + self.feed_forward.forward(&self.ffn_norm.forward(&x)?) } } @@ -364,7 +351,7 @@ impl DualARTransformer { cfg.dim, ); let layers: Result> = (0..cfg.n_layer) - .map(|l| TransformerBlock::load(&vb.pp(format!("layers.{}", l)), cfg)) + .map(|l| TransformerBlock::load(&vb.pp(format!("layers.{}", l)), cfg, false)) .collect(); let layers = layers?; let norm = RmsNorm::new(vb.get(cfg.dim, "norm.weight")?, cfg.norm_eps); @@ -374,7 +361,7 @@ impl DualARTransformer { cfg.dim, ); let fast_layers: Result> = (0..cfg.n_fast_layer) - .map(|l| TransformerBlock::load(&vb.pp(format!("fast_layers.{}", l)), cfg)) + .map(|l| TransformerBlock::load(&vb.pp(format!("fast_layers.{}", l)), cfg, true)) .collect(); let fast_layers = fast_layers?; let fast_norm = RmsNorm::new(vb.get(cfg.dim, "fast_norm.weight")?, cfg.norm_eps); @@ -417,7 +404,7 @@ impl DualARTransformer { // Offset the ranges for each codebook so they don't overlap let codebook_tokens_shifted = codebook_tokens.broadcast_add( &Tensor::arange_step( - 0 as u32, + 0, (self.cfg.num_codebooks * self.cfg.codebook_size) as u32, self.cfg.codebook_size as u32, x.device(), @@ -444,7 +431,7 @@ impl DualARTransformer { let mask = get_mask(seq_len, x.device())?; let (cos_full, sin_full) = &self.freqs_cis; - for (_, layer) in self.layers.iter_mut().enumerate() { + for layer in self.layers.iter_mut() { x = layer.forward( &x, &mask, @@ -494,4 +481,10 @@ impl DualARTransformer { layer.attention.clear_cache(); } } + + pub fn clear_slow_layer_caches(&mut self) { + for layer in self.layers.iter_mut() { + layer.attention.clear_cache(); + } + } } diff --git a/fish_speech_core/lib/models/text2semantic/utils/encode.rs b/fish_speech_core/lib/models/text2semantic/utils/encode.rs index 890b848..1b392ac 100644 --- a/fish_speech_core/lib/models/text2semantic/utils/encode.rs +++ b/fish_speech_core/lib/models/text2semantic/utils/encode.rs @@ -23,7 +23,7 @@ pub fn encode_tokens( let zeros = Tensor::zeros((num_codebooks, new_tokens.len()), DType::U32, device)?; let prompt = Tensor::cat(&[tokens, zeros], 0)?; - if let None = prompt_tokens { + if prompt_tokens.is_none() { return Ok(prompt); } let prompt_tokens = prompt_tokens.unwrap().to_dtype(DType::U32)?; diff --git a/fish_speech_core/lib/models/text2semantic/utils/mod.rs b/fish_speech_core/lib/models/text2semantic/utils/mod.rs index ec5a769..1d4b151 100644 --- a/fish_speech_core/lib/models/text2semantic/utils/mod.rs +++ b/fish_speech_core/lib/models/text2semantic/utils/mod.rs @@ -1 +1,74 @@ +use candle_core::{DType, Device, Result, Tensor}; +use std::collections::{HashMap, VecDeque}; + pub mod encode; + +pub struct RepPenProcessor { + penalty_mask: Tensor, + one: Tensor, + penalty_amt: Tensor, + context: VecDeque, + tokens_seen: HashMap, + max_ctxt_size: usize, + vocab_size: usize, +} + +impl RepPenProcessor { + pub fn new( + vocab_size: usize, + max_ctxt_size: usize, + penalty_amt: f32, + dtype: DType, + device: &Device, + ) -> Result { + let penalty_mask = Tensor::ones(vocab_size, dtype, device)?; + // Yes, this is inelegant, but there's no scalar interface to slice_set + let one = Tensor::ones(1, dtype, device)?; + let penalty_amt = Tensor::from_vec(vec![penalty_amt], 1, device)?.to_dtype(dtype)?; + Ok(Self { + penalty_mask, + one, + penalty_amt, + context: VecDeque::new(), + tokens_seen: HashMap::new(), + max_ctxt_size, + vocab_size, + }) + } + + pub fn apply(&mut self, logits: &Tensor, last_token: usize) -> Result { + if last_token >= self.vocab_size { + candle_core::bail!("Token must be within vocab size"); + } + + // Add latest token to penalties if it's not there already + let count = self.tokens_seen.entry(last_token).or_insert(1); + if *count == 1 { + // This is the first time we're penalizing the token in this window, so add to mask + self.penalty_mask + .slice_set(&self.penalty_amt, 0, last_token)?; + } + self.context.push_front(last_token); + + if self.context.len() > self.max_ctxt_size { + // If the token falling out of the window is the last of its kind, un-penalize it + if let Some(dropped_token) = self.context.pop_back() { + if let Some(count) = self.tokens_seen.get_mut(&dropped_token) { + *count -= 1; + if *count == 0 { + self.tokens_seen.remove(&dropped_token); + self.penalty_mask.slice_set(&self.one, 0, dropped_token)?; + } + } + } + } + + logits.broadcast_div(&self.penalty_mask) + } + + pub fn clear_cache(&mut self) -> Result<()> { + self.penalty_mask = self.penalty_mask.ones_like()?; + self.context.clear(); + Ok(()) + } +} diff --git a/fish_speech_core/src/bin/llama_generate.rs b/fish_speech_core/src/bin/llama_generate.rs index ddc0c61..65965a8 100644 --- a/fish_speech_core/src/bin/llama_generate.rs +++ b/fish_speech_core/src/bin/llama_generate.rs @@ -2,9 +2,8 @@ use anyhow::Error; use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Module, VarBuilder}; use candle_transformers::generation::{LogitsProcessor, Sampling}; -use candle_transformers::utils::apply_repeat_penalty; use clap::Parser; -use fish_speech_core::models::text2semantic::utils::encode::encode_tokens; +use fish_speech_core::models::text2semantic::utils::{encode::encode_tokens, RepPenProcessor}; use fish_speech_core::models::text2semantic::{BaseModelArgs, DualARTransformer}; use fish_speech_core::models::vqgan::config::WhichModel; use indicatif::{ProgressBar, ProgressStyle}; @@ -28,20 +27,6 @@ use tokenizers::Tokenizer; // Ok(()) // } -fn apply_rep_pen( - logits: &Tensor, - tokens: &[u32], - rep_pen: f32, - repeat_last_n: usize, -) -> Result { - if rep_pen == 1. { - Ok(logits.clone()) - } else { - let start_at = tokens.len().saturating_sub(repeat_last_n); - apply_repeat_penalty(&logits, rep_pen, &tokens[start_at..]) - } -} - /// Extremely stripped-down softmax for two tokens fn softmax_sample(pad_prob: f32, eos_prob: f32, pad_id: u32, eos_id: u32) -> u32 { // Compute softmax @@ -71,12 +56,11 @@ fn decode_one_token_ar( input_pos: usize, im_end_id: u32, pad_id: u32, - previous_tokens: Option<&Tensor>, - sampling_args: &SamplingArgs, -) -> Result { + previous_token: Option>, + rep_pens: &mut [RepPenProcessor], +) -> Result<(Vec, Tensor)> { let (logits, hidden_states) = model.forward_generate(&x, input_pos)?; let slow_logits = logits.flatten_all()?; - let repeat_window_size = 16; let pad_prob = slow_logits .i(pad_id as usize)? @@ -93,27 +77,23 @@ fn decode_one_token_ar( let mut x = hidden_states; for codebook_idx in 0..model.cfg.num_codebooks { - // TODO: Figure out what the heck input_pos is let logits = model .forward_generate_fast(&x, codebook_idx)? .flatten_all()?; - let logits_adj = match previous_tokens { - Some(ctxt) => apply_rep_pen( - &logits, - &ctxt.i((codebook_idx + 1, ..))?.to_vec1()?, - sampling_args.repetition_penalty, - repeat_window_size, - )?, - None => logits, + let logits_adj = match &previous_token { + None => logits.clone(), + Some(t) => rep_pens[codebook_idx].apply(&logits, t[codebook_idx + 1] as usize)?, }; let a = fast_logits_processor.sample(&logits_adj.flatten_all()?)?; - // println!("Codebook shape: {:?}", prev_codes[codebook_idx + 1].shape()); let a_tensor = Tensor::from_slice(&[a], 1, x.device())?; x = model.fast_embeddings.forward(&a_tensor)?.unsqueeze(0)?; codebooks.push(a); } - Tensor::from_vec(codebooks, model.cfg.num_codebooks + 1, x.device())?.unsqueeze(D::Minus1) + let codes_tensor = + Tensor::from_vec(codebooks.clone(), model.cfg.num_codebooks + 1, x.device())? + .unsqueeze(D::Minus1)?; + Ok((codebooks, codes_tensor)) } /// Takes a conditioning sequence as input and generates as many tokens as requested @@ -134,8 +114,21 @@ fn generate( }, }; let mut fast_logits_processor = LogitsProcessor::from_sampling(42, sampling); + let maybe_fast_rep_pens: Result> = (0..model.cfg.num_codebooks) + .map(|_| { + RepPenProcessor::new( + model.cfg.codebook_size, + 16, + sampling_args.repetition_penalty, + model.fast_embeddings.embeddings().dtype(), + model.fast_embeddings.embeddings().device(), + ) + }) + .collect(); + let mut fast_rep_pens = maybe_fast_rep_pens?; + let start_pp = Instant::now(); - let mut cur_token = decode_one_token_ar( + let (mut previous_token, mut cur_token) = decode_one_token_ar( model, &mut fast_logits_processor, prompt, @@ -143,7 +136,7 @@ fn generate( im_end_id, pad_id, None, - &sampling_args, + &mut fast_rep_pens, )?; let dt = start_pp.elapsed(); let mut input_pos = prompt.dim(D::Minus1)?; @@ -167,26 +160,25 @@ fn generate( let start_decode = Instant::now(); for i in 1..max_new_tokens { - let next_token = decode_one_token_ar( + let (next_indices, next_token) = decode_one_token_ar( model, &mut fast_logits_processor, &cur_token, input_pos, im_end_id, pad_id, - Some(&previous_tokens), - sampling_args, + Some(previous_token), + &mut fast_rep_pens, )?; previous_tokens = Tensor::cat(&[previous_tokens, next_token.clone()], D::Minus1)?; spinner.inc(1); spinner.set_message(format!("Tokens: {}", i)); - if let Some(semantic_token) = next_token.i((0, 0))?.to_vec0::().ok() { - if semantic_token == im_end_id { - break; - } + if next_indices[0] == im_end_id { + break; } input_pos += 1; cur_token = next_token; + previous_token = next_indices; } let dt = start_decode.elapsed(); let out_len = previous_tokens.dim(1)? as f64; @@ -194,8 +186,8 @@ fn generate( "{} tokens generated ({:.2} tokens/s, {:.3}ms / token, RTF: {:.3})", out_len, out_len / dt.as_secs_f64(), - (dt.as_secs_f64() * 1e3) / out_len, - (out_len / 43.07) / dt.as_secs_f64() + (dt.as_secs_f64() * 1e3) / (out_len - 1f64), + (out_len / 21.535) / dt.as_secs_f64() ); previous_tokens.i((1.., ..)) } @@ -214,11 +206,11 @@ fn generate_long( }; let conditioning_prompts = - load_prompt_texts(&args.prompt_tokens, args.prompt_text.clone(), &device)?; + load_prompt_texts(&args.prompt_tokens, args.prompt_text.clone(), device)?; let encoded_prompts: Result = conditioning_prompts .iter() - .map(|(t, c)| encode_tokens(&tokenizer, &t, &device, Some(c), model.cfg.num_codebooks)) + .map(|(t, c)| encode_tokens(tokenizer, t, device, Some(c), model.cfg.num_codebooks)) .try_fold( Tensor::from_slice( &(vec![] as Vec), @@ -231,7 +223,7 @@ fn generate_long( let encoded = vec![encode_tokens( &tokenizer, &args.text, - &device, + device, None, model.cfg.num_codebooks, )?]; @@ -251,6 +243,7 @@ fn generate_long( pad_id, &sampling_args, )?; + model.clear_slow_layer_caches(); let res = res.broadcast_sub(&Tensor::ones_like(&res)?)?; res.write_npy(&args.out_path)?; @@ -265,7 +258,7 @@ struct SamplingArgs { } fn load_prompt_texts( - prompt_tokens: &Vec, + prompt_tokens: &[PathBuf], prompt_texts: Vec, device: &Device, ) -> anyhow::Result> { @@ -277,13 +270,10 @@ fn load_prompt_texts( )))? } - let codes: Result> = prompt_tokens - .iter() - .map(|path| Tensor::read_npy(path)) - .collect(); + let codes: Result> = prompt_tokens.iter().map(Tensor::read_npy).collect(); let codes: Result> = codes?.into_iter().map(|c| c.to_device(device)).collect(); - Ok(prompt_texts.into_iter().zip(codes?.into_iter()).collect()) + Ok(prompt_texts.into_iter().zip(codes?).collect()) } #[derive(Parser, Debug)] @@ -358,7 +348,7 @@ fn main() -> anyhow::Result<()> { let config = BaseModelArgs::from_json_file(checkpoint_dir.join("config.json"))?; let tokenizer = Tokenizer::from_file(checkpoint_dir.join("tokenizer.json")).unwrap(); // TODO: Figure out why BF16 is breaking on Metal - #[cfg(any(feature = "cuda"))] + #[cfg(feature = "cuda")] let dtype = DType::BF16; #[cfg(not(feature = "cuda"))] let dtype = DType::F32;