diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 30a13196d9..29927c832a 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -137,6 +137,14 @@ def fasterprune( if sparsity >= SPARSITY_THRESHOLD else None ) + + g_idx = [] + if actorder: + g_idx = [perm[i] // quant_scheme.weights.group_size for i in range(self.columns)] + g_idx = g_idx[invperm] + else: + g_idx = [i // quant_scheme.weights.group_size for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=W.device) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): @@ -148,6 +156,15 @@ def fasterprune( Err1 = torch.zeros_like(W1) Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] + + # """ + # if not channel wise + + # strategy = quant_scheme.weights.strategy + # if strategy is not QuantizationStrategy.CHANNEL: + # idx = i + + # """ if sparsity >= SPARSITY_THRESHOLD: tmp = ( @@ -176,6 +193,7 @@ def fasterprune( else: q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype) q = torch.dequantize(q) + elif hasattr(self.layer, "quantization_scheme"): quant_scheme = self.layer.quantization_scheme if quant_scheme.weights is not None: @@ -235,9 +253,11 @@ def fasterprune( _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) - + + if actorder: W = W[:, invperm] + # g_idx = g_idx[invperm] if isinstance(self.layer, transformers.Conv1D): W = W.t() @@ -247,6 +267,7 @@ def fasterprune( # place, clone() or direct assignment won't work self.layer.weight -= self.layer.weight self.layer.weight += W + self.g_idx = g_idx def free(self): """