From b8df1aad9414a669e188678b36be304400987a72 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 23 Sep 2024 10:26:22 -0400 Subject: [PATCH] use ipex op in backward --- bitsandbytes/autograd/_functions.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 0abd6b6df..35c2b45de 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -517,7 +517,10 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + if getattr(quant_state, "ipex", False): + output = F.gemv_4bit(A, B, out, state=quant_state) + else: + output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) # 3. Save state ctx.state = quant_state @@ -548,7 +551,10 @@ def backward(ctx, grad_output): # not supported by PyTorch. TODO: create work-around # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) if req_gradA: - grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) + if getattr(ctx.state, "ipex", False): + grad_A = F.gemv_4bit(grad_output, B, None, state=ctx.state) + else: + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None @@ -575,7 +581,7 @@ def matmul_4bit( bias=None, ): assert quant_state is not None - if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False: + if A.numel() == A.shape[-1] and A.device.type != "cpu" and A.requires_grad == False: # CPU backend does not require A to be a vector if A.shape[-1] % quant_state.blocksize != 0: warn(