Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GGUF BF16 dtype support #2387

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions candle-core/src/cpu/avx.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::{Cpu, CpuF16};
use super::{Cpu, CpuBF16, CpuF16};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

use half::f16;
use half::{bf16, f16};

pub struct CurrentCpu {}

Expand Down Expand Up @@ -146,3 +146,82 @@ impl CpuF16<ARR> for CurrentCpuF16 {
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}

pub struct CurrentCpuBF16 {}
impl CpuBF16<ARR> for CurrentCpuBF16 {
type Unit = __m256;
type Array = [__m256; ARR];

const STEP: usize = STEP;
const EPR: usize = EPR;

fn n() -> usize {
ARR
}

unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}

unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}

unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}

#[cfg(target_feature = "f16c")]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm256_loadu_ps(tmp.as_ptr())
}

unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}

unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}

#[cfg(target_feature = "f16c")]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
let mut tmp = [0.0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
for i in 0..8 {
*mem_addr.add(i) = bf16::from_f32(tmp[i]);
}
}

unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
let mut offset = ARR >> 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}
7 changes: 7 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ impl VecOps for half::bf16 {
fn max(self, other: Self) -> Self {
Self::max(self, other)
}

#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
*res = half::bf16::from_f32(res_f32);
}
}
impl VecOps for u8 {
#[inline(always)]
Expand Down
62 changes: 60 additions & 2 deletions candle-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,33 @@ trait CpuF16<const ARR: usize> {
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
}
use half::f16;

#[allow(unused)]
trait CpuBF16<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;

fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const bf16) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit);
}

use half::{bf16, f16};

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub mod avx;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub use avx::{CurrentCpu, CurrentCpuF16};
pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16};

#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
Expand Down Expand Up @@ -170,6 +189,34 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
*c = sumf;
}

#[cfg(target_feature = "avx")]
#[inline(always)]
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
let mut sumf = 0.0f32;
let np = k & !(CurrentCpuBF16::STEP - 1);

let mut sum = CurrentCpuBF16::zero_array();
let mut ax = CurrentCpuBF16::zero_array();
let mut ay = CurrentCpuBF16::zero_array();

for i in (0..np).step_by(CurrentCpuBF16::STEP) {
for j in 0..CurrentCpuBF16::n() {
ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR));
ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR));

sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]);
}
}

CurrentCpuBF16::vec_reduce(sum, &mut sumf);

// leftovers
for i in np..k {
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sumf;
}

#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
Expand All @@ -180,3 +227,14 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
}
*c = sum;
}

#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
// leftovers
let mut sum = 0.0;
for i in 0..k {
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sum;
}
1 change: 1 addition & 0 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ impl QCudaStorage {
match self.dtype {
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
GgmlDType::BF16 => deq::<half::bf16>(&buffer, block_len, &mut out)?,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With #2424 this can be optimized!

GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
Expand Down
46 changes: 45 additions & 1 deletion candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::utils::{
use super::GgmlDType;
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
use half::{bf16, f16};
use rayon::prelude::*;

// Default to QK_K 256 rather than 64.
Expand Down Expand Up @@ -1963,3 +1963,47 @@ impl GgmlType for f16 {
Ok(())
}
}

impl GgmlType for bf16 {
const DTYPE: GgmlDType = GgmlDType::BF16;
const BLCK_SIZE: usize = 1;
type VecDotType = bf16;

fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}

fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if xs.len() < n {
crate::bail!("size mismatch {} < {n}", xs.len())
}
if ys.len() < n {
crate::bail!("size mismatch {} < {n}", ys.len())
}
let mut res = 0f32;
unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
Ok(res)
}

fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
// TODO: vectorize
for (x, y) in xs.iter().zip(ys.iter_mut()) {
*y = bf16::from_f32(*x)
}
Ok(())
}

fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
// TODO: vectorize
for (x, y) in xs.iter().zip(ys.iter_mut()) {
*y = x.to_f32()
}
Ok(())
}
}
12 changes: 9 additions & 3 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils;
use half::f16;
use half::{bf16, f16};

pub use k_quants::GgmlType;

Expand Down Expand Up @@ -133,6 +133,7 @@ impl QStorage {
pub enum GgmlDType {
F32,
F16,
BF16,
Q4_0,
Q4_1,
Q5_0,
Expand Down Expand Up @@ -164,6 +165,8 @@ impl GgmlDType {
13 => Self::Q5K,
14 => Self::Q6K,
15 => Self::Q8K,
// https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
30 => Self::BF16,
_ => crate::bail!("unknown dtype for tensor {u}"),
};
Ok(dtype)
Expand All @@ -185,6 +188,8 @@ impl GgmlDType {
Self::Q5K => 13,
Self::Q6K => 14,
Self::Q8K => 15,
// https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
Self::BF16 => 30,
}
}

Expand All @@ -205,14 +210,15 @@ impl GgmlDType {
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
}
}
/// The type size for blocks in bytes.
pub fn type_size(&self) -> usize {
use k_quants::*;
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::F16 | Self::BF16 => 2,
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
Expand All @@ -233,7 +239,7 @@ impl GgmlDType {
pub fn block_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
Self::F16 | Self::BF16 => 1,
Self::Q4_0 => k_quants::QK4_0,
Self::Q4_1 => k_quants::QK4_1,
Self::Q5_0 => k_quants::QK5_0,
Expand Down
Loading