Skip to content
This repository has been archived by the owner on Apr 27, 2024. It is now read-only.

Commit

Permalink
Add exploratory work for quantizing falcon to int4
Browse files Browse the repository at this point in the history
  • Loading branch information
ramiro050 committed Dec 21, 2023
1 parent a623c87 commit 7620818
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 0 deletions.
18 changes: 18 additions & 0 deletions examples/falcon/falcon-7b-quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Example from Huggingface docs:
# https://huggingface.co/docs/transformers/quantization

from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Using group_size=64 because that is what TheBloke used for
# Falcon-7B-Instruct-GPTQ and because using group_size=128 results in
# an error in AutoGPTQ.
gptq_config = GPTQConfig(bits=4, dataset="c4", tokenizer=tokenizer, group_size=64, disable_exllama=True)

quantized_model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", quantization_config=gptq_config
)

quantized_model.to("cpu")
quantized_model.save_pretrained("falcon-7b-int4-gptq")
75 changes: 75 additions & 0 deletions examples/falcon/falcon-int4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import io
import logging

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._functorch.compile_utils import strip_overloads
import torch_mlir
from transformers import AutoModelForCausalLM

logging.basicConfig(level=logging.INFO)

PATH = "/home/ramiroleal/falcon/quantization-examples/falcon-7b-int4-gptq"
OUTPUT = "/tmp/falcon-int8-raw.mlir"
model = AutoModelForCausalLM.from_pretrained(PATH)

INPUT_SIZE = (1, 100)
for module in model.modules():
if hasattr(module, "unpack"):
print(f"Calling {module}.unpack()")
module.unpack()

continue
x = torch.rand((1, 1, module.infeatures), dtype=torch.float16)
new = module.forward(x)
old = module.forward_old(x)

if not torch.allclose(new, old):
print("Max:", torch.max(torch.abs(new - old)))
print("STD:", torch.std(new - old))
print("Mean:", torch.mean(new - old))
print(
"Corr:",
torch.corrcoef(
torch.stack(
[
new.flatten().to(torch.float32),
old.flatten().to(torch.float32),
]
)
),
)


def add_cast_to_uint4(gm):
for node in gm.graph.nodes:
if node.target == torch.ops.aten.bitwise_not:
node.target = torch.ops.autogptq.cast_to_uint4
gm.recompile()

shape_env = ShapeEnv()
with FakeTensorMode(allow_non_fake_inputs=True, shape_env=shape_env):
input_ids = torch.randint(low=1, high=10000, size=INPUT_SIZE)
torch._dynamo.allow_in_graph(torch.ops.autogptq.cast_to_uint4.default)
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import cast_to_uint4
torch.fx.wrap(cast_to_uint4)
#torch._dynamo.allow_in_graph('cast_to_uint4')
model_fx = make_fx(lambda x: model(x).logits, tracing_mode="symbolic", pre_dispatch=False)(input_ids)
#model_fx_2 = make_fx(model_fx, tracing_mode="symbolic", pre_dispatch=False)(input_ids)
strip_overloads(model_fx)
add_cast_to_uint4(model_fx)
# def model_call(x):
# return model(x).logits
# model_fx = torch.export.export(model_call, (input_ids,))
mlir = torch_mlir.compile(model_fx, input_ids, output_type="raw")

# with open(OUTPUT_STR, "w") as f:
# f.write(str(mlir))

bytecode_stream = io.BytesIO()
mlir.operation.write_bytecode(bytecode_stream)
mlir_bytes = bytecode_stream.getbuffer()
with open(OUTPUT, "bw") as f:
f.write(mlir_bytes)
58 changes: 58 additions & 0 deletions examples/falcon/falcon-int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import io
import logging

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._functorch.compile_utils import strip_overloads
import torch_mlir
from transformers import AutoModelForCausalLM

logging.basicConfig(level=logging.INFO)

PATH = "/home/ramiroleal/falcon/quantization-examples/falcon-7b-int8-gptq"
OUTPUT = "/tmp/falcon-int8-raw.mlir"
OUTPUT_STR = "/tmp/falcon-int8-raw-str.mlir"
model = AutoModelForCausalLM.from_pretrained(PATH)

INPUT_SIZE = (1, 100)
for module in model.modules():
if hasattr(module, "unpack"):
print(f"Calling {module}.unpack()")
module.unpack()

x = torch.rand((1, 1, module.infeatures), dtype=torch.float16)
new = module.forward(x)
old = module.forward_old(x)

