This repository has been archived by the owner on Apr 27, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add exploratory work for quantizing falcon to int4
- Loading branch information
Showing
4 changed files
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Example from Huggingface docs: | ||
# https://huggingface.co/docs/transformers/quantization | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig | ||
|
||
model_id = "tiiuae/falcon-7b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
# Using group_size=64 because that is what TheBloke used for | ||
# Falcon-7B-Instruct-GPTQ and because using group_size=128 results in | ||
# an error in AutoGPTQ. | ||
gptq_config = GPTQConfig(bits=4, dataset="c4", tokenizer=tokenizer, group_size=64, disable_exllama=True) | ||
|
||
quantized_model = AutoModelForCausalLM.from_pretrained( | ||
model_id, device_map="auto", quantization_config=gptq_config | ||
) | ||
|
||
quantized_model.to("cpu") | ||
quantized_model.save_pretrained("falcon-7b-int4-gptq") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import io | ||
import logging | ||
|
||
import torch | ||
from torch.fx.experimental.proxy_tensor import make_fx | ||
from torch.fx.experimental.symbolic_shapes import ShapeEnv | ||
from torch._subclasses.fake_tensor import FakeTensorMode | ||
from torch._functorch.compile_utils import strip_overloads | ||
import torch_mlir | ||
from transformers import AutoModelForCausalLM | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
PATH = "/home/ramiroleal/falcon/quantization-examples/falcon-7b-int4-gptq" | ||
OUTPUT = "/tmp/falcon-int8-raw.mlir" | ||
model = AutoModelForCausalLM.from_pretrained(PATH) | ||
|
||
INPUT_SIZE = (1, 100) | ||
for module in model.modules(): | ||
if hasattr(module, "unpack"): | ||
print(f"Calling {module}.unpack()") | ||
module.unpack() | ||
|
||
continue | ||
x = torch.rand((1, 1, module.infeatures), dtype=torch.float16) | ||
new = module.forward(x) | ||
old = module.forward_old(x) | ||
|
||
if not torch.allclose(new, old): | ||
print("Max:", torch.max(torch.abs(new - old))) | ||
print("STD:", torch.std(new - old)) | ||
print("Mean:", torch.mean(new - old)) | ||
print( | ||
"Corr:", | ||
torch.corrcoef( | ||
torch.stack( | ||
[ | ||
new.flatten().to(torch.float32), | ||
old.flatten().to(torch.float32), | ||
] | ||
) | ||
), | ||
) | ||
|
||
|
||
def add_cast_to_uint4(gm): | ||
for node in gm.graph.nodes: | ||
if node.target == torch.ops.aten.bitwise_not: | ||
node.target = torch.ops.autogptq.cast_to_uint4 | ||
gm.recompile() | ||
|
||
shape_env = ShapeEnv() | ||
with FakeTensorMode(allow_non_fake_inputs=True, shape_env=shape_env): | ||
input_ids = torch.randint(low=1, high=10000, size=INPUT_SIZE) | ||
torch._dynamo.allow_in_graph(torch.ops.autogptq.cast_to_uint4.default) | ||
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import cast_to_uint4 | ||
torch.fx.wrap(cast_to_uint4) | ||
#torch._dynamo.allow_in_graph('cast_to_uint4') | ||
model_fx = make_fx(lambda x: model(x).logits, tracing_mode="symbolic", pre_dispatch=False)(input_ids) | ||
#model_fx_2 = make_fx(model_fx, tracing_mode="symbolic", pre_dispatch=False)(input_ids) | ||
strip_overloads(model_fx) | ||
add_cast_to_uint4(model_fx) | ||
# def model_call(x): | ||
# return model(x).logits | ||
# model_fx = torch.export.export(model_call, (input_ids,)) | ||
mlir = torch_mlir.compile(model_fx, input_ids, output_type="raw") | ||
|
||
# with open(OUTPUT_STR, "w") as f: | ||
# f.write(str(mlir)) | ||
|
||
bytecode_stream = io.BytesIO() | ||
mlir.operation.write_bytecode(bytecode_stream) | ||
mlir_bytes = bytecode_stream.getbuffer() | ||
with open(OUTPUT, "bw") as f: | ||
f.write(mlir_bytes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import io | ||
import logging | ||
|
||
import torch | ||
from torch.fx.experimental.proxy_tensor import make_fx | ||
from torch._subclasses.fake_tensor import FakeTensorMode | ||
from torch._functorch.compile_utils import strip_overloads | ||
import torch_mlir | ||
from transformers import AutoModelForCausalLM | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
PATH = "/home/ramiroleal/falcon/quantization-examples/falcon-7b-int8-gptq" | ||
OUTPUT = "/tmp/falcon-int8-raw.mlir" | ||
OUTPUT_STR = "/tmp/falcon-int8-raw-str.mlir" | ||
model = AutoModelForCausalLM.from_pretrained(PATH) | ||
|
||
INPUT_SIZE = (1, 100) | ||
for module in model.modules(): | ||
if hasattr(module, "unpack"): | ||
print(f"Calling {module}.unpack()") | ||
module.unpack() | ||
|
||
x = torch.rand((1, 1, module.infeatures), dtype=torch.float16) | ||
new = module.forward(x) | ||
old = module.forward_old(x) | ||
|
||
if not torch.allclose(new, old): | ||
print("Max:", torch.max(torch.abs(new - old))) | ||
print("STD:", torch.std(new - old)) | ||
print("Mean:", torch.mean(new - old)) | ||
print( | ||
"Corr:", | ||
torch.corrcoef( | ||
torch.stack( | ||
[ | ||
new.flatten().to(torch.float32), | ||
old.flatten().to(torch.float32), | ||
] | ||
) | ||
), | ||
) | ||
|
||
assert False | ||
with FakeTensorMode(allow_non_fake_inputs=True): | ||
input_ids = torch.randint(low=1, high=10000, size=INPUT_SIZE) | ||
model_fx = make_fx(lambda x: model(x).logits, tracing_mode="fake")(input_ids) | ||
strip_overloads(model_fx) | ||
mlir = torch_mlir.compile(model_fx, input_ids, output_type="torch") | ||
|
||
# with open(OUTPUT_STR, "w") as f: | ||
# f.write(str(mlir)) | ||
|
||
bytecode_stream = io.BytesIO() | ||
mlir.operation.write_bytecode(bytecode_stream) | ||
mlir_bytes = bytecode_stream.getbuffer() | ||
with open(OUTPUT, "bw") as f: | ||
f.write(mlir_bytes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
diff --git a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py | ||
index 04eb425..39b7fd0 100644 | ||
--- a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py | ||
+++ b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py | ||
@@ -17,6 +17,12 @@ except ImportError: | ||
autogptq_cuda_64 = None | ||
_autogptq_cuda_available = False | ||
|
||
+def cast_to_uint4(x: torch.Tensor) -> torch.Tensor: | ||
+ return torch.ops.aten.bitwise_not(x) | ||
+ | ||
+goofy_lib = torch.library.Library("autogptq", "DEF") | ||
+goofy_lib.define("autogptq::cast_to_uint4(Tensor t) -> Tensor") | ||
+goofy_lib.impl("cast_to_uint4", cast_to_uint4) | ||
|
||
class QuantLinear(nn.Module): | ||
QUANT_TYPE = "cuda-old" | ||
@@ -96,6 +102,56 @@ class QuantLinear(nn.Module): | ||
def post_init(self): | ||
pass | ||
|
||
+ def unpack(self): | ||
+ assert self.bits == 8 or self.bits == 4 | ||
+ zeros = torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits) | ||
+ zeros = torch.bitwise_right_shift(zeros, self.wf.unsqueeze(0)) | ||
+ zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1).to(torch.uint8) | ||
+ qzeros_unpacked = zeros.reshape(zeros.shape[0], zeros.shape[1] * zeros.shape[2]) | ||
+ | ||
+ weight = torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1) | ||
+ weight = torch.bitwise_right_shift(weight, self.wf.unsqueeze(-1)) | ||
+ weight = torch.bitwise_and(weight,(2 ** self.bits) - 1).to(torch.uint8) | ||
+ qweight_unpacked = weight.reshape(self.infeatures, self.outfeatures) | ||
+ | ||
+ self.register_buffer('qweight_unpacked', qweight_unpacked) | ||
+ self.register_buffer('qzeros_unpacked', qzeros_unpacked) | ||
+ | ||
+ def forward(self, x): | ||
+ assert self.bits == 8 or self.bits == 4 | ||
+ x_dtype = x.dtype | ||
+ out_shape = x.shape[:-1] + (self.outfeatures,) | ||
+ x = x.reshape(-1, x.shape[-1]) | ||
+ | ||
+ if self.wf.device != self.qzeros.device: | ||
+ self.wf = self.wf.to(self.qzeros.device) | ||
+ | ||
+ if self.bits == 4: | ||
+ zeros = torch.ops.autogptq.cast_to_uint4(self.qzeros_unpacked) | ||
+ else: | ||
+ zeros = self.qzeros_unpacked | ||
+ zeros = zeros.reshape(-1, 1, zeros.shape[-1]) | ||
+ | ||
+ scales = self.scales | ||
+ scales = scales.reshape(-1, 1, scales.shape[-1]) | ||
+ | ||
+ if self.bits == 4: | ||
+ weight = torch.ops.autogptq.cast_to_uint4(self.qweight_unpacked) | ||
+ else: | ||
+ weight = self.qweight_unpacked | ||
+ weight = weight.reshape(-1, self.group_size, weight.shape[-1]) | ||
+ | ||
+ # Multiply by `scales` separately to avoid overflow | ||
+ # Cast to `int16` needed to avoid precision errors and match AutGPTQ behavior | ||
+ bigger_dtype = torch.int16 if self.bits == 8 else torch.uint8 | ||
+ weight = scales * (weight - (zeros + 1).to(bigger_dtype)) | ||
+ weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) | ||
+ out = torch.matmul(x, weight) | ||
+ | ||
+ out = out.to(dtype=x_dtype).reshape(out_shape) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output. | ||
+ out = out + self.bias if self.bias is not None else out | ||
+ return out | ||
+ | ||
def pack(self, linear, scales, zeros, g_idx): | ||
W = linear.weight.data.clone() | ||
if isinstance(linear, nn.Conv2d): | ||
@@ -194,7 +250,7 @@ class QuantLinear(nn.Module): | ||
qzeros = qzeros.astype(np.int32) | ||
self.qzeros = torch.from_numpy(qzeros) | ||
|
||
- def forward(self, x): | ||
+ def forward_old(self, x): | ||
x_dtype = x.dtype | ||
out_shape = x.shape[:-1] + (self.outfeatures,) | ||
x = x.reshape(-1, x.shape[-1]) | ||
@@ -267,7 +323,7 @@ class QuantLinear(nn.Module): | ||
weight = weight.reshape(-1, self.group_size, weight.shape[2]) | ||
else: | ||
raise NotImplementedError("Only 2,3,4,8 bits are supported.") | ||
- | ||
+ | ||
weight = (scales * (weight - zeros)) | ||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) | ||
out = torch.matmul(x, weight) |