Skip to content

Commit

Permalink
Reduce memory overhead from copy2d (#13)
Browse files Browse the repository at this point in the history
* WIP Slow implementation of kv

* Initial microbenchmark

* WIP isolate RoPE

* WIP Fix failures

* WIP temporarily disable copy2d optimization

* Move rep pen to GPU

* Final cleanup

* Clean up code after testing

* Fix RTF calculations for DualAR
  • Loading branch information
EndlessReform authored Oct 6, 2024
1 parent 43f2446 commit f543dae
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 109 deletions.
Empty file added Dockerfile
Empty file.
105 changes: 49 additions & 56 deletions fish_speech_core/lib/models/text2semantic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ pub mod utils;

use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{
embedding, ops::silu, ops::softmax_last_dim, Embedding, Linear, Module, RmsNorm, VarBuilder,
embedding, kv_cache::KvCache, ops::silu, ops::softmax_last_dim, Embedding, Linear, Module,
RmsNorm, VarBuilder,
};
use candle_transformers::utils::repeat_kv;
use serde::Deserialize;
use serde_json;
use std::fs::File;
Expand Down Expand Up @@ -116,18 +116,6 @@ impl Module for FeedForward {
}
}

pub struct Cache {
/// TODO: Does this require Arc<Mutex>>?
kvs: Option<(Tensor, Tensor)>,
}

impl Cache {
pub fn new() -> Result<Self> {
// Precompute freqs_cis
Ok(Self { kvs: None })
}
}

