Skip to content

Commit

Permalink
int8: inference optimizations, some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 18, 2024
1 parent 510a880 commit 0ab14fe
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 235 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ CMakeFiles/
bitsandbytes.dir/
Debug/
Release/
cmake-build-*/

# IDE local files
.vs/
.idea/

# Distribution / packaging
.Python
Expand Down
20 changes: 10 additions & 10 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,17 +335,17 @@ def forward(

# Zero out the outliers in the int8 inputs
CA[:, state.idx] = 0
CAt[:, state.idx] = 0
# CAt[:, state.idx] = 0

# Extract the input outliers in original precision
subA = A[:, state.idx]

# Extract the corresponding weights
if state.has_fp16_weights:
state.subB = B[:, state.idx].t().contiguous()
state.subB = B[:, state.idx].t() # .contiguous()
else:
outliers = state.CB[:, state.idx].clone()
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
outliers = state.CB[:, state.idx] # .clone()
state.subB = (7.874016e-3 * outliers * state.SCB.view(-1, 1)).t().to(A.dtype)
else:
subA = None

Expand All @@ -372,14 +372,14 @@ def forward(
ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx)
else:
ctx.tensors = [None, None, A]
ctx.tensors = [None, None, None] # A]
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)

output_shape = (*input_shape[:-1], state.CB.shape[0])

if len(input_shape) == 3:
return output.view(output_shape).clone()
return output.reshape(output_shape) # .clone()
else:
return output

Expand Down Expand Up @@ -417,10 +417,10 @@ def backward(ctx, grad_output):

if req_gradA:
# grad_output @ B.T
if state.CBt is not None:
gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
# if state.CBt is not None:
# gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
Expand Down
173 changes: 34 additions & 139 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2400,64 +2400,38 @@ def get_colrow_absmax(
outlier_mask = None

if row_stats is None or col_stats is None:
# view as 2D
absA = A.abs().view(-1, A.shape[-1])

if threshold > 0.0:
# Filter outliers from stats when enabled
outlier_mask = absA >= threshold
absA.masked_fill_(outlier_mask, 0.0)

# For parity with tests build nnz_block_ptr.
nnz_block_ptr = torch.zeros(absA.shape[0] + 1, dtype=torch.int64, device=A.device)
nnz_block_ptr[1:] = outlier_mask.sum(1).cumsum(0)

if row_stats is None:
# shape [rows]; unsqueeze(-1) gives [rows,1]
row_stats = absA.amax(dim=1, keepdim=False).float()
row_stats = get_row_absmax(A, threshold)

if col_stats is None:
# shape [cols]; unsqueeze(0) gives [1,cols]
col_stats = absA.amax(dim=0, keepdim=False).float()

return row_stats, col_stats, outlier_mask


def get_colrow_absmax_old(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
def get_row_absmax(A, threshold=0.0):
assert A.dtype == torch.float16
device = A.device

rows = prod(A.shape[:-1])
cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]

col_tiles = (cols + 255) // 256
tiled_rows = ((rows + 15) // 16) * 16
if row_stats is None:
row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
if col_stats is None:
col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)

if nnz_block_ptr is None and threshold > 0.0:
nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device)

ptrA = get_ptr(A)
ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats)
ptrNnzrows = get_ptr(nnz_block_ptr)
rows = ct.c_int32(rows)
cols = ct.c_int32(cols)
row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device)

prev_device = pre_call(A.device)
is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
post_call(prev_device)
is_on_gpu([A, row_stats])

if threshold > 0.0:
nnz_block_ptr.cumsum_(0)
with torch.cuda.device_of(A):
lib.cget_row_stats(get_ptr(A), get_ptr(row_stats), ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))

return row_stats, col_stats, nnz_block_ptr
return row_stats


class COOSparseTensor:
Expand Down Expand Up @@ -2541,127 +2515,48 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)


# @torch.compile
def extract_outliers_new(A: torch.Tensor, threshold: float):
outlier_mask = A.abs() >= threshold
return A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo()


def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
# TODO: Optimize/write CUDA kernel for this
assert A.dtype == torch.half

if row_stats is None or col_stats is None:
row_stats, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold)
rows = prod(A.shape[:-1])
cols = A.shape[-1]

row_stats = torch.empty((rows,), device=A.device, dtype=torch.float32)

out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)

if threshold > 0.0:
# Extract outliers to COO tensor:
# 1. Zero out all of the non-outliers, convert to COO.
# 2. Zero out the outliers in the dense tensor.
# TODO we could improve perf of this
# is_outlier = A.abs() >= threshold
coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo()
A = A.masked_fill(outlier_mask, 0.0)
# outlier_mask = A.abs() >= threshold
# coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo()
# A = A.masked_fill(outlier_mask, 0.0)
coo_tensor = extract_outliers_new(A, threshold)
else:
coo_tensor = None

