Skip to content

Commit

Permalink
use ipex op in backward
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Sep 23, 2024
1 parent 012c660 commit b8df1aa
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,10 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]

# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
if getattr(quant_state, "ipex", False):
output = F.gemv_4bit(A, B, out, state=quant_state)
else:
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)

# 3. Save state
ctx.state = quant_state
Expand Down Expand Up @@ -548,7 +551,10 @@ def backward(ctx, grad_output):
# not supported by PyTorch. TODO: create work-around
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
if getattr(ctx.state, "ipex", False):
grad_A = F.gemv_4bit(grad_output, B, None, state=ctx.state)
else:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())

return grad_A, grad_B, None, grad_bias, None

Expand All @@ -575,7 +581,7 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
if A.numel() == A.shape[-1] and A.device.type != "cpu" and A.requires_grad == False:
# CPU backend does not require A to be a vector
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down

0 comments on commit b8df1aa

Please sign in to comment.