diff --git a/quantize.py b/quantize.py index 0b89a1641..9f1c33954 100644 --- a/quantize.py +++ b/quantize.py @@ -22,6 +22,8 @@ # per_token_dynamic_quant, # ) +########################################################################## +### process quantization dictionary ### def quantize_model(model: nn.Module, quantize_options): """ @@ -77,7 +79,7 @@ def quantize_model(model: nn.Module, quantize_options): ######################################################################### -##### Quantization Primitives ###### +##### Quantization Primitives ###### def dynamically_quantize_per_channel( x, @@ -281,6 +283,8 @@ def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): ) ######################################################################### +### QuantHandler API definition ### + class QuantHandler: def __init__(self, mod): @@ -299,7 +303,8 @@ def quantized_model(self) -> nn.Module: return self.mod -##### Weight-only int8 per-channel quantized code ###### +######################################################################### +##### Weight-only int8 per-channel quantized code ###### def replace_linear_weight_only_int8_per_channel(module, node_type, group_size=None): @@ -447,7 +452,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, (weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1) * scales.view(weight.shape[0], no_groups, -1)).view(weight.shape[0], -1)) -##### embedding table quantization ###### +######################################################################### +##### embedding table quantization ###### def replace_embedding_weight_only_grouped_int8_per_channel( @@ -587,8 +593,8 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: # r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, )) -################################################################## -##### weight only int4 per channel groupwise quantized code ###### +######################################################################### +##### weight only int4 per channel groupwise quantized code ###### def _int4_prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): weight_int32, scales_and_zeros = group_quantize_tensor( @@ -727,8 +733,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) -######################################################################## -### Int8 Dynamic Activations 4 Bit Weights +######################################################################### +##### Int8 Dynamic Activations 4 Bit Weights ##### def prepare_int4_weight_and_scales_and_zeros(weight, group_size, precision): weight_int8, scales, zeros = group_quantize_tensor_symmetric( @@ -986,171 +992,190 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) -#### GPTQ ######## - -# try: -# from GPTQ import ( # pyre-ignore[21] -# evaluate, -# GenericGPTQRunner, -# get_task_dict, -# InputRecorder, -# lm_eval, -# MultiInput, -# ) - -# except: -# pass - - -# class GPTQQuantHandler(QuantHandler): -# """ -# This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. -# Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement -# __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. - -# The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and -# create_quantized_state_dict. Here is a description of each function. - -# get_qparams_func: -# A function that calculates the quantization qparams for an input tensor. -# Args: -# weight: A 2d weight tensor with non-integer dtype. -# Returns: -# qparams: it can have any format but will need to be handled by the other defined functions below. - -# quantize_func: -# A function that applies quantization to an input tensor. It should be noted -# that this function needs to be able to handle quantizing the entire weight tensor, a single group, -# or a single column. -# Args: -# weight: A 2d weight tensor with non-integer dtype. -# qparams: the output from get_qparams_func -# Returns: -# quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - - -# dequantize_func: -# A function that dequantizes an input quantized weight tensor. It should be noted -# that this function needs to be able to handle dequantizing the entire weight tensor, a single group, -# or a single column. -# Args: -# quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) -# qparams: the output from get_qparams_func -# Returns: -# weight: A 2d weight tensor with non-integer dtype. - -# combine_qparams_list_func: -# A function that combines several qparams into one qparam. -# Args: -# qparams_list: a list of qparams objects, each obtained by calling get_qparams_func -# on a single group from a weight tensor -# Returns: -# qparams: an object of the same format as the qparams above. - -# skip_layer_func: -# A function that determines which linear layers should be skipped during GPTQ -# Args: -# weight: A 2d weight tensor with non-integer dtype. -# Returns: -# skip: boolean indicating whether layer should be skipped - -# make_names_and_values_dict_func: -# A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they -# should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. -# Args: -# quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) -# qparams: the output from get_qparams_func -# Returns: -# names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the -# corresponding quantized weights and qparams. -# """ - -# def __init__(self): -# assert self.get_qparams_func is not None -# assert self.quantize_func is not None -# assert self.dequantize_func is not None -# assert self.combine_qparams_list_func is not None -# assert self.make_names_and_values_dict_func is not None - -# @staticmethod -# def get_inputs( -# model, -# tokenizer, -# calibration_tasks, -# calibration_limit, -# calibration_seq_length, -# pad_calibration_inputs, -# ) -> "MultiInput": # pyre-ignore[11] -# input_recorder = InputRecorder( -# model, -# tokenizer, -# calibration_seq_length, -# pad_calibration_inputs, -# ) +######################################################################### +##### GPTQ ##### -# try: -# lm_eval.tasks.initialize_tasks() -# except: -# pass -# task_dict = get_task_dict(calibration_tasks) -# print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) - -# evaluate( -# input_recorder, -# task_dict, -# limit=calibration_limit, -# ) -# inputs = input_recorder.get_recorded_inputs() -# assert inputs is not None, ( -# f"No inputs were collected, use a task other than {calibration_tasks}, " -# + "use option pad_calibration_inputs, or decrease calibration_sequence_length (currently " -# + f"{calibration_seq_length})" -# ) -# print(f"Obtained {len(inputs[0].values)} calibration samples") -# return inputs +class GPTQQuantHandler(QuantHandler): + """ + This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. + Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement + __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. + + The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and + create_quantized_state_dict. Here is a description of each function. + + get_qparams_func: + A function that calculates the quantization qparams for an input tensor. + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + qparams: it can have any format but will need to be handled by the other defined functions below. + + quantize_func: + A function that applies quantization to an input tensor. It should be noted + that this function needs to be able to handle quantizing the entire weight tensor, a single group, + or a single column. + Args: + weight: A 2d weight tensor with non-integer dtype. + qparams: the output from get_qparams_func + Returns: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + + + dequantize_func: + A function that dequantizes an input quantized weight tensor. It should be noted + that this function needs to be able to handle dequantizing the entire weight tensor, a single group, + or a single column. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + weight: A 2d weight tensor with non-integer dtype. + + combine_qparams_list_func: + A function that combines several qparams into one qparam. + Args: + qparams_list: a list of qparams objects, each obtained by calling get_qparams_func + on a single group from a weight tensor + Returns: + qparams: an object of the same format as the qparams above. + + skip_layer_func: + A function that determines which linear layers should be skipped during GPTQ + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + skip: boolean indicating whether layer should be skipped + + make_names_and_values_dict_func: + A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they + should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the + corresponding quantized weights and qparams. + """ + def __init__(self): + assert self.mod is not None + assert self.get_qparams_func is not None + assert self.quantize_func is not None + assert self.dequantize_func is not None + assert self.combine_qparams_list_func is not None + assert self.make_names_and_values_dict_func is not None + + @staticmethod + def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": + input_recorder = InputRecorder( + model, + tokenizer, + calibration_seq_length, + pad_calibration_inputs, + ) -# @torch.no_grad() -# def create_quantized_state_dict( -# self, -# tokenizer, -# blocksize, -# percdamp, -# groupsize, -# calibration_tasks, -# calibration_limit, -# calibration_seq_length, -# pad_calibration_inputs, -# ) -> Dict: -# inputs = GPTQQuantHandler.get_inputs( -# self.mod, -# tokenizer, -# calibration_tasks, -# calibration_limit, -# calibration_seq_length, -# pad_calibration_inputs, -# ) -# print("Tracing model for GPTQ") -# GPTQ_runner = GenericGPTQRunner( -# self.mod, -# inputs, -# blocksize, -# percdamp, -# groupsize, -# ).configure_quantization_mode( -# self.get_qparams_func, # pyre-ignore[16] -# self.quantize_func, # pyre-ignore[16] -# self.dequantize_func, # pyre-ignore[16] -# self.combine_qparams_list_func, # pyre-ignore[16] -# self.make_names_and_values_dict_func, # pyre-ignore[16] -# self.skip_layer_func, # pyre-ignore[16] -# ) + try: + lm_eval.tasks.initialize_tasks() + except: + pass + task_dict = get_task_dict(calibration_tasks) + print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) + + evaluate( + input_recorder, + task_dict, + limit=calibration_limit, + ) + inputs = input_recorder.get_recorded_inputs() + assert inputs is not None, ( + f"No inputs were collected, use a task other than {calibration_tasks}, "+ + f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ + f"{calibration_seq_length})" + ) + print(f"Obtained {len(inputs[0].values)} calibration samples") + return inputs + + @torch.no_grad() + def create_quantized_state_dict( + self, + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) -> "StateDict": + inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) + print("Tracing model for GPTQ") + GPTQ_runner = GenericGPTQRunner( + self.mod, + inputs, + blocksize, + percdamp, + groupsize, + ).configure_quantization_mode( + self.get_qparams_func, + self.quantize_func, + self.dequantize_func, + self.combine_qparams_list_func, + self.make_names_and_values_dict_func, + self.skip_layer_func + ) + + print("Applying GPTQ to weights") + GPTQ_runner.run() + return GPTQ_runner.get_quantized_state_dict() + + def convert_for_runtime(self) -> "nn.Module": + pass -# print("Applying GPTQ to weights") -# GPTQ_runner.run() -# return GPTQ_runner.get_quantized_state_dict() -# def convert_for_runtime(self) -> "nn.Module": -# pass +class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + from model import find_multiple + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) + self.quantize_func = lambda w, qparams: \ + group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) + self.dequantize_func = lambda q, qparams: \ + group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() + self.combine_qparams_list_func = lambda qparams_list: \ + [torch.cat(x, dim=1) for x in zip(*qparams_list)] + # skip unless padding=True or its correctly sized + self.skip_layer_func = lambda linear_weight: not ( + _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding + ) + # we need to do the padding here, both for q and the qparams if necessary + def make_names_and_values_dict_func(q, qparams): + k = q.shape[1] + new_k = find_multiple(k, 1024) + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + scales_and_zeros = pack_scales_and_zeros(*qparams) + # how many new groups we need for padded weight + delta_groups = new_k // groupsize - scales_and_zeros.shape[0] + final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + return {"weight": final_q, "scales_and_zeros": final_s_and_z} + self.make_names_and_values_dict_func = make_names_and_values_dict_func + super().__init__() + + + def convert_for_runtime(self, use_cuda): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda) + return self.mod + + def quantized_model(self) -> nn.Module: + model_updated_state_dict = self.create_quantized_state_dict() + self.convert_for_runtime() + self.mod.load_state_dict(model_updated_state_dict) + return self.mod + # class Int8DynActInt4WeightGPTQQuantHandler(GPTQQuantHandler):