# Quantize
scaled_A = A.mul(C)
# quant_row = torch.round(A * (C / row_stats.unsqueeze(-1))).to(torch.int8)
# quant_col = torch.round(A * (C / col_stats.unsqueeze(0))).to(torch.int8)
quant_row = torch.round(scaled_A / row_stats.unsqueeze(-1)).to(torch.int8)
quant_col = torch.round(scaled_A / col_stats.unsqueeze(0)).to(torch.int8)

if out_row is not None:
quant_row = out_row.copy_(quant_row)
if out_col is not None:
quant_col = out_col.copy_(quant_col)

return quant_row, quant_col, row_stats.flatten().float(), col_stats.flatten().float(), coo_tensor
is_on_gpu([A, row_stats])


def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
device = A.device
assert A.dtype == torch.half
assert device.type == "cuda"
prev_device = pre_call(A.device)

cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]

if row_stats is None or col_stats is None:
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)

if out_col is None:
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
if out_row is None:
out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)

coo_tensor = None
ptrA = get_ptr(A)
ptrColStats = get_ptr(col_stats)
ptrRowStats = get_ptr(row_stats)
ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row)

is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
ptrRowIdx = get_ptr(coo_tensor.rowidx)
ptrColIdx = get_ptr(coo_tensor.colidx)
ptrVal = get_ptr(coo_tensor.values)
ptrRowPtr = get_ptr(nnz_row_ptr)

lib.cdouble_rowcol_quant(
ptrA,
ptrRowStats,
ptrColStats,
ptrOutCol,
ptrOutRow,
ptrRowIdx,
ptrColIdx,
ptrVal,
ptrRowPtr,
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
)
val, idx = torch.sort(coo_tensor.rowidx)
coo_tensor.rowidx = val
coo_tensor.colidx = coo_tensor.colidx[idx]
coo_tensor.values = coo_tensor.values[idx]
else:
lib.cdouble_rowcol_quant(
ptrA,
ptrRowStats,
ptrColStats,
ptrOutCol,
ptrOutRow,
None,
None,
None,
None,
ct.c_float(0.0),
ct.c_int32(rows),
ct.c_int32(cols),
)
else:
lib.cdouble_rowcol_quant(
ptrA,
ptrRowStats,
ptrColStats,
ptrOutCol,
ptrOutRow,
None,
None,
None,
None,
with torch.cuda.device_of(A):
lib.cint8_vector_quant(
get_ptr(A),
get_ptr(out_row),
get_ptr(row_stats),
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
)
post_call(prev_device)

return out_row, out_col, row_stats, col_stats, coo_tensor
# TODO: col_stats

return out_row, None, row_stats, None, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor


def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
Expand Down
6 changes: 3 additions & 3 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,9 +1008,9 @@ def forward(self, x: torch.Tensor):

out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

if not self.state.has_fp16_weights:
if self.state.CB is not None:
self.weight.data = self.state.CB
if not self.state.has_fp16_weights and self.state.CB is not None:
self.weight.data = self.state.CB

return out


Expand Down
30 changes: 12 additions & 18 deletions bitsandbytes/research/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ class SwitchBackBnb(torch.autograd.Function):
@staticmethod
# TODO: the B008 on the line below is a likely bug; the current implementation will
# have each SwitchBackBnb instance share a single MatmulLtState instance!!!
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008
def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
state = state or MatmulLtState()

# default to pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
Expand Down Expand Up @@ -222,7 +224,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00
# idx = torch.unique(coo_tensorA.colidx).long()
idx = torch.unique(coo_tensorA._indices()[1]).long()
CA[:, idx] = 0
CAt[:, idx] = 0
# CAt[:, idx] = 0
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
Expand All @@ -249,29 +251,21 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
_,
) = F.double_quant(B.to(torch.float16))
state.SB = (state.CB.shape, "row")
else:
has_grad = False

if coo_tensorA is not None and not state.has_fp16_weights:
# extract outliers
state.idx = torch.unique(coo_tensorA._indices()[1]).long()

# outlier_idx = torch.unique(coo_tensorA.colidx)
outlier_idx = torch.unique(coo_tensorA._indices()[1]).long()
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
# outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
outliers = state.CB[:, state.idx.long()].clone()
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
# CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]

shapeB = state.SB[0]
Expand Down Expand Up @@ -318,6 +312,7 @@ def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states
Expand All @@ -340,11 +335,10 @@ def backward(ctx, grad_output):
grad_B = torch.matmul(grad_output.t(), A)

if req_gradA:
if state.CBt is not None:
gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)

elif state.CB is not None:
# if state.CBt is not None:
# gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
Expand Down
Loading

0 comments on commit 0ab14fe

Please sign in to comment.