/// Returns (cos, sin) for the full possible batch size
fn precompute_freqs_cis(
config: &BaseModelArgs,
Expand All @@ -137,7 +125,7 @@ fn precompute_freqs_cis(
let n_elem = config.dim / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / (config.rope_base as f32).powf(i as f32 / n_elem as f32))
.map(|i| 1f32 / config.rope_base.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, config.max_seq_len as u32, device)?
Expand Down Expand Up @@ -173,17 +161,17 @@ pub struct Attention {
dim: usize,
wqkv: Linear,
wo: Linear,
cache: Cache,
cache: KvCache,
}

impl Attention {
pub fn load(vb: &VarBuilder, config: &BaseModelArgs) -> Result<Self> {
pub fn load(vb: &VarBuilder, config: &BaseModelArgs, is_fast: bool) -> Result<Self> {
let total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim;
// KQV for all heads, but in a batch
let wqkv = Linear::new(vb.get((total_head_dim, config.dim), "wqkv.weight")?, None);
let wo = Linear::new(vb.get((config.dim, config.dim), "wo.weight")?, None);

let cache = Cache::new()?;
let cache = KvCache::new(2, if is_fast { config.num_codebooks } else { 1024 });

Ok(Self {
n_head: config.n_head,
Expand All @@ -204,8 +192,8 @@ impl Attention {
cos: &Tensor,
sin: &Tensor,
) -> Result<(Tensor, Tensor)> {
let q_embed = candle_nn::rotary_emb::rope_i(&q, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope_i(&k, &cos, &sin)?;
let q_embed = candle_nn::rotary_emb::rope_i(q, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope_i(k, cos, sin)?;
Ok((q_embed, k_embed))
}

Expand Down Expand Up @@ -255,38 +243,38 @@ impl Attention {

let query_states = query_states
.reshape((bsz, seqlen, self.n_head, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let key_states = key_states
.reshape((bsz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let value_states = value_states
.reshape((bsz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;

// Logic copied from phi3.rs
let _seqlen_offset = match &self.cache.kvs {
None => 0,
Some((prev_k, _)) => prev_k.dim(2)?,
};
let (query_states, key_states) =
self.apply_rotary_emb_qkv(&query_states, &key_states, freqs_cis.0, freqs_cis.1)?;
let (key_states, value_states) = match &self.cache.kvs {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.cache.kvs = Some((key_states.clone(), value_states.clone()));
let (query_states, key_states) = self.apply_rotary_emb_qkv(
&query_states.contiguous()?,
&key_states.contiguous()?,
freqs_cis.0,
freqs_cis.1,
)?;

let (key_states, value_states) = self
.cache
.append(&key_states.contiguous()?, &value_states.contiguous()?)?;

// Repeat KV cache
let key_states = repeat_kv(key_states, self.n_head / self.n_local_heads)?.contiguous()?;
let value_states =
repeat_kv(value_states, self.n_head / self.n_local_heads)?.contiguous()?;
// Length changes after pulling
let kv_seqlen = key_states.dim(2)?;
let n_rep = self.n_head / self.n_local_heads;
// TODO: Consider whether there's a better way to do this with 2 copy2ds instead of ucopy
// https://github.com/huggingface/candle/pull/2043 got the duplication wrong but there might still be something there
let key_states = key_states
.unsqueeze(2)?
.expand((bsz, self.n_local_heads, n_rep, kv_seqlen, self.head_dim))?
.reshape((bsz, self.n_local_heads * n_rep, kv_seqlen, self.head_dim))?;
let value_states = value_states
.unsqueeze(2)?
.expand((bsz, self.n_local_heads, n_rep, kv_seqlen, self.head_dim))?
.reshape((bsz, self.n_local_heads * n_rep, kv_seqlen, self.head_dim))?;

// TODO: Add optional flash attention
let y =
Expand All @@ -297,7 +285,7 @@ impl Attention {
}

pub fn clear_cache(&mut self) {
self.cache.kvs = None;
self.cache.reset();
}
}

Expand All @@ -309,8 +297,8 @@ pub struct TransformerBlock {
}

impl TransformerBlock {
pub fn load(vb: &VarBuilder, cfg: &BaseModelArgs) -> Result<Self> {
let attention = Attention::load(&vb.pp("attention"), cfg)?;
pub fn load(vb: &VarBuilder, cfg: &BaseModelArgs, is_fast: bool) -> Result<Self> {
let attention = Attention::load(&vb.pp("attention"), cfg, is_fast)?;
let feed_forward = FeedForward::load(&vb.pp("feed_forward"), cfg)?;
let ffn_norm = RmsNorm::new(vb.get(cfg.dim, "ffn_norm.weight")?, cfg.norm_eps);
let attention_norm = RmsNorm::new(vb.get(cfg.dim, "attention_norm.weight")?, cfg.norm_eps);
Expand All @@ -331,10 +319,9 @@ impl TransformerBlock {
) -> Result<Tensor> {
let residual = x;
let x = self.attention_norm.forward(x)?;
let x = (self.attention.forward(&x, mask, freqs_cis)? + residual)?;
let x = (residual + self.attention.forward(&x, mask, freqs_cis)?)?;
let residual = &x;
let x = residual + self.feed_forward.forward(&self.ffn_norm.forward(&x)?);
x
residual + self.feed_forward.forward(&self.ffn_norm.forward(&x)?)
}
}

Expand Down Expand Up @@ -364,7 +351,7 @@ impl DualARTransformer {
cfg.dim,
);
let layers: Result<Vec<TransformerBlock>> = (0..cfg.n_layer)
.map(|l| TransformerBlock::load(&vb.pp(format!("layers.{}", l)), cfg))
.map(|l| TransformerBlock::load(&vb.pp(format!("layers.{}", l)), cfg, false))
.collect();
let layers = layers?;
let norm = RmsNorm::new(vb.get(cfg.dim, "norm.weight")?, cfg.norm_eps);
Expand All @@ -374,7 +361,7 @@ impl DualARTransformer {
cfg.dim,
);
let fast_layers: Result<Vec<TransformerBlock>> = (0..cfg.n_fast_layer)
.map(|l| TransformerBlock::load(&vb.pp(format!("fast_layers.{}", l)), cfg))
.map(|l| TransformerBlock::load(&vb.pp(format!("fast_layers.{}", l)), cfg, true))
.collect();
let fast_layers = fast_layers?;
let fast_norm = RmsNorm::new(vb.get(cfg.dim, "fast_norm.weight")?, cfg.norm_eps);
Expand Down Expand Up @@ -417,7 +404,7 @@ impl DualARTransformer {
// Offset the ranges for each codebook so they don't overlap
let codebook_tokens_shifted = codebook_tokens.broadcast_add(
&Tensor::arange_step(
0 as u32,
0,
(self.cfg.num_codebooks * self.cfg.codebook_size) as u32,
self.cfg.codebook_size as u32,
x.device(),
Expand All @@ -444,7 +431,7 @@ impl DualARTransformer {
let mask = get_mask(seq_len, x.device())?;

let (cos_full, sin_full) = &self.freqs_cis;
for (_, layer) in self.layers.iter_mut().enumerate() {
for layer in self.layers.iter_mut() {
x = layer.forward(
&x,
&mask,
Expand Down Expand Up @@ -494,4 +481,10 @@ impl DualARTransformer {
layer.attention.clear_cache();
}
}

pub fn clear_slow_layer_caches(&mut self) {
for layer in self.layers.iter_mut() {
layer.attention.clear_cache();
}
}
}
2 changes: 1 addition & 1 deletion fish_speech_core/lib/models/text2semantic/utils/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn encode_tokens(
let zeros = Tensor::zeros((num_codebooks, new_tokens.len()), DType::U32, device)?;
let prompt = Tensor::cat(&[tokens, zeros], 0)?;

if let None = prompt_tokens {
if prompt_tokens.is_none() {
return Ok(prompt);
}
let prompt_tokens = prompt_tokens.unwrap().to_dtype(DType::U32)?;
Expand Down
73 changes: 73 additions & 0 deletions fish_speech_core/lib/models/text2semantic/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,74 @@
use candle_core::{DType, Device, Result, Tensor};
use std::collections::{HashMap, VecDeque};

pub mod encode;

pub struct RepPenProcessor {
penalty_mask: Tensor,
one: Tensor,
penalty_amt: Tensor,
context: VecDeque<usize>,
tokens_seen: HashMap<usize, usize>,
max_ctxt_size: usize,
vocab_size: usize,
}

impl RepPenProcessor {
pub fn new(
vocab_size: usize,
max_ctxt_size: usize,
penalty_amt: f32,
dtype: DType,
device: &Device,
) -> Result<Self> {
let penalty_mask = Tensor::ones(vocab_size, dtype, device)?;
// Yes, this is inelegant, but there's no scalar interface to slice_set
let one = Tensor::ones(1, dtype, device)?;
let penalty_amt = Tensor::from_vec(vec![penalty_amt], 1, device)?.to_dtype(dtype)?;
Ok(Self {
penalty_mask,
one,
penalty_amt,
context: VecDeque::new(),
tokens_seen: HashMap::new(),
max_ctxt_size,
vocab_size,
})
}

pub fn apply(&mut self, logits: &Tensor, last_token: usize) -> Result<Tensor> {
if last_token >= self.vocab_size {
candle_core::bail!("Token must be within vocab size");
}

// Add latest token to penalties if it's not there already
let count = self.tokens_seen.entry(last_token).or_insert(1);
if *count == 1 {
// This is the first time we're penalizing the token in this window, so add to mask
self.penalty_mask
.slice_set(&self.penalty_amt, 0, last_token)?;
}
self.context.push_front(last_token);

if self.context.len() > self.max_ctxt_size {
// If the token falling out of the window is the last of its kind, un-penalize it
if let Some(dropped_token) = self.context.pop_back() {
if let Some(count) = self.tokens_seen.get_mut(&dropped_token) {
*count -= 1;
if *count == 0 {
self.tokens_seen.remove(&dropped_token);
self.penalty_mask.slice_set(&self.one, 0, dropped_token)?;
}
}
}
}

logits.broadcast_div(&self.penalty_mask)
}

pub fn clear_cache(&mut self) -> Result<()> {
self.penalty_mask = self.penalty_mask.ones_like()?;
self.context.clear();
Ok(())
}
}
Loading

0 comments on commit f543dae

Please sign in to comment.