-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reduce memory overhead from copy2d (#13)
* WIP Slow implementation of kv * Initial microbenchmark * WIP isolate RoPE * WIP Fix failures * WIP temporarily disable copy2d optimization * Move rep pen to GPU * Final cleanup * Clean up code after testing * Fix RTF calculations for DualAR
- Loading branch information
1 parent
43f2446
commit f543dae
Showing
5 changed files
with
165 additions
and
109 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,74 @@ | ||
use candle_core::{DType, Device, Result, Tensor}; | ||
use std::collections::{HashMap, VecDeque}; | ||
|
||
pub mod encode; | ||
|
||
pub struct RepPenProcessor { | ||
penalty_mask: Tensor, | ||
one: Tensor, | ||
penalty_amt: Tensor, | ||
context: VecDeque<usize>, | ||
tokens_seen: HashMap<usize, usize>, | ||
max_ctxt_size: usize, | ||
vocab_size: usize, | ||
} | ||
|
||
impl RepPenProcessor { | ||
pub fn new( | ||
vocab_size: usize, | ||
max_ctxt_size: usize, | ||
penalty_amt: f32, | ||
dtype: DType, | ||
device: &Device, | ||
) -> Result<Self> { | ||
let penalty_mask = Tensor::ones(vocab_size, dtype, device)?; | ||
// Yes, this is inelegant, but there's no scalar interface to slice_set | ||
let one = Tensor::ones(1, dtype, device)?; | ||
let penalty_amt = Tensor::from_vec(vec![penalty_amt], 1, device)?.to_dtype(dtype)?; | ||
Ok(Self { | ||
penalty_mask, | ||
one, | ||
penalty_amt, | ||
context: VecDeque::new(), | ||
tokens_seen: HashMap::new(), | ||
max_ctxt_size, | ||
vocab_size, | ||
}) | ||
} | ||
|
||
pub fn apply(&mut self, logits: &Tensor, last_token: usize) -> Result<Tensor> { | ||
if last_token >= self.vocab_size { | ||
candle_core::bail!("Token must be within vocab size"); | ||
} | ||
|
||
// Add latest token to penalties if it's not there already | ||
let count = self.tokens_seen.entry(last_token).or_insert(1); | ||
if *count == 1 { | ||
// This is the first time we're penalizing the token in this window, so add to mask | ||
self.penalty_mask | ||
.slice_set(&self.penalty_amt, 0, last_token)?; | ||
} | ||
self.context.push_front(last_token); | ||
|
||
if self.context.len() > self.max_ctxt_size { | ||
// If the token falling out of the window is the last of its kind, un-penalize it | ||
if let Some(dropped_token) = self.context.pop_back() { | ||
if let Some(count) = self.tokens_seen.get_mut(&dropped_token) { | ||
*count -= 1; | ||
if *count == 0 { | ||
self.tokens_seen.remove(&dropped_token); | ||
self.penalty_mask.slice_set(&self.one, 0, dropped_token)?; | ||
} | ||
} | ||
} | ||
} | ||
|
||
logits.broadcast_div(&self.penalty_mask) | ||
} | ||
|
||
pub fn clear_cache(&mut self) -> Result<()> { | ||
self.penalty_mask = self.penalty_mask.ones_like()?; | ||
self.context.clear(); | ||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.