From e255b17765add46053a2669086cbc95b3fff406c Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 11 Jun 2024 15:04:28 -0400 Subject: [PATCH 1/3] Fix for Sparsity Persist (#2323) * fix sparsity persist * helper moved to compressed-tensors --- .../quantization/gptq/utils/gptq_wrapper.py | 43 +++++++++---------- .../obcq/test_mask_structure_preservation.py | 24 +---------- 2 files changed, 21 insertions(+), 46 deletions(-) diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 73321c0d0aa..ded28b4123b 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -103,6 +103,14 @@ def fasterprune( W = W.t() W = W.float() + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() + if preserve_zeros + else None + ) + tick = time.time() dead = torch.diag(self.H) == 0 @@ -119,17 +127,6 @@ def fasterprune( self.H = torch.linalg.cholesky(self.H, upper=True) Hinv = self.H - sparsity = tensor_sparsity(W) - mask = ( - torch.where( - W == 0, - torch.tensor(1, dtype=torch.bool), - torch.tensor(0, dtype=torch.bool), - ) - if sparsity >= SPARSITY_THRESHOLD - else None - ) - # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -141,21 +138,13 @@ def fasterprune( Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] - if sparsity >= SPARSITY_THRESHOLD: - tmp = ( - (~mask[:, i1:i2]) - * W1**2 - / (torch.diag(Hinv1).reshape((1, -1))) ** 2 - ) - thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] - mask1 = tmp <= thresh + if preserve_zeros: + W1_nz_mask = W_nz_mask[:, i1:i2] for i in range(count): w = W1[:, i] d = Hinv1[i, i] q = w.clone() - if sparsity >= SPARSITY_THRESHOLD: - q[mask1[:, i]] = 0 if hasattr(self.layer, "weight_fake_quant"): scale = self.layer.weight_fake_quant.scale @@ -216,13 +205,21 @@ def fasterprune( Losses1[:, i] = (w - q) ** 2 / d**2 err1 = (w - q) / d - W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if preserve_zeros: + W1[:, i:] -= w1_err * W1_nz_mask[:, i:] + else: + W1[:, i:] -= w1_err Err1[:, i] = err1 W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_zeros: + W[:, i2:] -= w_err * W_nz_mask[:, i2:] + else: + W[:, i2:] -= w_err _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) diff --git a/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py b/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py index a068c391431..eca6f5d2379 100644 --- a/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py +++ b/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py @@ -19,6 +19,7 @@ import pytest import sparseml +from compressed_tensors.compressors.utils import tensor_follows_mask_structure from parameterized import parameterized_class from tests.testing_utils import parse_params, requires_torch @@ -28,29 +29,6 @@ ) -def tensor_follows_mask_structure(tensor, mask: str = "2:4"): - """ - :param tensor: tensor to check - :param mask: mask structure to check for, in the format "n:m" - :return: True if the tensor follows the mask structure, False otherwise. - Note, some weights can incidentally be zero, so we check for - atleast n zeros in each chunk of size m - """ - import torch - - n, m = tuple(map(int, mask.split(":"))) - # Reshape the tensor into chunks of size m - tensor = tensor.view(-1, m) - - # Count the number of zeros in each chunk - zero_counts = (tensor == 0).sum(dim=1) - - # Check if the number of zeros in each chunk atleast n - # Greater than sign is needed as some weights can incidentally - # be zero - return torch.all(zero_counts >= n) - - @requires_torch @pytest.mark.integration @parameterized_class(parse_params(MASK_STRUCTURE_CONFIGS_DIRECTORY)) From 4e2ad0ac56ab3569aa350e21bed2f13da11b3408 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 12 Jun 2024 12:01:53 -0400 Subject: [PATCH 2/3] [GHA] Update End-to-End Nightly Build Process (#2304) * trigger nightly workflow * update condition * update * update condition * skip actual tests to speed up testing * try true conditions * try agin * try again * clean-up * update condiitions * try again * try again * try fil case * update * try new condition * try again * try again * try again * revert * try new conditions * typo * try again * try dev workflow * try again * update condition * update * try again * test failure case * update * try again * update * try nightly * add publish --------- Co-authored-by: Sara Adkins --- .github/workflows/build-container.yml | 10 ++- .github/workflows/build-nightly.yml | 22 ------ .../workflows/build-wheel-and-container.yml | 39 ++++----- .../publish-nightly-docker-images.yaml | 79 ------------------- .github/workflows/test-nightly.yml | 4 +- ...nternal.yml => test-wheel-and-publish.yml} | 39 ++++++--- 6 files changed, 59 insertions(+), 134 deletions(-) delete mode 100644 .github/workflows/build-nightly.yml delete mode 100644 .github/workflows/publish-nightly-docker-images.yaml rename .github/workflows/{test-wheel-push-to-internal.yml => test-wheel-and-publish.yml} (57%) diff --git a/.github/workflows/build-container.yml b/.github/workflows/build-container.yml index 9eda86ae0d0..ae7cc43bc52 100644 --- a/.github/workflows/build-container.yml +++ b/.github/workflows/build-container.yml @@ -53,4 +53,12 @@ jobs: build-args: | BRANCH=${{github.head_ref}} push: true - tags: ghcr.io/neuralmagic/sparseml-dev:${{ inputs.name }} \ No newline at end of file + tags: ghcr.io/neuralmagic/sparseml-dev:${{ inputs.name }} + + - name: Build Nightly Docker Container + if: ${{ inputs.dev == 'false' && inputs.release == 'false'}} + uses: docker/build-push-action@v4 + with: + context: ./docker/containers/docker_nightly + push: true + tags: ghcr.io/neuralmagic/sparseml-nightly:latest, ghcr.io/neuralmagic/sparseml-nightly:${{ steps.date.outputs.date }} \ No newline at end of file diff --git a/.github/workflows/build-nightly.yml b/.github/workflows/build-nightly.yml deleted file mode 100644 index be44d8b863e..00000000000 --- a/.github/workflows/build-nightly.yml +++ /dev/null @@ -1,22 +0,0 @@ -name: build-nightly -run-name: ${{ github.workflow }} is to create nightly wheel file for pypi -on: - push: - branches: - - 'main' - schedule: - - cron: '30 0 * * *' - workflow_dispatch: - - -jobs: - - BUILD-SPARSEML-NIGHTLY: - - uses: ./.github/workflows/util.yml - with: - runs_on: ubuntu-22.04 - run_id: ${{ github.run_id }} - build_type: nightly - testmo_project_id: 9 - secrets: inherit diff --git a/.github/workflows/build-wheel-and-container.yml b/.github/workflows/build-wheel-and-container.yml index 3eaaf674e08..421e227577a 100644 --- a/.github/workflows/build-wheel-and-container.yml +++ b/.github/workflows/build-wheel-and-container.yml @@ -4,15 +4,8 @@ on: types: [opened, synchronize, reopened] branches: - main - - 'release/[0-9]+.[0-9]+' - push: - branches: - - 'release/[0-9]+.[0-9]+' - - main - release: - types: [created, published] schedule: - - cron: '0 0 * * *' + - cron: '0 20 * * *' permissions: id-token: write @@ -23,10 +16,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true -# if not dev or release, will create a nightly build +# TODO: do we want to push to nightly everytime we push to main? +# if not dev or release, will create a nightly build; turning off release for now env: - PRODUCTION: ${{ github.event_name == 'schedule' || github.event_name == 'release'}} - RELEASE: ${{ github.event_name =='release' || startsWith(github.base_ref, 'release/') }} + RELEASE: 'false' DEV: ${{ github.base_ref == 'main' && github.event_name == 'pull_request'}} jobs: @@ -42,8 +35,14 @@ jobs: echo "dev=$DEV" >> $GITHUB_OUTPUT echo "release=$RELEASE" >> $GITHUB_OUTPUT - build-wheel-and-push: + test-nightly: needs: set-outputs + if: ${{ needs.set-outputs.outputs.dev == 'false' && needs.set-outputs.outputs.release == 'false'}} + uses: ./.github/workflows/test-nightly.yml + + build-wheel-and-push: + needs: [set-outputs, test-nightly] + if: ${{ always() && needs.set-outputs.outputs.dev == 'false' && needs.test-nightly.result == 'success' || always() && needs.set-outputs.outputs.dev == 'true' && needs.set-outputs.result == 'success' }} uses: ./.github/workflows/build-wheel.yml with: build-label: ubuntu-20.04 @@ -55,22 +54,24 @@ jobs: python: '3.10' secrets: inherit - test-wheel-and-push-internal: - needs: build-wheel-and-push - uses: ./.github/workflows/test-wheel-push-to-internal.yml + test-wheel-and-publish: + needs: [set-outputs, build-wheel-and-push] + if: ${{ always() && !cancelled() && needs.build-wheel-and-push.result == 'success' }} + uses: ./.github/workflows/test-wheel-and-publish.yml with: build-label: ubuntu-20.04 whl: ${{ needs.build-wheel-and-push.outputs.wheel }} python: '3.10' + dev: ${{ needs.set-outputs.outputs.dev }} + release: ${{ needs.set-outputs.outputs.release }} secrets: inherit - # TODO: add nightly and release container build steps once wheel build push - # to production is automated. Removed until then. build-container-and-push: - needs: [set-outputs, test-wheel-and-push-internal] + needs: [test-wheel-and-publish, set-outputs] + if: ${{ always() && !cancelled() && needs.test-wheel-and-publish.result == 'success' }} uses: ./.github/workflows/build-container.yml with: - build-label: k8s-eng-gpu-64G-v100-32G + build-label: k8s-eng-gpu-16G-t4-32G dev: ${{ needs.set-outputs.outputs.dev }} release: ${{ needs.set-outputs.outputs.release }} name: ${{ github.event.number }} diff --git a/.github/workflows/publish-nightly-docker-images.yaml b/.github/workflows/publish-nightly-docker-images.yaml deleted file mode 100644 index 5ca14ac08bc..00000000000 --- a/.github/workflows/publish-nightly-docker-images.yaml +++ /dev/null @@ -1,79 +0,0 @@ -name: Publish Nightly Docker Images - -on: - push: - branches: - - 'main' - schedule: - - cron: '0 1 * * *' - workflow_dispatch: -jobs: - push-nightly-docker-image: - name: Push Version Tagged Nightly Docker Images - runs-on: ubuntu-latest - permissions: - contents: read - packages: write - - steps: - - name: Set up Docker Buildx - id: buildx - uses: docker/setup-buildx-action@v2 - with: - buildkitd-flags: --debug - - - name: Login to Github Packages - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Checkout code - uses: actions/checkout@v3 - with: - fetch-depth: 1 - - - name: Get version tag - id: extract_tag - run: echo "tag=$(date +%Y%m%d)" >> $GITHUB_OUTPUT - - - name: Current Version Name - run: | - echo ${{ steps.extract_tag.outputs.tag }} - - - name: Sparseml-Nightly latest using default cuda 11.1.1 - uses: docker/build-push-action@v2 - with: - context: ./docker - build-args: | - DEPS=all - BRANCH=main - push: true - tags: | - ghcr.io/neuralmagic/sparseml-nightly:latest - - - name: Today's Sparseml-Nightly using default cuda 11.1.1 - uses: docker/build-push-action@v2 - with: - context: ./docker - build-args: | - DEPS=all - BRANCH=main - push: true - tags: | - ghcr.io/neuralmagic/sparseml-nightly:${{ steps.extract_tag.outputs.tag }} - - - name: Today's Sparseml-Nightly Base using default cuda 11.1.1 - uses: docker/build-push-action@v2 - with: - context: ./docker - build-args: | - DEPS=base - BRANCH=main - push: true - tags: | - ghcr.io/neuralmagic/sparseml-nightly:base-${{ steps.extract_tag.outputs.tag }} - - - name: Image digest - run: echo ${{ steps.docker_build.outputs.digest }} diff --git a/.github/workflows/test-nightly.yml b/.github/workflows/test-nightly.yml index 8472b6c8134..4fc1c19cd84 100644 --- a/.github/workflows/test-nightly.yml +++ b/.github/workflows/test-nightly.yml @@ -1,8 +1,7 @@ name: Run Nightly Tests on: - schedule: - - cron: '0 20 * * *' workflow_dispatch: + workflow_call: jobs: test-nightly-tests: runs-on: k8s-mle-gpu-12-vcpu-225GB-ram-2-a6000-48G @@ -33,6 +32,5 @@ jobs: run: | pytest tests/sparseml/transformers/obcq -m integration - name: Run finetune tests - if: always() run: | pytest tests/sparseml/transformers/finetune -m integration \ No newline at end of file diff --git a/.github/workflows/test-wheel-push-to-internal.yml b/.github/workflows/test-wheel-and-publish.yml similarity index 57% rename from .github/workflows/test-wheel-push-to-internal.yml rename to .github/workflows/test-wheel-and-publish.yml index 28af2f272e7..e40fa462ded 100644 --- a/.github/workflows/test-wheel-push-to-internal.yml +++ b/.github/workflows/test-wheel-and-publish.yml @@ -1,4 +1,4 @@ -name: Test Wheel and Push to Internal PyPi +name: Test Wheel and Publish on: workflow_call: inputs: @@ -11,9 +11,15 @@ on: required: true python: type: string + dev: + type: string + required: true + release: + type: string + required: true jobs: - test-wheel-and-push-internal: + test-wheel-and-publish: runs-on: ${{ inputs.build-label }} steps: - uses: actions/setup-python@v4 @@ -36,24 +42,37 @@ jobs: filename: ${{ inputs.whl }} dst: dist_s3 - - name: Set Env - run: | - pip3 install virtualenv - virtualenv venv - source venv/bin/activate - - name: Fetch name of whl run: | echo "FILENAME=$(echo dist_s3/*.whl)" >> $GITHUB_ENV - name: Install whl run: | - pip3 install $FILENAME[dev] + pip3 install $FILENAME[dev,onnxruntime,torch,torchvision,transformers] - name: Checkout code uses: actions/checkout@v3 - name: Remove src files and run tests run: | + pwd rm -rf src - make test \ No newline at end of file + make test + + - name: Make directory for wheel + run: | + mkdir dist_s3 + + - name: Pull from s3 + uses: neuralmagic/nm-actions/actions/s3_pull@main + with: + filename: ${{ inputs.whl }} + dst: dist_s3 + + - name: Publish Nightly Wheel + if: ${{ inputs.DEV == 'false' && inputs.RELEASE == 'false'}} + uses: neuralmagic/nm-actions/actions/publish-whl@main + with: + username: ${{ secrets.PYPI_PUBLIC_USER }} + password: ${{ secrets.PYPI_PUBLIC_AUTH }} + whl: ./$FILENAME \ No newline at end of file From 5c1de1c73577b9a4ca3666662a50ccff2c8acd03 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Thu, 13 Jun 2024 16:04:46 -0400 Subject: [PATCH 3/3] udpate llama7b_sparse_quantized example (#2322) * udpate llama7b_sparse_quantized example * one shot llama example * Update examples/llama7b_sparse_quantized/README.md Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> * Fix GPTQ Aliases (#2327) * fix alias application with unit tests * style --------- Co-authored-by: Sara Adkins Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> --- examples/llama7b_one_shot_quantization.md | 50 ++++++++++++ examples/llama7b_sparse_quantized/README.md | 80 +++++++++++++------ .../modifiers/quantization/gptq/base.py | 42 +++++----- .../pruning/sparsegpt/test_pytorch.py | 2 +- .../transformers/gptq/test_oneshot.py | 80 ++++++++++++++----- 5 files changed, 189 insertions(+), 65 deletions(-) create mode 100644 examples/llama7b_one_shot_quantization.md diff --git a/examples/llama7b_one_shot_quantization.md b/examples/llama7b_one_shot_quantization.md new file mode 100644 index 00000000000..d3ee50e1aaf --- /dev/null +++ b/examples/llama7b_one_shot_quantization.md @@ -0,0 +1,50 @@ +# Creating a Quantized Llama Model in One Shot + +Quantizing a model to a lower precision can save on both memory and speed at inference time. +This example demonstrates how to use the SparseML API to quantize a Llama model from 16 bits +to 4 bits and save it to a compressed-tensors format for inference with vLLM. + +## Step 1: Select a model and dataset +For this example, we will use a TinyLlama model and the open platypus dataset, however +these can be swapped out for any huggingface compatible models and datasets + +```python +model = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset = "open_platypus" +``` + +## Step 2: Configure a `GPTQModifier` +Modifiers in sparseml are used to apply optimizations to models. In this example we use a +`GPTQModifier` to apply the GPTQ algorithm to our model. We target all `Linear` layers +for 4-bit weight quantization. These options may be swapped out for any valid `QuantizationScheme`. + +```python +from sparseml.modifiers.quantization.gptq import GPTQModifier + +gptq = GPTQModifier( + targets="Linear", + scheme="W4A16" +) +``` + + +### Step3: One-Shot Compression + +The `oneshot` api applies the created modifier to the target model and dataset. +Setting `save_compressed` to True runs the model through `compressed_tensors` compression +after the quantization is completed. + +```python +from sparseml.transformers import oneshot + +oneshot( + model=model, + dataset=dataset, + recipe=gptq, + save_compressed=True, + output_dir="llama-compressed-example", + overwrite_output_dir=True, + max_seq_length=256, + num_calibration_samples=256, +) +``` diff --git a/examples/llama7b_sparse_quantized/README.md b/examples/llama7b_sparse_quantized/README.md index 779696ba599..c96b6e7ca43 100644 --- a/examples/llama7b_sparse_quantized/README.md +++ b/examples/llama7b_sparse_quantized/README.md @@ -1,47 +1,79 @@ # Creating a Sparse Quantized Llama7b Model -The example in this folder runs in multiple stages to create a Llama 7b model with -a 2:4 sparsity pattern and W4A16 post training quantization (PTW). The model is -calibrated and trained with the ultachat200k dataset. At least 75GB of GPU memory is -required to run this example. +This example uses SparseML and Compressed-Tensors to create a 2:4 sparse and quantized Llama2-7b model. +The model is calibrated and trained with the ultachat200k dataset. +At least 75GB of GPU memory is required to run this example. -## Recipe Summary +Follow the steps below, or to run the example as `python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py` -The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml). It contains 3 stages that are outlined below. +## Step 1: Select a model, dataset, and recipe +In this step, we select which model to use as a baseline for sparsification, a dataset to +use for calibration and finetuning, and a recipe. +Models can reference a local directory, model in the huggingface hub, or in the sparsezoo. -### Stage 1: Sparsification +Datasets can be from a local compatible directory or the huggingface hub. -Runs the SparseGPT one-shot algorithm to prune the model to 50% sparsity with a 2:4 -sparsity pattern. This means that 2 weights out of every group of 4 weights are masked to 0. +Recipes are YAML files that describe how a model should be optimized during or after training. +The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml). +It contains instructions to prune the model to 2:4 sparsity, run one epoch of recovery finetuning, +and quantize to 4 bits in one show using GPTQ. -### Stage 2: Finetuning Recovery - -This stage runs a single epoch of training on the ultrachat200k dataset while maintaining -the sparsity mask from stage 1. The purpose of this stage is to recover any accuracy lost -during the sparsification process. +```python +import torch +from sparseml.transformers import SparseAutoModelForCausalLM -### Stage 3: Quantization +model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) -Finally, we run the GPTQ one-shot algorithm to quantize all linear weights to 4 bit -channelwise. +dataset = "ultrachat-200k" +splits = {"calibration": "train_gen[:5%]", "train": "train_gen"} -## How to Run +recipe = "2:4_w4a16_recipe.yaml" +``` -We can run the entire staged recipe with one call to SparseML's `apply` pathway. This -will save a checkpoint of the model after each stage. +## Step 2: Run sparsification using `apply` +The `apply` function applies the given recipe to our model and dataset. +The hardcoded kwargs may be altered based on each model's needs. +After running, the sparsified model will be saved to `output_llama7b_2:4_w4a16_channel`. + +```python +from sparseml.transformers import apply + +output_dir = "output_llama7b_2:4_w4a16_channel" + +apply( + model=model, + dataset=dataset, + recipe=recipe, + bf16=False, # use full precision for training + output_dir=output_dir, + splits=splits, + max_seq_length=512, + num_calibration_samples=512, + num_train_epochs=0.5, + logging_steps=500, + save_steps=5000, + gradient_checkpointing=True, + learning_rate=0.0001, + lr_scheduler_type="cosine", + warmup_ratio=0.1, +) +``` -```python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py``` -### Compression +### Step 3: Compression The resulting model will be uncompressed. To save a final compressed copy of the model run the following: -``` +```python import torch from sparseml.transformers import SparseAutoModelForCausalLM +compressed_output_dir = "output_llama7b_2:4_w4a16_channel_compressed" model = SparseAutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16) model.save_pretrained(compressed_output_dir, save_compressed=True) ``` @@ -49,4 +81,4 @@ model.save_pretrained(compressed_output_dir, save_compressed=True) ### Custom Quantization The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are `tensor`, `group` and `channel`. The above recipe (`2:4_w4a16_recipe.yaml`) uses channel-wise quantization specified by `strategy: "channel"` in its config group. -To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. Group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml` +To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. A group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml`. diff --git a/src/sparseml/modifiers/quantization/gptq/base.py b/src/sparseml/modifiers/quantization/gptq/base.py index 004fce2ee7a..43bc596d849 100644 --- a/src/sparseml/modifiers/quantization/gptq/base.py +++ b/src/sparseml/modifiers/quantization/gptq/base.py @@ -18,9 +18,9 @@ from pydantic import Field from compressed_tensors.quantization import ( - QuantizationConfig, QuantizationScheme, is_preset_scheme, + preset_name_to_scheme, ) from sparseml.core import Modifier from sparseml.core.factory import ModifierFactory @@ -77,6 +77,7 @@ class GPTQModifier(Modifier): QuantizationScheme except targets, which will be set to the targets parameter set at the modifier level. Can also be set to a dictionary of the format `preset_scheme_name: targets` for example: `W8A8: ['Linear']` for weight 8 bit + or a string of a preset scheme if targets is provided and activation 8 bit quantization on the Linear layers. """ @@ -89,7 +90,7 @@ class GPTQModifier(Modifier): ignore: List[str] = Field(default_factory=list) disable_quantization_observer_epoch: Optional[float] = None num_calibration_steps: Optional[int] = None - scheme: Optional[Dict[str, Any]] = None + scheme: Optional[Union[str, Dict[str, Any]]] = None compressible_layers_: Optional[List] = None quantization_modifier_: Any = None @@ -167,32 +168,33 @@ def _build_quant_modifier(self, framework): if getattr(self, key, False) } + if isinstance(self.targets, str): + self.targets = [self.targets] + if self.scheme is not None: # takes precedence over config_groups - if any(is_preset_scheme(key) for key in self.scheme.keys()): - config_groups = QuantizationConfig( - config_groups=self.scheme - ).config_groups - quant_args["config_groups"] = config_groups - else: - targets = self.targets or ["Linear"] - config_group = QuantizationScheme.model_validate( - {"targets": targets, **self.scheme} - ) - quant_args["config_groups"] = {"config_group_0": config_group} + if isinstance(self.scheme, str) and is_preset_scheme(self.scheme): + # attach targets to scheme + self.scheme = {self.scheme: self.targets} - targets = self.targets or ["Linear"] - config_group = QuantizationScheme.model_validate( - {"targets": targets, **self.scheme} - ) - quant_args["config_groups"] = {"config_group_0": config_group} + quant_args["config_groups"] = {} + for idx, key in enumerate(self.scheme.keys()): + if is_preset_scheme(key): + scheme = preset_name_to_scheme(key, self.scheme[key]) + else: + scheme = QuantizationScheme.model_validate( + {"targets": self.scheme[key], **self.scheme} + ) + + group_name = f"group_{idx}" + quant_args["config_groups"][group_name] = scheme - if "config_groups" not in quant_args: + if "config_groups" not in quant_args or len("config_groups") == 0: default_quant_scheme = QuantizationScheme.default_scheme( targets=self.targets ) - quant_args["config_groups"] = {"config_group_0": default_quant_scheme} + quant_args["config_groups"] = {"group_0": default_quant_scheme} _LOGGER.info(f"Building quantization modifier with args: {quant_args}") vllm_quant_config = {"QuantizationModifier": quant_args} self._build_quant_modifier_from_dict(vllm_quant_config, framework) diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 0fcb66eee9c..1b9f365bebf 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -95,7 +95,7 @@ def test_create_default_quant_modifier(self): modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - default_config_group_name = "config_group_0" + default_config_group_name = "group_0" should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ default_config_group_name ] diff --git a/tests/sparseml/transformers/gptq/test_oneshot.py b/tests/sparseml/transformers/gptq/test_oneshot.py index c7c14275df1..1d2e28cc303 100644 --- a/tests/sparseml/transformers/gptq/test_oneshot.py +++ b/tests/sparseml/transformers/gptq/test_oneshot.py @@ -16,11 +16,57 @@ import shutil import unittest +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from parameterized import parameterized_class +from sparseml.modifiers.quantization.gptq import GPTQModifier from sparseml.transformers.sparsification.sparse_model import SparseAutoModelForCausalLM from tests.testing_utils import requires_torch +recipe_str = """ +quant_stage: + quant_modifiers: + GPTQModifier: + sequential_update: false + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: "channel" + targets: ["Linear"] +""" + +recipe_modifier_full = GPTQModifier( + ignore=["lm_head"], + sequential_update=False, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], weights=QuantizationArgs(num_bits=4, strategy="channel") + ) + }, +) + +recipe_modifier_shorthand_a = GPTQModifier( + ignore=["lm_head"], sequential_update=False, targets="Linear", scheme="W4A16" +) + +recipe_modifier_shorthand_b = GPTQModifier( + ignore=["lm_head"], sequential_update=False, scheme={"W4A16": ["Linear"]} +) + + @requires_torch +@parameterized_class( + [ + {"recipe": recipe_str}, + {"recipe": recipe_modifier_full}, + {"recipe": recipe_modifier_shorthand_a}, + {"recipe": recipe_modifier_shorthand_b}, + ] +) class TestGPTQOneShotWithFullScheme(unittest.TestCase): def setUp(self): import torch @@ -30,26 +76,6 @@ def setUp(self): self.dataset = "open_platypus" self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - self.recipe = """ - first_stage: - quant_modifiers: - GPTQModifier: - ignore: ["lm_head"] - sequential_update: True - dampening_frac: 0.001 - block_size: 128 - targets: ["Linear"] - scheme: - input_activations: null - output_activations: null - weights: - num_bits: 8 - type: "int" - symmetric: true - strategy: "tensor" - group_size: 128 - """ - def test_oneshot_application(self): from sparseml.transformers import oneshot @@ -68,9 +94,23 @@ def test_oneshot_application(self): # Check that the model is quantized assert model_loaded.quantization_config is not None + # check config is set properly + assert model_loaded.quantization_config.ignore == ["lm_head"] + assert len(model_loaded.quantization_config.config_groups) == 1 + quant_scheme = model_loaded.quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + assert quant_scheme.targets == ["Linear"] + weight_args = model_loaded.quantization_config.config_groups["group_0"].weights + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == 4 + # Check a specific layer is quantized targetted_linear_layer = model_loaded.transformer.h[0].attn.attention.k_proj assert hasattr(targetted_linear_layer, "quantization_scheme") + # Check lm-head is not quantized + not_targetted = model_loaded.lm_head + assert not hasattr(not_targetted, "quantization_scheme") + def tearDown(self): shutil.rmtree(self.output)