Skip to content

Commit

Permalink
enable new ipex API
Browse files Browse the repository at this point in the history
ipex weight is 4D so we cannot transpose

fix dequant

check require grad
  • Loading branch information
jiqing-feng committed Sep 14, 2024
1 parent 2784653 commit 012c660
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 41 deletions.
5 changes: 4 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,10 @@ def matmul_4bit(
)
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if getattr(quant_state, "ipex", False):
out = F.gemv_4bit(A, B, out, state=quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None:
out += bias
return out
Expand Down
17 changes: 9 additions & 8 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,11 @@ def dequantize_4bit_impl(
if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"):
assert quant_state.op_context is not None
A = quant_state.op_context.to_public(quant_state.op_context.get_weight())
A = A.reshape(-1)
absmax = quant_state.op_context.get_scales().reshape(-1)
if ipex_cpu and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(
A, "nf4", quant_state.shape, 2
)
quant_state.ipex = False

if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
Expand Down Expand Up @@ -510,9 +510,10 @@ def gemm_4bit_impl(
torch.Tensor:
GEMM output tensor.
"""
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
assert state.op_context is not None
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
if ipex_cpu and _ipex_cpu_version_prereq(2, 5) and getattr(state, "ipex", False):
output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape,
state.new_scales, state.new_zeros, None, None, state.blocksize,
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation)
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
output = torch.matmul(A, dqB.to(A.dtype))
Expand Down
26 changes: 14 additions & 12 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,32 +447,30 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
if (
getattr(self.weight, "quant_state", None) is not None
and getattr(self.weight.quant_state, "op_context", None) is not None
and getattr(self.weight.quant_state, "ipex", False)
):
context = self.weight.quant_state.op_context
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
self.weight, "nf4", self.weight.quant_state.shape, 2
)
self.weight.data = original_weight.data
self.weight.quant_state.ipex = False

super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias

if getattr(self.weight, "quant_state", None) is not None:
if (
self.weight.quant_state.absmax.shape.numel() == 0
and getattr(self.weight.quant_state, "op_context", None) is not None
):
self.weight.quant_state.absmax = context.get_scales().reshape(-1)
delattr(self.weight.quant_state, "op_context")
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()

def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
if (
x.device.type == "cpu"
and not hasattr(self.weight.quant_state, "op_context")
and not getattr(self.weight.quant_state, "ipex", False)
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
and x.requires_grad == False
):
enable_ipex_fusion(self.weight, self.weight.quant_state)
enable_ipex_fusion(self)

# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
Expand All @@ -499,7 +497,11 @@ def forward(self, x: torch.Tensor):
x = x.to(self.compute_dtype)

bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
if getattr(self.weight.quant_state, "ipex", False):
out = bnb.matmul_4bit(x, self.weight, bias=bias, quant_state=self.weight.quant_state)
else:
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)


out = out.to(inp_dtype)

Expand Down
40 changes: 20 additions & 20 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,28 +200,28 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def enable_ipex_fusion(weight, quant_state):
def enable_ipex_fusion(linear):
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq

if _ipex_cpu_version_prereq(2, 3):
import intel_extension_for_pytorch as ipex

lowp_mode = ipex.quantization.WoqLowpMode.BF16
quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
ipex.quantization.WoqWeightDtype.NF4,
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
quant_state.blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
)
quant_state.absmax = torch.Tensor()
weight.data = torch.empty([1, 0], dtype=torch.uint8)
if _ipex_cpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight, new_scales, new_zeros, _, compensation = \
torch.ops.ipex_prepack.woq_linear_pack_weight(
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # batch_size
quant_state.blocksize,
2,
)
linear.weight.data = new_weight.data
setattr(linear.weight.quant_state, "ipex", True)
setattr(linear.weight.quant_state, "new_scales", new_scales)
setattr(linear.weight.quant_state, "new_zeros", new_zeros)
setattr(linear.weight.quant_state, "compensation", compensation)


class QuantState:
Expand Down

0 comments on commit 012c660

Please sign in to comment.