Skip to content

Commit

Permalink
int8: more tests passing, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 18, 2024
1 parent 0ab14fe commit fdf4745
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 47 deletions.
46 changes: 28 additions & 18 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,14 @@ def forward(
# 1. Quantize A
if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1])
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)

if ctx.needs_input_grad[1]:
# Slower path
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
else:
# Fast path
CA, SCA, coo_tensorA = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
CAt = SCAt = None

has_grad = False

Expand All @@ -322,20 +329,24 @@ def forward(
state.reset_grads()

# 2. Quantize B
(
state.CB,
state.CBt,
state.SCB,
state.SCBt,
_,
) = F.double_quant(B.to(torch.float16))
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))

# (
# state.CB,
# state.CBt,
# state.SCB,
# state.SCBt,
# _,
# ) = F.double_quant(B.to(torch.float16))

if state.threshold > 0.0 and coo_tensorA is not None:
state.idx = torch.unique(coo_tensorA._indices()[1]).long()

# Zero out the outliers in the int8 inputs
CA[:, state.idx] = 0
# CAt[:, state.idx] = 0

if CAt is not None:
CAt[:, state.idx] = 0

# Extract the input outliers in original precision
subA = A[:, state.idx]
Expand Down Expand Up @@ -372,7 +383,7 @@ def forward(
ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx)
else:
ctx.tensors = [None, None, None] # A]
ctx.tensors = [None, None, None]
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)

Expand Down Expand Up @@ -403,17 +414,16 @@ def backward(ctx, grad_output):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
# if req_gradB:

# grad_B = torch.matmul(grad_output.t(), A)
# if state.threshold > 0.0 and subA is not None:
# grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradB:
gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t())
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
grad_B = torch.matmul(grad_output.t(), A)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
# Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
# if req_gradB:
# gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t())
# grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
# if state.threshold > 0.0 and subA is not None:
# grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

if req_gradA:
# grad_output @ B.T
Expand Down
36 changes: 24 additions & 12 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2409,6 +2409,7 @@ def get_colrow_absmax(

if row_stats is None:
# shape [rows]; unsqueeze(-1) gives [rows,1]
# We have a CUDA kernel for row max, but not yet for cols.
row_stats = get_row_absmax(A, threshold)

if col_stats is None:
Expand Down Expand Up @@ -2521,29 +2522,42 @@ def extract_outliers_new(A: torch.Tensor, threshold: float):


def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
# TODO: Optimize/write CUDA kernel for this?
# Note: for inference, use the new int8_vectorwise_quant.

# Use CUDA kernel for rowwise and COO tensor
quant_row, row_stats, coo_tensor = int8_vectorwise_quant(A, threshold=threshold)

# PyTorch impl for colwise
_, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold)
if threshold > 0.0 and outlier_mask is not None:
A = A.masked_fill(outlier_mask, 0.0)
quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8)

if out_row is not None:
quant_row = out_row.copy_(quant_row)
if out_col is not None:
quant_col = out_col.copy_(quant_col)

return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor


def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
assert A.dtype == torch.half
is_on_gpu([A])

rows = prod(A.shape[:-1])
cols = A.shape[-1]

row_stats = torch.empty((rows,), device=A.device, dtype=torch.float32)

out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)

if threshold > 0.0:
# Extract outliers to COO tensor:
# 1. Zero out all of the non-outliers, convert to COO.
# 2. Zero out the outliers in the dense tensor.
# TODO we could improve perf of this
# outlier_mask = A.abs() >= threshold
# coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo()
# A = A.masked_fill(outlier_mask, 0.0)
coo_tensor = extract_outliers_new(A, threshold)
else:
coo_tensor = None

is_on_gpu([A, row_stats])

with torch.cuda.device_of(A):
lib.cint8_vector_quant(
get_ptr(A),
Expand All @@ -2554,9 +2568,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None,
ct.c_int32(cols),
)

# TODO: col_stats

return out_row, None, row_stats, None, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor
return out_row, row_stats, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor


def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
Expand Down
4 changes: 2 additions & 2 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3612,7 +3612,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
#pragma unroll
for(int k = 0; k < num_values_8bit/4; k++)
{
#if __CUDA_ARCH__ >= 800
#if BNB_BF16_AVAILABLE
local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax;
local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax;
#else
Expand Down Expand Up @@ -3649,7 +3649,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
#pragma unroll
for(int k = 0; k < num_values_4bit/4; k++)
{
#if __CUDA_ARCH__ >= 800
#if BNB_BF16_AVAILABLE
local_C += (float)(local_A[k]*local_B[k]);
#else
// bf16 multipliation not supported
Expand Down
17 changes: 10 additions & 7 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,16 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
if not has_fp16_weights:
if not transpose[0] and not transpose[1]:
B2 = B2.t().contiguous()
(
state.CB,
CBt,
state.SCB,
SCBt,
coo_tensorB,
) = bnb.functional.double_quant(B2.to(torch.float16))

state.CB, state.SCB, _ = bnb.functional.int8_vectorwise_quant(B2.to(torch.float16))

# (
# state.CB,
# CBt,
# state.SCB,
# SCBt,
# coo_tensorB,
# ) = bnb.functional.double_quant(B2.to(torch.float16))
B2 = state.CB

if not transpose[0] and transpose[1]:
Expand Down
30 changes: 25 additions & 5 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,17 +1132,37 @@ def test_overflow():
c2 = torch.matmul(a.float(), b.float().t())


@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_double_quant(dim1, dim2):
threshold = 2.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()

idx = torch.abs(A) >= threshold
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)

if coo_tensor is not None:
A1 = A * idx
A2 = coo_tensor.to_dense()
torch.testing.assert_close(A1, A2)

A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)


# @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
# @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_double_quant(dim1, dim2):
def test_coo_int8_vectorwise_quant(dim1, dim2):
threshold = 3.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()

idx = torch.abs(A) >= threshold
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold)

if coo_tensor is not None:
A1 = A * idx
Expand Down Expand Up @@ -1239,13 +1259,13 @@ def test_integrated_sparse_decomp(dim1, dim2):
w1 = torch.randn(dim1, dim2).cuda().half()
out1 = torch.matmul(A, w1.t())

Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
Cw1, statsw1, coo_tensor = F.int8_vectorwise_quant(w1)
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A)

out1_32, Sout1_32 = F.igemmlt(CA, Cw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold)

out1_32, Sout1_32 = F.igemmlt(CA, Cw1)
out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
Expand Down
7 changes: 4 additions & 3 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ def test_linear_no_igemmlt():

assert linear_custom.state.CB is not None
assert not linear_custom.state.has_fp16_weights
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)

# assert linear_custom.state.CxB is None
idx = torch.isclose(fx_ref, fx_ours, atol=0.02, rtol=1e-5)
assert (idx == 0).sum().item() < fx_ref.numel() * 2.5e-4
torch.testing.assert_close(fx_ref, fx_ours, atol=0.03, rtol=1e-5)
torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5)


@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
Expand Down

0 comments on commit fdf4745

Please sign in to comment.