Skip to content

Commit

Permalink
Fix GPTQ Aliases (#2327)
Browse files Browse the repository at this point in the history
* fix alias application with unit tests

* style
  • Loading branch information
Sara Adkins authored Jun 12, 2024
1 parent d637be9 commit 1a2d534
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 40 deletions.
32 changes: 13 additions & 19 deletions src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from pydantic import Field

from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
is_preset_scheme,
preset_name_to_scheme,
)
from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
Expand Down Expand Up @@ -178,29 +178,23 @@ def _build_quant_modifier(self, framework):
# attach targets to scheme
self.scheme = {self.scheme: self.targets}

if any(is_preset_scheme(key) for key in self.scheme.keys()):
config_groups = QuantizationConfig(
config_groups=self.scheme
).config_groups
quant_args["config_groups"] = config_groups
else:
targets = self.targets or ["Linear"]
config_group = QuantizationScheme.model_validate(
{"targets": targets, **self.scheme}
)
quant_args["config_groups"] = {"config_group_0": config_group}
quant_args["config_groups"] = {}
for idx, key in enumerate(self.scheme.keys()):
if is_preset_scheme(key):
scheme = preset_name_to_scheme(key, self.scheme[key])
else:
scheme = QuantizationScheme.model_validate(
{"targets": self.scheme[key], **self.scheme}
)

targets = self.targets or ["Linear"]
config_group = QuantizationScheme.model_validate(
{"targets": targets, **self.scheme}
)
quant_args["config_groups"] = {"config_group_0": config_group}
group_name = f"group_{idx}"
quant_args["config_groups"][group_name] = scheme

if "config_groups" not in quant_args:
if "config_groups" not in quant_args or len("config_groups") == 0:
default_quant_scheme = QuantizationScheme.default_scheme(
targets=self.targets
)
quant_args["config_groups"] = {"config_group_0": default_quant_scheme}
quant_args["config_groups"] = {"group_0": default_quant_scheme}
_LOGGER.info(f"Building quantization modifier with args: {quant_args}")
vllm_quant_config = {"QuantizationModifier": quant_args}
self._build_quant_modifier_from_dict(vllm_quant_config, framework)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_create_default_quant_modifier(self):
modifier.on_initialize_structure(testing_harness.get_state())
assert modifier.quantize
assert isinstance(modifier.quantization_modifier_, QuantizationModifier)
default_config_group_name = "config_group_0"
default_config_group_name = "group_0"
should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[
default_config_group_name
]
Expand Down
80 changes: 60 additions & 20 deletions tests/sparseml/transformers/gptq/test_oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,57 @@
import shutil
import unittest

from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
from parameterized import parameterized_class
from sparseml.modifiers.quantization.gptq import GPTQModifier
from sparseml.transformers.sparsification.sparse_model import SparseAutoModelForCausalLM
from tests.testing_utils import requires_torch


recipe_str = """
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "channel"
targets: ["Linear"]
"""

recipe_modifier_full = GPTQModifier(
ignore=["lm_head"],
sequential_update=False,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"], weights=QuantizationArgs(num_bits=4, strategy="channel")
)
},
)

recipe_modifier_shorthand_a = GPTQModifier(
ignore=["lm_head"], sequential_update=False, targets="Linear", scheme="W4A16"
)

recipe_modifier_shorthand_b = GPTQModifier(
ignore=["lm_head"], sequential_update=False, scheme={"W4A16": ["Linear"]}
)


@requires_torch
@parameterized_class(
[
{"recipe": recipe_str},
{"recipe": recipe_modifier_full},
{"recipe": recipe_modifier_shorthand_a},
{"recipe": recipe_modifier_shorthand_b},
]
)
class TestGPTQOneShotWithFullScheme(unittest.TestCase):
def setUp(self):
import torch
Expand All @@ -30,26 +76,6 @@ def setUp(self):
self.dataset = "open_platypus"
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

self.recipe = """
first_stage:
quant_modifiers:
GPTQModifier:
ignore: ["lm_head"]
sequential_update: True
dampening_frac: 0.001
block_size: 128
targets: ["Linear"]
scheme:
input_activations: null
output_activations: null
weights:
num_bits: 8
type: "int"
symmetric: true
strategy: "tensor"
group_size: 128
"""

def test_oneshot_application(self):
from sparseml.transformers import oneshot

Expand All @@ -68,9 +94,23 @@ def test_oneshot_application(self):
# Check that the model is quantized
assert model_loaded.quantization_config is not None

# check config is set properly
assert model_loaded.quantization_config.ignore == ["lm_head"]
assert len(model_loaded.quantization_config.config_groups) == 1
quant_scheme = model_loaded.quantization_config.config_groups["group_0"]
assert isinstance(quant_scheme, QuantizationScheme)
assert quant_scheme.targets == ["Linear"]
weight_args = model_loaded.quantization_config.config_groups["group_0"].weights
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == 4

# Check a specific layer is quantized
targetted_linear_layer = model_loaded.transformer.h[0].attn.attention.k_proj
assert hasattr(targetted_linear_layer, "quantization_scheme")

# Check lm-head is not quantized
not_targetted = model_loaded.lm_head
assert not hasattr(not_targetted, "quantization_scheme")

def tearDown(self):
shutil.rmtree(self.output)

0 comments on commit 1a2d534

Please sign in to comment.