if not torch.allclose(new, old):
print("Max:", torch.max(torch.abs(new - old)))
print("STD:", torch.std(new - old))
print("Mean:", torch.mean(new - old))
print(
"Corr:",
torch.corrcoef(
torch.stack(
[
new.flatten().to(torch.float32),
old.flatten().to(torch.float32),
]
)
),
)

assert False
with FakeTensorMode(allow_non_fake_inputs=True):
input_ids = torch.randint(low=1, high=10000, size=INPUT_SIZE)
model_fx = make_fx(lambda x: model(x).logits, tracing_mode="fake")(input_ids)
strip_overloads(model_fx)
mlir = torch_mlir.compile(model_fx, input_ids, output_type="torch")

# with open(OUTPUT_STR, "w") as f:
# f.write(str(mlir))

bytecode_stream = io.BytesIO()
mlir.operation.write_bytecode(bytecode_stream)
mlir_bytes = bytecode_stream.getbuffer()
with open(OUTPUT, "bw") as f:
f.write(mlir_bytes)
92 changes: 92 additions & 0 deletions examples/falcon/uint4-fix.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
diff --git a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py
index 04eb425..39b7fd0 100644
--- a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py
+++ b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py
@@ -17,6 +17,12 @@ except ImportError:
autogptq_cuda_64 = None
_autogptq_cuda_available = False

+def cast_to_uint4(x: torch.Tensor) -> torch.Tensor:
+ return torch.ops.aten.bitwise_not(x)
+
+goofy_lib = torch.library.Library("autogptq", "DEF")
+goofy_lib.define("autogptq::cast_to_uint4(Tensor t) -> Tensor")
+goofy_lib.impl("cast_to_uint4", cast_to_uint4)

class QuantLinear(nn.Module):
QUANT_TYPE = "cuda-old"
@@ -96,6 +102,56 @@ class QuantLinear(nn.Module):
def post_init(self):
pass

+ def unpack(self):
+ assert self.bits == 8 or self.bits == 4
+ zeros = torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits)
+ zeros = torch.bitwise_right_shift(zeros, self.wf.unsqueeze(0))
+ zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1).to(torch.uint8)
+ qzeros_unpacked = zeros.reshape(zeros.shape[0], zeros.shape[1] * zeros.shape[2])
+
+ weight = torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1)
+ weight = torch.bitwise_right_shift(weight, self.wf.unsqueeze(-1))
+ weight = torch.bitwise_and(weight,(2 ** self.bits) - 1).to(torch.uint8)
+ qweight_unpacked = weight.reshape(self.infeatures, self.outfeatures)
+
+ self.register_buffer('qweight_unpacked', qweight_unpacked)
+ self.register_buffer('qzeros_unpacked', qzeros_unpacked)
+
+ def forward(self, x):
+ assert self.bits == 8 or self.bits == 4
+ x_dtype = x.dtype
+ out_shape = x.shape[:-1] + (self.outfeatures,)
+ x = x.reshape(-1, x.shape[-1])
+
+ if self.wf.device != self.qzeros.device:
+ self.wf = self.wf.to(self.qzeros.device)
+
+ if self.bits == 4:
+ zeros = torch.ops.autogptq.cast_to_uint4(self.qzeros_unpacked)
+ else:
+ zeros = self.qzeros_unpacked
+ zeros = zeros.reshape(-1, 1, zeros.shape[-1])
+
+ scales = self.scales
+ scales = scales.reshape(-1, 1, scales.shape[-1])
+
+ if self.bits == 4:
+ weight = torch.ops.autogptq.cast_to_uint4(self.qweight_unpacked)
+ else:
+ weight = self.qweight_unpacked
+ weight = weight.reshape(-1, self.group_size, weight.shape[-1])
+
+ # Multiply by `scales` separately to avoid overflow
+ # Cast to `int16` needed to avoid precision errors and match AutGPTQ behavior
+ bigger_dtype = torch.int16 if self.bits == 8 else torch.uint8
+ weight = scales * (weight - (zeros + 1).to(bigger_dtype))
+ weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
+ out = torch.matmul(x, weight)
+
+ out = out.to(dtype=x_dtype).reshape(out_shape) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output.
+ out = out + self.bias if self.bias is not None else out
+ return out
+
def pack(self, linear, scales, zeros, g_idx):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
@@ -194,7 +250,7 @@ class QuantLinear(nn.Module):
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)

- def forward(self, x):
+ def forward_old(self, x):
x_dtype = x.dtype
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
@@ -267,7 +323,7 @@ class QuantLinear(nn.Module):
weight = weight.reshape(-1, self.group_size, weight.shape[2])
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
-
+
weight = (scales * (weight - zeros))
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
out = torch.matmul(x, weight)

0 comments on commit 7620818

Please sign in to comment.