From 7620818ebdd973ea8eaed628d9eba71615642ae7 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Thu, 21 Dec 2023 23:28:44 +0000 Subject: [PATCH] Add exploratory work for quantizing falcon to int4 --- examples/falcon/falcon-7b-quantizer.py | 18 +++++ examples/falcon/falcon-int4.py | 75 +++++++++++++++++++++ examples/falcon/falcon-int8.py | 58 ++++++++++++++++ examples/falcon/uint4-fix.diff | 92 ++++++++++++++++++++++++++ 4 files changed, 243 insertions(+) create mode 100644 examples/falcon/falcon-7b-quantizer.py create mode 100644 examples/falcon/falcon-int4.py create mode 100644 examples/falcon/falcon-int8.py create mode 100644 examples/falcon/uint4-fix.diff diff --git a/examples/falcon/falcon-7b-quantizer.py b/examples/falcon/falcon-7b-quantizer.py new file mode 100644 index 0000000..e403b64 --- /dev/null +++ b/examples/falcon/falcon-7b-quantizer.py @@ -0,0 +1,18 @@ +# Example from Huggingface docs: +# https://huggingface.co/docs/transformers/quantization + +from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig + +model_id = "tiiuae/falcon-7b" +tokenizer = AutoTokenizer.from_pretrained(model_id) +# Using group_size=64 because that is what TheBloke used for +# Falcon-7B-Instruct-GPTQ and because using group_size=128 results in +# an error in AutoGPTQ. +gptq_config = GPTQConfig(bits=4, dataset="c4", tokenizer=tokenizer, group_size=64, disable_exllama=True) + +quantized_model = AutoModelForCausalLM.from_pretrained( + model_id, device_map="auto", quantization_config=gptq_config +) + +quantized_model.to("cpu") +quantized_model.save_pretrained("falcon-7b-int4-gptq") diff --git a/examples/falcon/falcon-int4.py b/examples/falcon/falcon-int4.py new file mode 100644 index 0000000..6b10ed6 --- /dev/null +++ b/examples/falcon/falcon-int4.py @@ -0,0 +1,75 @@ +import io +import logging + +import torch +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._functorch.compile_utils import strip_overloads +import torch_mlir +from transformers import AutoModelForCausalLM + +logging.basicConfig(level=logging.INFO) + +PATH = "/home/ramiroleal/falcon/quantization-examples/falcon-7b-int4-gptq" +OUTPUT = "/tmp/falcon-int8-raw.mlir" +model = AutoModelForCausalLM.from_pretrained(PATH) + +INPUT_SIZE = (1, 100) +for module in model.modules(): + if hasattr(module, "unpack"): + print(f"Calling {module}.unpack()") + module.unpack() + + continue + x = torch.rand((1, 1, module.infeatures), dtype=torch.float16) + new = module.forward(x) + old = module.forward_old(x) + + if not torch.allclose(new, old): + print("Max:", torch.max(torch.abs(new - old))) + print("STD:", torch.std(new - old)) + print("Mean:", torch.mean(new - old)) + print( + "Corr:", + torch.corrcoef( + torch.stack( + [ + new.flatten().to(torch.float32), + old.flatten().to(torch.float32), + ] + ) + ), + ) + + +def add_cast_to_uint4(gm): + for node in gm.graph.nodes: + if node.target == torch.ops.aten.bitwise_not: + node.target = torch.ops.autogptq.cast_to_uint4 + gm.recompile() + +shape_env = ShapeEnv() +with FakeTensorMode(allow_non_fake_inputs=True, shape_env=shape_env): + input_ids = torch.randint(low=1, high=10000, size=INPUT_SIZE) + torch._dynamo.allow_in_graph(torch.ops.autogptq.cast_to_uint4.default) + from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import cast_to_uint4 + torch.fx.wrap(cast_to_uint4) + #torch._dynamo.allow_in_graph('cast_to_uint4') + model_fx = make_fx(lambda x: model(x).logits, tracing_mode="symbolic", pre_dispatch=False)(input_ids) + #model_fx_2 = make_fx(model_fx, tracing_mode="symbolic", pre_dispatch=False)(input_ids) + strip_overloads(model_fx) + add_cast_to_uint4(model_fx) +# def model_call(x): +# return model(x).logits +# model_fx = torch.export.export(model_call, (input_ids,)) + mlir = torch_mlir.compile(model_fx, input_ids, output_type="raw") + +# with open(OUTPUT_STR, "w") as f: +# f.write(str(mlir)) + +bytecode_stream = io.BytesIO() +mlir.operation.write_bytecode(bytecode_stream) +mlir_bytes = bytecode_stream.getbuffer() +with open(OUTPUT, "bw") as f: + f.write(mlir_bytes) diff --git a/examples/falcon/falcon-int8.py b/examples/falcon/falcon-int8.py new file mode 100644 index 0000000..2449089 --- /dev/null +++ b/examples/falcon/falcon-int8.py @@ -0,0 +1,58 @@ +import io +import logging + +import torch +from torch.fx.experimental.proxy_tensor import make_fx +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._functorch.compile_utils import strip_overloads +import torch_mlir +from transformers import AutoModelForCausalLM + +logging.basicConfig(level=logging.INFO) + +PATH = "/home/ramiroleal/falcon/quantization-examples/falcon-7b-int8-gptq" +OUTPUT = "/tmp/falcon-int8-raw.mlir" +OUTPUT_STR = "/tmp/falcon-int8-raw-str.mlir" +model = AutoModelForCausalLM.from_pretrained(PATH) + +INPUT_SIZE = (1, 100) +for module in model.modules(): + if hasattr(module, "unpack"): + print(f"Calling {module}.unpack()") + module.unpack() + + x = torch.rand((1, 1, module.infeatures), dtype=torch.float16) + new = module.forward(x) + old = module.forward_old(x) + + if not torch.allclose(new, old): + print("Max:", torch.max(torch.abs(new - old))) + print("STD:", torch.std(new - old)) + print("Mean:", torch.mean(new - old)) + print( + "Corr:", + torch.corrcoef( + torch.stack( + [ + new.flatten().to(torch.float32), + old.flatten().to(torch.float32), + ] + ) + ), + ) + +assert False +with FakeTensorMode(allow_non_fake_inputs=True): + input_ids = torch.randint(low=1, high=10000, size=INPUT_SIZE) + model_fx = make_fx(lambda x: model(x).logits, tracing_mode="fake")(input_ids) + strip_overloads(model_fx) + mlir = torch_mlir.compile(model_fx, input_ids, output_type="torch") + +# with open(OUTPUT_STR, "w") as f: +# f.write(str(mlir)) + +bytecode_stream = io.BytesIO() +mlir.operation.write_bytecode(bytecode_stream) +mlir_bytes = bytecode_stream.getbuffer() +with open(OUTPUT, "bw") as f: + f.write(mlir_bytes) diff --git a/examples/falcon/uint4-fix.diff b/examples/falcon/uint4-fix.diff new file mode 100644 index 0000000..8e9f052 --- /dev/null +++ b/examples/falcon/uint4-fix.diff @@ -0,0 +1,92 @@ +diff --git a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py +index 04eb425..39b7fd0 100644 +--- a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py ++++ b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py +@@ -17,6 +17,12 @@ except ImportError: + autogptq_cuda_64 = None + _autogptq_cuda_available = False + ++def cast_to_uint4(x: torch.Tensor) -> torch.Tensor: ++ return torch.ops.aten.bitwise_not(x) ++ ++goofy_lib = torch.library.Library("autogptq", "DEF") ++goofy_lib.define("autogptq::cast_to_uint4(Tensor t) -> Tensor") ++goofy_lib.impl("cast_to_uint4", cast_to_uint4) + + class QuantLinear(nn.Module): + QUANT_TYPE = "cuda-old" +@@ -96,6 +102,56 @@ class QuantLinear(nn.Module): + def post_init(self): + pass + ++ def unpack(self): ++ assert self.bits == 8 or self.bits == 4 ++ zeros = torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits) ++ zeros = torch.bitwise_right_shift(zeros, self.wf.unsqueeze(0)) ++ zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1).to(torch.uint8) ++ qzeros_unpacked = zeros.reshape(zeros.shape[0], zeros.shape[1] * zeros.shape[2]) ++ ++ weight = torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1) ++ weight = torch.bitwise_right_shift(weight, self.wf.unsqueeze(-1)) ++ weight = torch.bitwise_and(weight,(2 ** self.bits) - 1).to(torch.uint8) ++ qweight_unpacked = weight.reshape(self.infeatures, self.outfeatures) ++ ++ self.register_buffer('qweight_unpacked', qweight_unpacked) ++ self.register_buffer('qzeros_unpacked', qzeros_unpacked) ++ ++ def forward(self, x): ++ assert self.bits == 8 or self.bits == 4 ++ x_dtype = x.dtype ++ out_shape = x.shape[:-1] + (self.outfeatures,) ++ x = x.reshape(-1, x.shape[-1]) ++ ++ if self.wf.device != self.qzeros.device: ++ self.wf = self.wf.to(self.qzeros.device) ++ ++ if self.bits == 4: ++ zeros = torch.ops.autogptq.cast_to_uint4(self.qzeros_unpacked) ++ else: ++ zeros = self.qzeros_unpacked ++ zeros = zeros.reshape(-1, 1, zeros.shape[-1]) ++ ++ scales = self.scales ++ scales = scales.reshape(-1, 1, scales.shape[-1]) ++ ++ if self.bits == 4: ++ weight = torch.ops.autogptq.cast_to_uint4(self.qweight_unpacked) ++ else: ++ weight = self.qweight_unpacked ++ weight = weight.reshape(-1, self.group_size, weight.shape[-1]) ++ ++ # Multiply by `scales` separately to avoid overflow ++ # Cast to `int16` needed to avoid precision errors and match AutGPTQ behavior ++ bigger_dtype = torch.int16 if self.bits == 8 else torch.uint8 ++ weight = scales * (weight - (zeros + 1).to(bigger_dtype)) ++ weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) ++ out = torch.matmul(x, weight) ++ ++ out = out.to(dtype=x_dtype).reshape(out_shape) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output. ++ out = out + self.bias if self.bias is not None else out ++ return out ++ + def pack(self, linear, scales, zeros, g_idx): + W = linear.weight.data.clone() + if isinstance(linear, nn.Conv2d): +@@ -194,7 +250,7 @@ class QuantLinear(nn.Module): + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + +- def forward(self, x): ++ def forward_old(self, x): + x_dtype = x.dtype + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.reshape(-1, x.shape[-1]) +@@ -267,7 +323,7 @@ class QuantLinear(nn.Module): + weight = weight.reshape(-1, self.group_size, weight.shape[2]) + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") +- ++ + weight = (scales * (weight - zeros)) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + out = torch.matmul(x, weight)