Skip to content

Commit

Permalink
Fix for Sparsity Persist (#2323)
Browse files Browse the repository at this point in the history
* fix sparsity persist

* helper moved to compressed-tensors
  • Loading branch information
Sara Adkins authored Jun 11, 2024
1 parent 934f0d8 commit e255b17
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 46 deletions.
43 changes: 20 additions & 23 deletions src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def fasterprune(
W = W.t()
W = W.float()

sparsity = tensor_sparsity(W)
preserve_zeros = sparsity >= SPARSITY_THRESHOLD
W_nz_mask = (
(~torch.isclose(W, torch.zeros(1, device=W.device).float())).float()
if preserve_zeros
else None
)

tick = time.time()

dead = torch.diag(self.H) == 0
Expand All @@ -119,17 +127,6 @@ def fasterprune(
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H

sparsity = tensor_sparsity(W)
mask = (
torch.where(
W == 0,
torch.tensor(1, dtype=torch.bool),
torch.tensor(0, dtype=torch.bool),
)
if sparsity >= SPARSITY_THRESHOLD
else None
)

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
Expand All @@ -141,21 +138,13 @@ def fasterprune(
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

if sparsity >= SPARSITY_THRESHOLD:
tmp = (
(~mask[:, i1:i2])
* W1**2
/ (torch.diag(Hinv1).reshape((1, -1))) ** 2
)
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
mask1 = tmp <= thresh
if preserve_zeros:
W1_nz_mask = W_nz_mask[:, i1:i2]

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
q = w.clone()
if sparsity >= SPARSITY_THRESHOLD:
q[mask1[:, i]] = 0

if hasattr(self.layer, "weight_fake_quant"):
scale = self.layer.weight_fake_quant.scale
Expand Down Expand Up @@ -216,13 +205,21 @@ def fasterprune(
Losses1[:, i] = (w - q) ** 2 / d**2

err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
if preserve_zeros:
W1[:, i:] -= w1_err * W1_nz_mask[:, i:]
else:
W1[:, i:] -= w1_err
Err1[:, i] = err1

W[:, i1:i2] = Q1
Losses += torch.sum(Losses1, 1) / 2

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
w_err = Err1.matmul(Hinv[i1:i2, i2:])
if preserve_zeros:
W[:, i2:] -= w_err * W_nz_mask[:, i2:]
else:
W[:, i2:] -= w_err

_LOGGER.info("time %.2f" % (time.time() - tick))
_LOGGER.info("error %.2f" % torch.sum(Losses).item())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

import sparseml
from compressed_tensors.compressors.utils import tensor_follows_mask_structure
from parameterized import parameterized_class
from tests.testing_utils import parse_params, requires_torch

Expand All @@ -28,29 +29,6 @@
)


def tensor_follows_mask_structure(tensor, mask: str = "2:4"):
"""
:param tensor: tensor to check
:param mask: mask structure to check for, in the format "n:m"
:return: True if the tensor follows the mask structure, False otherwise.
Note, some weights can incidentally be zero, so we check for
atleast n zeros in each chunk of size m
"""
import torch

n, m = tuple(map(int, mask.split(":")))
# Reshape the tensor into chunks of size m
tensor = tensor.view(-1, m)

# Count the number of zeros in each chunk
zero_counts = (tensor == 0).sum(dim=1)

# Check if the number of zeros in each chunk atleast n
# Greater than sign is needed as some weights can incidentally
# be zero
return torch.all(zero_counts >= n)


@requires_torch
@pytest.mark.integration
@parameterized_class(parse_params(MASK_STRUCTURE_CONFIGS_DIRECTORY))
Expand Down

0 comments on commit e255b17

Please sign in to comment.