Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

udpate llama7b_sparse_quantized example #2322

Merged
merged 6 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions examples/llama7b_one_shot_quantization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Creating a Quantized Llama Model in One Shot

Quantizing a model to a lower precision can save on both memory and speed at inference time.
This example demonstrates how to use the SparseML API to quantize a Llama model from 16 bits
to 4 bits and save it to a compressed-tensors format for inference with vLLM.

## Step 1: Select a model and dataset
For this example, we will use a TinyLlama model and the open platypus dataset, however
these can be swapped out for any huggingface compatible models and datasets

```python
model = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
dataset = "open_platypus"
```

## Step 2: Configure a `GPTQModifier`
Modifiers in sparseml are used to apply optimizations to models. In this example we use a
`GPTQModifier` to apply the GPTQ algorithm to our model. We target all `Linear` layers
for 4-bit weight quantization. These options may be swapped out for any valid `QuantizationScheme`.

```python
from sparseml.modifiers.quantization.gptq import GPTQModifier
bfineran marked this conversation as resolved.
Show resolved Hide resolved

gptq = GPTQModifier(
Satrat marked this conversation as resolved.
Show resolved Hide resolved
targets="Linear",
scheme="W4A16"
)
```


### Step3: One-Shot Compression

The `oneshot` api applies the created modifier to the target model and dataset.
Setting `save_compressed` to True runs the model through `compressed_tensors` compression
after the quantization is completed.

```python
from sparseml.transformers import oneshot

oneshot(
model=model,
dataset=dataset,
recipe=gptq,
save_compressed=True,
output_dir="llama-compressed-example",
overwrite_output_dir=True,
max_seq_length=256,
num_calibration_samples=256,
)
```
80 changes: 56 additions & 24 deletions examples/llama7b_sparse_quantized/README.md
Original file line number Diff line number Diff line change
@@ -1,52 +1,84 @@
# Creating a Sparse Quantized Llama7b Model

The example in this folder runs in multiple stages to create a Llama 7b model with
a 2:4 sparsity pattern and W4A16 post training quantization (PTW). The model is
calibrated and trained with the ultachat200k dataset. At least 75GB of GPU memory is
required to run this example.
This example uses SparseML and Compressed-Tensors to create a 2:4 sparse and quantized Llama2-7b model.
The model is calibrated and trained with the ultachat200k dataset.
At least 75GB of GPU memory is required to run this example.

## Recipe Summary
Follow the steps below, or to run the example as `python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py`

The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml). It contains 3 stages that are outlined below.
## Step 1: Select a model, dataset, and recipe
In this step, we select which model to use as a baseline for sparsification, a dataset to
use for calibration and finetuning, and a recipe.

Models can reference a local directory, model in the huggingface hub, or in the sparsezoo.

### Stage 1: Sparsification
Datasets can be from a local compatible directory or the huggingface hub.

Runs the SparseGPT one-shot algorithm to prune the model to 50% sparsity with a 2:4
sparsity pattern. This means that 2 weights out of every group of 4 weights are masked to 0.
Recipes are YAML files that describe how a model should be optimized during or after training.
The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml).
It contains instructions to prune the model to 2:4 sparsity, run one epoch of recovery finetuning,
and quantize to 4 bits in one show using GPTQ.

### Stage 2: Finetuning Recovery

This stage runs a single epoch of training on the ultrachat200k dataset while maintaining
the sparsity mask from stage 1. The purpose of this stage is to recover any accuracy lost
during the sparsification process.
```python
import torch
from sparseml.transformers import SparseAutoModelForCausalLM

### Stage 3: Quantization
model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base"
model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

Finally, we run the GPTQ one-shot algorithm to quantize all linear weights to 4 bit
channelwise.
dataset = "ultrachat-200k"
splits = {"calibration": "train_gen[:5%]", "train": "train_gen"}
Satrat marked this conversation as resolved.
Show resolved Hide resolved

## How to Run
recipe = "2:4_w4a16_recipe.yaml"
```

