Skip to content

Commit

Permalink
Add quantize_onto
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Aug 9, 2024
1 parent b7d9af0 commit ab3c58e
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 1 deletion.
17 changes: 17 additions & 0 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,23 @@ impl QCudaStorage {
Ok(())
}

pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> {
// Run the quantization on cpu.
let src_len = src.as_slice::<f32>()?.len();
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;

if let QStorage::Cpu(storage) = &mut qcpu_storage {
storage.from_float(src.as_slice::<f32>()?)?;
} else {
unreachable!()
}

let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
Ok(())
}

pub fn storage_size_in_bytes(&self) -> usize {
self.data.len()
}
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/quantized/dummy_cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ impl QCudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}

pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}

pub fn storage_size_in_bytes(&self) -> usize {
0
}
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/quantized/dummy_metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ impl QMetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}

pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}

pub fn storage_size_in_bytes(&self) -> usize {
0
}
Expand Down
17 changes: 17 additions & 0 deletions candle-core/src/quantized/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,23 @@ impl QMetalStorage {
Ok(())
}

pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> {
// Quantization only happens on CPU for now.
let elem_count = src.as_slice::<f32>()?.len();
let src = crate::Storage::Cpu(src);
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;

if let QStorage::Cpu(storage) = &mut qcpu_storage {
storage.from_float(src.as_slice::<f32>()?)?;
} else {
unreachable!()
}

let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
self.buffer = buffer;
Ok(())
}

pub fn storage_size_in_bytes(&self) -> usize {
self.buffer.length() as usize
}
Expand Down
42 changes: 41 additions & 1 deletion candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,19 @@ impl QStorage {
}
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"),
_ => crate::bail!("Invalid quantize storage locations do not match"),
}
Ok(())
}

fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
(QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
(QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
_ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
}
Ok(())
}
Expand Down Expand Up @@ -341,6 +353,34 @@ impl QTensor {
})
}

/// Quantize `src` (currently on the CPU) to a QTensor on `dev`
pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
if !src.device().is_cpu() {
crate::bail!(
"`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
src.device()
)
}
let shape = src.shape();
let block_size = dtype.block_size();
check_shape(shape, block_size)?;
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
let elem_count = shape.elem_count();
if elem_count % block_size != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
block_size
)
}
// storage is on the `dev`, src is on `cpu`
let mut storage = dev.qzeros(elem_count, dtype)?;
storage.quantize_onto(&src.storage())?;
Ok(Self {
storage,
shape: shape.clone(),
})
}

pub fn dtype(&self) -> GgmlDType {
self.storage.dtype()
}
Expand Down

0 comments on commit ab3c58e

Please sign in to comment.