Skip to content

Commit

Permalink
enable backward
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Sep 27, 2024
1 parent b8df1aa commit cd7bf21
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 18 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def backward(ctx, grad_output):
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
if getattr(ctx.state, "ipex", False):
grad_A = F.gemv_4bit(grad_output, B, None, state=ctx.state)
grad_A = F.gemv_4bit(grad_output, B, None, state=ctx.state, backward=True)
else:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())

Expand Down
3 changes: 2 additions & 1 deletion bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,13 @@ def gemv_4bit(
transposed_A=False,
transposed_B=False,
state: QuantState = None,
backward=False,
) -> torch.Tensor:
assert_on_cpu([A, B, out])
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")

return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state, backward)

def dequantize_blockwise(
self,
Expand Down
12 changes: 9 additions & 3 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def gemm_4bit_impl(
transposed_A=False,
transposed_B=False,
state: QuantState = None,
backward=False,
) -> torch.Tensor:
"""
Matrix-matrix multiplication with 4-bit quantization.
Expand All @@ -511,9 +512,14 @@ def gemm_4bit_impl(
GEMM output tensor.
"""
if ipex_cpu and _ipex_cpu_version_prereq(2, 5) and getattr(state, "ipex", False):
output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape,
state.new_scales, state.new_zeros, None, None, state.blocksize,
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation)
if backward:
output = torch.ops.torch_ipex.woq_linear(A, state.backward_weight, "nf4", torch.Size([state.shape[1], state.shape[0]]),
state.backward_new_scales, state.backward_new_zeros, None, None, state.blocksize,
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.backward_compensation)
else:
output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape,
state.new_scales, state.new_zeros, None, None, state.blocksize,
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation)
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
output = torch.matmul(A, dqB.to(A.dtype))
Expand Down
28 changes: 20 additions & 8 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,16 +1530,28 @@ def gemv_4bit(
transposed_A=False,
transposed_B=False,
state=None,
backward=False,
):
ensure_backend_is_available(A.device.type)
return backends[A.device.type].gemv_4bit(
A,
B,
out=out,
transposed_A=transposed_A,
transposed_B=transposed_B,
state=state,
)
if A.device.type == "cpu":
return backends[A.device.type].gemv_4bit(
A,
B,
out=out,
transposed_A=transposed_A,
transposed_B=transposed_B,
state=state,
backward=backward,
)
else:
return backends[A.device.type].gemv_4bit(
A,
B,
out=out,
transposed_A=transposed_A,
transposed_B=transposed_B,
state=state,
)


def igemm(
Expand Down
3 changes: 1 addition & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,8 @@ def forward(self, x: torch.Tensor):
and not getattr(self.weight.quant_state, "ipex", False)
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
and x.requires_grad == False
):
enable_ipex_fusion(self)
enable_ipex_fusion(self, x.requires_grad)

# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
Expand Down
24 changes: 21 additions & 3 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,23 +200,41 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def enable_ipex_fusion(linear):
def enable_ipex_fusion(linear, grad=False):
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq

if _ipex_cpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight, new_scales, new_zeros, _, compensation = \
torch.ops.ipex_prepack.woq_linear_pack_weight(
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # batch_size
quant_state.blocksize,
2,
)
if grad or True:
backward_new_weight, backward_new_scales, backward_new_zeros, _, backward_compensation = \
torch.ops.ipex_prepack.woq_linear_pack_weight(
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
linear.weight.t().data.reshape([quant_state.shape[1], quant_state.shape[0] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
quant_state.absmax.view(quant_state.shape[1], quant_state.shape[0] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # batch_size
quant_state.blocksize,
2,
)
setattr(linear.weight.quant_state, "backward_weight", backward_new_weight)
setattr(linear.weight.quant_state, "backward_new_scales", backward_new_scales)
setattr(linear.weight.quant_state, "backward_new_zeros", backward_new_zeros)
setattr(linear.weight.quant_state, "backward_compensation", backward_compensation)

linear.weight.data = new_weight.data
setattr(linear.weight.quant_state, "ipex", True)
setattr(linear.weight.quant_state, "new_scales", new_scales)
Expand Down

0 comments on commit cd7bf21

Please sign in to comment.