From 4057ba567ff633f7a5746bd90f07d2cc9024f7c8 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 7 Jun 2024 12:49:08 -0400 Subject: [PATCH 1/4] udpate llama7b_sparse_quantized example --- examples/llama7b_sparse_quantized/README.md | 84 ++++++++++++++------- 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/examples/llama7b_sparse_quantized/README.md b/examples/llama7b_sparse_quantized/README.md index 59a8b98bca6..af86973ffa1 100644 --- a/examples/llama7b_sparse_quantized/README.md +++ b/examples/llama7b_sparse_quantized/README.md @@ -1,46 +1,76 @@ # Creating a Sparse Quantized Llama7b Model -The example in this folder runs in multiple stages to create a Llama 7b model with -a 2:4 sparsity pattern and W4A16 post training quantization (PTW). The model is -calibrated and trained with the ultachat200k dataset. At least 75GB of GPU memory is -required to run this example. +This example uses SparseML and Compressed-Tensors to create a 2:4 sparse and quantized Llama2-7b model. +The model is calibrated and trained with the ultachat200k dataset. +At least 75GB of GPU memory is required to run this example. -## Recipe Summary +Follow the steps below, or to run the example as `python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py` -The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml). It contains 3 stages that are outlined below. +## Step 1: Select a model, dataset, and recipe +In this step, we select which model to use as a baseline for sparsification, a dataset to +use for calibration and finetuning, and a recipe. +Models can reference a local directory, model in the huggingface hub, or in the sparsezoo. -### Stage 1: Sparsification +Datasets can be from a local compatible directory or the huggingface hub. -Runs the SparseGPT one-shot algorithm to prune the model to 50% sparsity with a 2:4 -sparsity pattern. This means that 2 weights out of every group of 4 weights are masked to 0. +Recipes are YAML files that describe how a model should be optimized during or after training. +The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml). +It contains instructions to prune the model to 2:4 sparsity, run one epoch of recovery finetuning, +and quantize to 4 bits in one show using GPTQ. -### Stage 2: Finetuning Recovery - -This stage runs a single epoch of training on the ultrachat200k dataset while maintaining -the sparsity mask from stage 1. The purpose of this stage is to recover any accuracy lost -during the sparsification process. +```python +import torch +from sparseml.transformers import SparseAutoModelForCausalLM -### Stage 3: Quantization +model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) -Finally, we run the GPTQ one-shot algorithm to quantize all linear weights to 4 bit -channelwise. +dataset = "ultrachat-200k" +splits = {"calibration": "train_gen[:5%]", "train": "train_gen"} -## How to Run +recipe = "2:4_w4a16_recipe.yaml" +``` -We can run the entire staged recipe with one call to SparseML's `apply` pathway. This -will save a checkpoint of the model after each stage. +## Step 2: Run sparsification using `apply` +The `apply` function applies the given recipe to our model and dataset. +The hardcoded kwargs may be altered based on each model's needs. +After running, the sparsified model will be saved to `output_llama7b_2:4_w4a16_channel`. + +```python +from sparseml.transformers import apply + +output_dir = "output_llama7b_2:4_w4a16_channel" + +apply( + model=model, + dataset=dataset, + recipe=recipe, + bf16=False, # use full precision for training + output_dir=output_dir, + splits=splits, + max_seq_length=512, + num_calibration_samples=512, + num_train_epochs=0.5, + logging_steps=500, + save_steps=5000, + gradient_checkpointing=True, + learning_rate=0.0001, + lr_scheduler_type="cosine", + warmup_ratio=0.1, +) +``` -```python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py``` -### Compression +### Step3: Compression The resulting model will be uncompressed. To save a final compressed copy of the model run the following: -``` -import torch -from sparseml import SparseAutoModelForCausalLM +```python +compressed_output_dir = "output_llama7b_2:4_w4a16_channel_compressed" model = SparseAutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16) model.save_pretrained(compressed_output_dir, save_compressed=True) @@ -49,4 +79,6 @@ model.save_pretrained(compressed_output_dir, save_compressed=True) ### Custom Quantization The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are `tensor`, `group` and `channel`. The above recipe (`2:4_w4a16_recipe.yaml`) uses channel-wise quantization specified by `strategy: "channel"` in its config group. -To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. Group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml` \ No newline at end of file +To use quantize per tensor, change strategy from `channel` to `tensor`. +To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. +Group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml` \ No newline at end of file From 7c53e0c17cc4845b75646b94c63a590886606c91 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 7 Jun 2024 17:15:58 -0400 Subject: [PATCH 2/4] one shot llama example --- examples/llama7b_one_shot_quantization.md | 50 +++++++++++++++++++ .../modifiers/quantization/gptq/base.py | 10 +++- 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 examples/llama7b_one_shot_quantization.md diff --git a/examples/llama7b_one_shot_quantization.md b/examples/llama7b_one_shot_quantization.md new file mode 100644 index 00000000000..d3ee50e1aaf --- /dev/null +++ b/examples/llama7b_one_shot_quantization.md @@ -0,0 +1,50 @@ +# Creating a Quantized Llama Model in One Shot + +Quantizing a model to a lower precision can save on both memory and speed at inference time. +This example demonstrates how to use the SparseML API to quantize a Llama model from 16 bits +to 4 bits and save it to a compressed-tensors format for inference with vLLM. + +## Step 1: Select a model and dataset +For this example, we will use a TinyLlama model and the open platypus dataset, however +these can be swapped out for any huggingface compatible models and datasets + +```python +model = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset = "open_platypus" +``` + +## Step 2: Configure a `GPTQModifier` +Modifiers in sparseml are used to apply optimizations to models. In this example we use a +`GPTQModifier` to apply the GPTQ algorithm to our model. We target all `Linear` layers +for 4-bit weight quantization. These options may be swapped out for any valid `QuantizationScheme`. + +```python +from sparseml.modifiers.quantization.gptq import GPTQModifier + +gptq = GPTQModifier( + targets="Linear", + scheme="W4A16" +) +``` + + +### Step3: One-Shot Compression + +The `oneshot` api applies the created modifier to the target model and dataset. +Setting `save_compressed` to True runs the model through `compressed_tensors` compression +after the quantization is completed. + +```python +from sparseml.transformers import oneshot + +oneshot( + model=model, + dataset=dataset, + recipe=gptq, + save_compressed=True, + output_dir="llama-compressed-example", + overwrite_output_dir=True, + max_seq_length=256, + num_calibration_samples=256, +) +``` diff --git a/src/sparseml/modifiers/quantization/gptq/base.py b/src/sparseml/modifiers/quantization/gptq/base.py index 004fce2ee7a..2e565e77d91 100644 --- a/src/sparseml/modifiers/quantization/gptq/base.py +++ b/src/sparseml/modifiers/quantization/gptq/base.py @@ -77,6 +77,7 @@ class GPTQModifier(Modifier): QuantizationScheme except targets, which will be set to the targets parameter set at the modifier level. Can also be set to a dictionary of the format `preset_scheme_name: targets` for example: `W8A8: ['Linear']` for weight 8 bit + or a string of a preset scheme if targets is provided and activation 8 bit quantization on the Linear layers. """ @@ -89,7 +90,7 @@ class GPTQModifier(Modifier): ignore: List[str] = Field(default_factory=list) disable_quantization_observer_epoch: Optional[float] = None num_calibration_steps: Optional[int] = None - scheme: Optional[Dict[str, Any]] = None + scheme: Optional[Union[str, Dict[str, Any]]] = None compressible_layers_: Optional[List] = None quantization_modifier_: Any = None @@ -167,9 +168,16 @@ def _build_quant_modifier(self, framework): if getattr(self, key, False) } + if isinstance(self.targets, str): + self.targets = [self.targets] + if self.scheme is not None: # takes precedence over config_groups + if isinstance(self.scheme, str) and is_preset_scheme(self.scheme): + # attach targets to scheme + self.scheme = {self.scheme: self.targets} + if any(is_preset_scheme(key) for key in self.scheme.keys()): config_groups = QuantizationConfig( config_groups=self.scheme From 3de03c2abba7468a1f581a7643fc46ed5353b814 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Mon, 10 Jun 2024 11:34:13 -0400 Subject: [PATCH 3/4] Update examples/llama7b_sparse_quantized/README.md Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> --- examples/llama7b_sparse_quantized/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama7b_sparse_quantized/README.md b/examples/llama7b_sparse_quantized/README.md index af86973ffa1..939cdcda39f 100644 --- a/examples/llama7b_sparse_quantized/README.md +++ b/examples/llama7b_sparse_quantized/README.md @@ -64,7 +64,7 @@ apply( ``` -### Step3: Compression +### Step 3: Compression The resulting model will be uncompressed. To save a final compressed copy of the model run the following: From 1a2d5340389ced9b0a53c1f15500ddcd5d1a8634 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 12 Jun 2024 11:19:44 -0400 Subject: [PATCH 4/4] Fix GPTQ Aliases (#2327) * fix alias application with unit tests * style --- .../modifiers/quantization/gptq/base.py | 32 +++----- .../pruning/sparsegpt/test_pytorch.py | 2 +- .../transformers/gptq/test_oneshot.py | 80 ++++++++++++++----- 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/src/sparseml/modifiers/quantization/gptq/base.py b/src/sparseml/modifiers/quantization/gptq/base.py index 2e565e77d91..43bc596d849 100644 --- a/src/sparseml/modifiers/quantization/gptq/base.py +++ b/src/sparseml/modifiers/quantization/gptq/base.py @@ -18,9 +18,9 @@ from pydantic import Field from compressed_tensors.quantization import ( - QuantizationConfig, QuantizationScheme, is_preset_scheme, + preset_name_to_scheme, ) from sparseml.core import Modifier from sparseml.core.factory import ModifierFactory @@ -178,29 +178,23 @@ def _build_quant_modifier(self, framework): # attach targets to scheme self.scheme = {self.scheme: self.targets} - if any(is_preset_scheme(key) for key in self.scheme.keys()): - config_groups = QuantizationConfig( - config_groups=self.scheme - ).config_groups - quant_args["config_groups"] = config_groups - else: - targets = self.targets or ["Linear"] - config_group = QuantizationScheme.model_validate( - {"targets": targets, **self.scheme} - ) - quant_args["config_groups"] = {"config_group_0": config_group} + quant_args["config_groups"] = {} + for idx, key in enumerate(self.scheme.keys()): + if is_preset_scheme(key): + scheme = preset_name_to_scheme(key, self.scheme[key]) + else: + scheme = QuantizationScheme.model_validate( + {"targets": self.scheme[key], **self.scheme} + ) - targets = self.targets or ["Linear"] - config_group = QuantizationScheme.model_validate( - {"targets": targets, **self.scheme} - ) - quant_args["config_groups"] = {"config_group_0": config_group} + group_name = f"group_{idx}" + quant_args["config_groups"][group_name] = scheme - if "config_groups" not in quant_args: + if "config_groups" not in quant_args or len("config_groups") == 0: default_quant_scheme = QuantizationScheme.default_scheme( targets=self.targets ) - quant_args["config_groups"] = {"config_group_0": default_quant_scheme} + quant_args["config_groups"] = {"group_0": default_quant_scheme} _LOGGER.info(f"Building quantization modifier with args: {quant_args}") vllm_quant_config = {"QuantizationModifier": quant_args} self._build_quant_modifier_from_dict(vllm_quant_config, framework) diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 0fcb66eee9c..1b9f365bebf 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -95,7 +95,7 @@ def test_create_default_quant_modifier(self): modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - default_config_group_name = "config_group_0" + default_config_group_name = "group_0" should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ default_config_group_name ] diff --git a/tests/sparseml/transformers/gptq/test_oneshot.py b/tests/sparseml/transformers/gptq/test_oneshot.py index c7c14275df1..1d2e28cc303 100644 --- a/tests/sparseml/transformers/gptq/test_oneshot.py +++ b/tests/sparseml/transformers/gptq/test_oneshot.py @@ -16,11 +16,57 @@ import shutil import unittest +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from parameterized import parameterized_class +from sparseml.modifiers.quantization.gptq import GPTQModifier from sparseml.transformers.sparsification.sparse_model import SparseAutoModelForCausalLM from tests.testing_utils import requires_torch +recipe_str = """ +quant_stage: + quant_modifiers: + GPTQModifier: + sequential_update: false + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: "channel" + targets: ["Linear"] +""" + +recipe_modifier_full = GPTQModifier( + ignore=["lm_head"], + sequential_update=False, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], weights=QuantizationArgs(num_bits=4, strategy="channel") + ) + }, +) + +recipe_modifier_shorthand_a = GPTQModifier( + ignore=["lm_head"], sequential_update=False, targets="Linear", scheme="W4A16" +) + +recipe_modifier_shorthand_b = GPTQModifier( + ignore=["lm_head"], sequential_update=False, scheme={"W4A16": ["Linear"]} +) + + @requires_torch +@parameterized_class( + [ + {"recipe": recipe_str}, + {"recipe": recipe_modifier_full}, + {"recipe": recipe_modifier_shorthand_a}, + {"recipe": recipe_modifier_shorthand_b}, + ] +) class TestGPTQOneShotWithFullScheme(unittest.TestCase): def setUp(self): import torch @@ -30,26 +76,6 @@ def setUp(self): self.dataset = "open_platypus" self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - self.recipe = """ - first_stage: - quant_modifiers: - GPTQModifier: - ignore: ["lm_head"] - sequential_update: True - dampening_frac: 0.001 - block_size: 128 - targets: ["Linear"] - scheme: - input_activations: null - output_activations: null - weights: - num_bits: 8 - type: "int" - symmetric: true - strategy: "tensor" - group_size: 128 - """ - def test_oneshot_application(self): from sparseml.transformers import oneshot @@ -68,9 +94,23 @@ def test_oneshot_application(self): # Check that the model is quantized assert model_loaded.quantization_config is not None + # check config is set properly + assert model_loaded.quantization_config.ignore == ["lm_head"] + assert len(model_loaded.quantization_config.config_groups) == 1 + quant_scheme = model_loaded.quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + assert quant_scheme.targets == ["Linear"] + weight_args = model_loaded.quantization_config.config_groups["group_0"].weights + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == 4 + # Check a specific layer is quantized targetted_linear_layer = model_loaded.transformer.h[0].attn.attention.k_proj assert hasattr(targetted_linear_layer, "quantization_scheme") + # Check lm-head is not quantized + not_targetted = model_loaded.lm_head + assert not hasattr(not_targetted, "quantization_scheme") + def tearDown(self): shutil.rmtree(self.output)