Skip to content

Commit

Permalink
add g_idx
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jun 24, 2024
1 parent d39e7f9 commit 828e185
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ def fasterprune(
if sparsity >= SPARSITY_THRESHOLD
else None
)

g_idx = []
if actorder:
g_idx = [perm[i] // quant_scheme.weights.group_size for i in range(self.columns)]
g_idx = g_idx[invperm]
else:
g_idx = [i // quant_scheme.weights.group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=W.device)

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
Expand All @@ -148,6 +156,15 @@ def fasterprune(
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

# """
# if not channel wise

# strategy = quant_scheme.weights.strategy
# if strategy is not QuantizationStrategy.CHANNEL:
# idx = i

# """

if sparsity >= SPARSITY_THRESHOLD:
tmp = (
Expand Down Expand Up @@ -176,6 +193,7 @@ def fasterprune(
else:
q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype)
q = torch.dequantize(q)

elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
Expand Down Expand Up @@ -235,9 +253,11 @@ def fasterprune(

_LOGGER.info("time %.2f" % (time.time() - tick))
_LOGGER.info("error %.2f" % torch.sum(Losses).item())



if actorder:
W = W[:, invperm]
# g_idx = g_idx[invperm]

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
Expand All @@ -247,6 +267,7 @@ def fasterprune(
# place, clone() or direct assignment won't work
self.layer.weight -= self.layer.weight
self.layer.weight += W
self.g_idx = g_idx

def free(self):
"""
Expand Down

0 comments on commit 828e185

Please sign in to comment.