Skip to content

Commit

Permalink
fix alloc/init cpu dequantize hqq
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot committed Oct 22, 2024
1 parent b162370 commit 62b3ae2
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions mistralrs-quant/src/hqq/hqq_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ pub(crate) struct Dequant4Bit {
}

impl Dequant4Bit {
fn dequantize<T: WithDType>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
let mut out = Vec::with_capacity(w.len());
fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
let output_size = w.len() * 2;
let mut out = vec![T::default(); output_size];
for (i, w) in w.iter().enumerate() {
let j = i % self.w;
let nrows = self.h * self.w;
Expand Down Expand Up @@ -125,8 +126,9 @@ pub(crate) struct Dequant2Bit {
}

impl Dequant2Bit {
fn dequantize<T: WithDType>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
let mut out = Vec::with_capacity(w.len());
fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
let output_size = w.len() * 4;
let mut out = vec![T::default(); output_size];
for (i, w) in w.iter().enumerate() {
let j = i % self.w;
let nrows = self.h * self.w;
Expand Down Expand Up @@ -187,8 +189,9 @@ pub(crate) struct Dequant1Bit {
}

impl Dequant1Bit {
fn dequantize<T: WithDType>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
let mut out = Vec::with_capacity(w.len());
fn dequantize<T: WithDType + Default>(&self, w: &[u8], s: &[T], z: &[T]) -> Vec<T> {
let output_size = w.len() * 8;
let mut out = vec![T::default(); output_size];
for (i, w) in w.iter().enumerate() {
let j = i % self.w;
let nrows = self.h * self.w;
Expand Down Expand Up @@ -253,8 +256,9 @@ pub(crate) struct Dequant3Bit {
}

impl Dequant3Bit {
fn dequantize<T: WithDType>(&self, w: &[i32], s: &[T], z: &[T]) -> Vec<T> {
let mut out = Vec::with_capacity(w.len());
fn dequantize<T: WithDType + Default>(&self, w: &[i32], s: &[T], z: &[T]) -> Vec<T> {
let output_size = w.len() * 10;
let mut out = vec![T::default(); output_size];
for (i, w) in w.iter().enumerate() {
let j = i % self.w;
let nrows = self.h * self.w;
Expand Down

0 comments on commit 62b3ae2

Please sign in to comment.