From cd7bf2145807932c8a8a499ddb6bb14e47eb24fc Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 27 Sep 2024 12:58:25 -0400 Subject: [PATCH] enable backward --- bitsandbytes/autograd/_functions.py | 2 +- bitsandbytes/backends/cpu.py | 3 ++- bitsandbytes/backends/cpu_xpu_common.py | 12 ++++++++--- bitsandbytes/functional.py | 28 ++++++++++++++++++------- bitsandbytes/nn/modules.py | 3 +-- bitsandbytes/utils.py | 24 ++++++++++++++++++--- 6 files changed, 54 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 35c2b45de..06683690c 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -552,7 +552,7 @@ def backward(ctx, grad_output): # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) if req_gradA: if getattr(ctx.state, "ipex", False): - grad_A = F.gemv_4bit(grad_output, B, None, state=ctx.state) + grad_A = F.gemv_4bit(grad_output, B, None, state=ctx.state, backward=True) else: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index 5d38171d5..549808c82 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -163,12 +163,13 @@ def gemv_4bit( transposed_A=False, transposed_B=False, state: QuantState = None, + backward=False, ) -> torch.Tensor: assert_on_cpu([A, B, out]) if state is None: raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") - return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state) + return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state, backward) def dequantize_blockwise( self, diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 78473bdc4..c298962a2 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -486,6 +486,7 @@ def gemm_4bit_impl( transposed_A=False, transposed_B=False, state: QuantState = None, + backward=False, ) -> torch.Tensor: """ Matrix-matrix multiplication with 4-bit quantization. @@ -511,9 +512,14 @@ def gemm_4bit_impl( GEMM output tensor. """ if ipex_cpu and _ipex_cpu_version_prereq(2, 5) and getattr(state, "ipex", False): - output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape, - state.new_scales, state.new_zeros, None, None, state.blocksize, - ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation) + if backward: + output = torch.ops.torch_ipex.woq_linear(A, state.backward_weight, "nf4", torch.Size([state.shape[1], state.shape[0]]), + state.backward_new_scales, state.backward_new_zeros, None, None, state.blocksize, + ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.backward_compensation) + else: + output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape, + state.new_scales, state.new_zeros, None, None, state.blocksize, + ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation) else: dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t() output = torch.matmul(A, dqB.to(A.dtype)) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6cf64df28..b53212bfd 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1530,16 +1530,28 @@ def gemv_4bit( transposed_A=False, transposed_B=False, state=None, + backward=False, ): ensure_backend_is_available(A.device.type) - return backends[A.device.type].gemv_4bit( - A, - B, - out=out, - transposed_A=transposed_A, - transposed_B=transposed_B, - state=state, - ) + if A.device.type == "cpu": + return backends[A.device.type].gemv_4bit( + A, + B, + out=out, + transposed_A=transposed_A, + transposed_B=transposed_B, + state=state, + backward=backward, + ) + else: + return backends[A.device.type].gemv_4bit( + A, + B, + out=out, + transposed_A=transposed_A, + transposed_B=transposed_B, + state=state, + ) def igemm( diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 0635c653d..dc00acdaf 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -468,9 +468,8 @@ def forward(self, x: torch.Tensor): and not getattr(self.weight.quant_state, "ipex", False) and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 and self.weight.quant_state.quant_type == "nf4" - and x.requires_grad == False ): - enable_ipex_fusion(self) + enable_ipex_fusion(self, x.requires_grad) # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index b89edd828..e0810a6e8 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -200,23 +200,41 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict -def enable_ipex_fusion(linear): +def enable_ipex_fusion(linear, grad=False): from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq if _ipex_cpu_version_prereq(2, 5): quant_state = linear.weight.quant_state new_weight, new_scales, new_zeros, _, compensation = \ + torch.ops.ipex_prepack.woq_linear_pack_weight( + linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + "nf4", + quant_state.shape, # weight shape + quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales + None, # zero_points + None, # bias + None, # batch_size + quant_state.blocksize, + 2, + ) + if grad or True: + backward_new_weight, backward_new_scales, backward_new_zeros, _, backward_compensation = \ torch.ops.ipex_prepack.woq_linear_pack_weight( - linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + linear.weight.t().data.reshape([quant_state.shape[1], quant_state.shape[0] // 2]), "nf4", quant_state.shape, # weight shape - quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales + quant_state.absmax.view(quant_state.shape[1], quant_state.shape[0] // quant_state.blocksize), # scales None, # zero_points None, # bias None, # batch_size quant_state.blocksize, 2, ) + setattr(linear.weight.quant_state, "backward_weight", backward_new_weight) + setattr(linear.weight.quant_state, "backward_new_scales", backward_new_scales) + setattr(linear.weight.quant_state, "backward_new_zeros", backward_new_zeros) + setattr(linear.weight.quant_state, "backward_compensation", backward_compensation) + linear.weight.data = new_weight.data setattr(linear.weight.quant_state, "ipex", True) setattr(linear.weight.quant_state, "new_scales", new_scales)