We can run the entire staged recipe with one call to SparseML's `apply` pathway. This
will save a checkpoint of the model after each stage.
## Step 2: Run sparsification using `apply`
The `apply` function applies the given recipe to our model and dataset.
The hardcoded kwargs may be altered based on each model's needs.
After running, the sparsified model will be saved to `output_llama7b_2:4_w4a16_channel`.

```python
from sparseml.transformers import apply

output_dir = "output_llama7b_2:4_w4a16_channel"

apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of arguments here is very confusing, especially since most of these are related to training...

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked to Ben and he is going to write up a README of just quantization without the training. This one is intended to be a more advanced readme showing how to do the full sparsity -> finetuning -> quantization flow

model=model,
dataset=dataset,
recipe=recipe,
bf16=False, # use full precision for training
output_dir=output_dir,
splits=splits,
max_seq_length=512,
num_calibration_samples=512,
num_train_epochs=0.5,
logging_steps=500,
save_steps=5000,
gradient_checkpointing=True,
learning_rate=0.0001,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
)
```

```python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py```

### Compression
### Step 3: Compression

The resulting model will be uncompressed. To save a final compressed copy of the model
run the following:

```
```python
import torch
from sparseml.transformers import SparseAutoModelForCausalLM

compressed_output_dir = "output_llama7b_2:4_w4a16_channel_compressed"
model = SparseAutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16)
model.save_pretrained(compressed_output_dir, save_compressed=True)
```

### Custom Quantization
The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are `tensor`, `group` and `channel`.
The above recipe (`2:4_w4a16_recipe.yaml`) uses channel-wise quantization specified by `strategy: "channel"` in its config group.
To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. Group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml`
To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. A group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml`.
42 changes: 22 additions & 20 deletions src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from pydantic import Field

from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
is_preset_scheme,
preset_name_to_scheme,
)
from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
Expand Down Expand Up @@ -77,6 +77,7 @@ class GPTQModifier(Modifier):
QuantizationScheme except targets, which will be set to the targets parameter
set at the modifier level. Can also be set to a dictionary of the format
`preset_scheme_name: targets` for example: `W8A8: ['Linear']` for weight 8 bit
or a string of a preset scheme if targets is provided
and activation 8 bit quantization on the Linear layers.
"""

Expand All @@ -89,7 +90,7 @@ class GPTQModifier(Modifier):
ignore: List[str] = Field(default_factory=list)
disable_quantization_observer_epoch: Optional[float] = None
num_calibration_steps: Optional[int] = None
scheme: Optional[Dict[str, Any]] = None
scheme: Optional[Union[str, Dict[str, Any]]] = None
compressible_layers_: Optional[List] = None
quantization_modifier_: Any = None

Expand Down Expand Up @@ -167,32 +168,33 @@ def _build_quant_modifier(self, framework):
if getattr(self, key, False)
}

if isinstance(self.targets, str):
self.targets = [self.targets]

if self.scheme is not None:
# takes precedence over config_groups

if any(is_preset_scheme(key) for key in self.scheme.keys()):
config_groups = QuantizationConfig(
config_groups=self.scheme
).config_groups
quant_args["config_groups"] = config_groups
else:
targets = self.targets or ["Linear"]
config_group = QuantizationScheme.model_validate(
{"targets": targets, **self.scheme}
)
quant_args["config_groups"] = {"config_group_0": config_group}
if isinstance(self.scheme, str) and is_preset_scheme(self.scheme):
# attach targets to scheme
self.scheme = {self.scheme: self.targets}

targets = self.targets or ["Linear"]
config_group = QuantizationScheme.model_validate(
{"targets": targets, **self.scheme}
)
quant_args["config_groups"] = {"config_group_0": config_group}
quant_args["config_groups"] = {}
for idx, key in enumerate(self.scheme.keys()):
if is_preset_scheme(key):
scheme = preset_name_to_scheme(key, self.scheme[key])
else:
scheme = QuantizationScheme.model_validate(
{"targets": self.scheme[key], **self.scheme}
)

group_name = f"group_{idx}"
quant_args["config_groups"][group_name] = scheme

if "config_groups" not in quant_args:
if "config_groups" not in quant_args or len("config_groups") == 0:
default_quant_scheme = QuantizationScheme.default_scheme(
targets=self.targets
)
quant_args["config_groups"] = {"config_group_0": default_quant_scheme}
quant_args["config_groups"] = {"group_0": default_quant_scheme}
_LOGGER.info(f"Building quantization modifier with args: {quant_args}")
vllm_quant_config = {"QuantizationModifier": quant_args}
self._build_quant_modifier_from_dict(vllm_quant_config, framework)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_create_default_quant_modifier(self):
modifier.on_initialize_structure(testing_harness.get_state())
assert modifier.quantize
assert isinstance(modifier.quantization_modifier_, QuantizationModifier)
default_config_group_name = "config_group_0"
default_config_group_name = "group_0"
should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[
default_config_group_name
]
Expand Down
80 changes: 60 additions & 20 deletions tests/sparseml/transformers/gptq/test_oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,57 @@
import shutil
import unittest

from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
from parameterized import parameterized_class
from sparseml.modifiers.quantization.gptq import GPTQModifier
from sparseml.transformers.sparsification.sparse_model import SparseAutoModelForCausalLM
from tests.testing_utils import requires_torch


recipe_str = """
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "channel"
targets: ["Linear"]
"""

recipe_modifier_full = GPTQModifier(
ignore=["lm_head"],
sequential_update=False,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"], weights=QuantizationArgs(num_bits=4, strategy="channel")
)
},
)

recipe_modifier_shorthand_a = GPTQModifier(
ignore=["lm_head"], sequential_update=False, targets="Linear", scheme="W4A16"
)

recipe_modifier_shorthand_b = GPTQModifier(
ignore=["lm_head"], sequential_update=False, scheme={"W4A16": ["Linear"]}
)


@requires_torch
@parameterized_class(
[
{"recipe": recipe_str},
{"recipe": recipe_modifier_full},
{"recipe": recipe_modifier_shorthand_a},
{"recipe": recipe_modifier_shorthand_b},
]
)
class TestGPTQOneShotWithFullScheme(unittest.TestCase):
def setUp(self):
import torch
Expand All @@ -30,26 +76,6 @@ def setUp(self):
self.dataset = "open_platypus"
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

self.recipe = """
first_stage:
quant_modifiers:
GPTQModifier:
ignore: ["lm_head"]
sequential_update: True
dampening_frac: 0.001
block_size: 128
targets: ["Linear"]
scheme:
input_activations: null
output_activations: null
weights:
num_bits: 8
type: "int"
symmetric: true
strategy: "tensor"
group_size: 128
"""

def test_oneshot_application(self):
from sparseml.transformers import oneshot

Expand All @@ -68,9 +94,23 @@ def test_oneshot_application(self):
# Check that the model is quantized
assert model_loaded.quantization_config is not None

# check config is set properly
assert model_loaded.quantization_config.ignore == ["lm_head"]
assert len(model_loaded.quantization_config.config_groups) == 1
quant_scheme = model_loaded.quantization_config.config_groups["group_0"]
assert isinstance(quant_scheme, QuantizationScheme)
assert quant_scheme.targets == ["Linear"]
weight_args = model_loaded.quantization_config.config_groups["group_0"].weights
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == 4

# Check a specific layer is quantized
targetted_linear_layer = model_loaded.transformer.h[0].attn.attention.k_proj
assert hasattr(targetted_linear_layer, "quantization_scheme")

# Check lm-head is not quantized
not_targetted = model_loaded.lm_head
assert not hasattr(not_targetted, "quantization_scheme")

def tearDown(self):
shutil.rmtree(self.output)
Loading