Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPTQ Aliases #2327

Merged
merged 2 commits into from
Jun 12, 2024
Merged

Fix GPTQ Aliases #2327

merged 2 commits into from
Jun 12, 2024

Conversation

Satrat
Copy link

@Satrat Satrat commented Jun 11, 2024

neuralmagic/compressed-tensors#81 must be merged first

When specifying a scheme preset, the quantization modifier for GPTQ was not being properly initialized. In the example code below, despite specifying a W4A16 scheme the quantization config was always empty: Building quantization modifier with args: {'config_groups': {'config_group_0': QuantizationScheme(targets=['Linear'], weights=None, input_activations=None, output_activations=None)}}

The fix was to update the GPTQ modifier initialization to correctly apply the preset scheme. I've also added unit tests to confirm all variants of the GPTQ recipe are functioning as intended

Example Code

import torch
from datasets import load_dataset
from sparseml.transformers import SparseAutoModelForCausalLM, oneshot
from sparseml.modifiers.quantization.gptq import GPTQModifier
from transformers import AutoTokenizer

NUM_CALIBRATION_SAMPLES = 16
MAX_SEQ_LEN = 2048
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model = SparseAutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

gptq = GPTQModifier(
    scheme={"W4A16": ["Linear"]}
)

ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
ds = ds.map(lambda batch: {"text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)})

oneshot(
    model=model,
    dataset=ds,
    recipe=gptq,
    max_seq_length=MAX_SEQ_LEN,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

@Satrat Satrat changed the base branch from main to compression-example-update June 11, 2024 19:51
@Satrat Satrat merged commit 1a2d534 into compression-example-update Jun 12, 2024
@Satrat Satrat deleted the sa/fix_gptq_aliases branch June 12, 2024 15:19
bfineran added a commit that referenced this pull request Jun 13, 2024
* udpate llama7b_sparse_quantized example

* one shot llama example

* Update examples/llama7b_sparse_quantized/README.md

Co-authored-by: dbogunowicz <[email protected]>

* Fix GPTQ Aliases (#2327)

* fix alias application with unit tests

* style

---------

Co-authored-by: Sara Adkins <[email protected]>
Co-authored-by: dbogunowicz <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants