Skip to content

Commit

Permalink
Move QAT out of prototype
Browse files Browse the repository at this point in the history
Summary: Move QAT out of prototype so we can provide stronger
BC guarantees moving forward.

**(Future) BC-breaking notes**

Note: This commit itself doesn't break BC yet. A future PR
will do that. The following is just to save this BC breaking
note somewhere.

Before:
```
from torchao.quantization.prototype.qat import (
    disable_4w_fake_quant,
    disable_8da4w_fake_quant,
    enable_4w_fake_quant,
    enable_8da4w_fake_quant,
    ComposableQATQuantizer,
    Int4WeightOnlyQATQuantizer,
    Int4WeightOnlyEmbeddingQATQuantizer
    Int8DynActInt4WeightQATQuantizer,
    Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.prototype.qat.api import (
    FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
    FakeQuantizer,
)
```

After:
```
from torchao.quantization.qat import (
    ComposableQATQuantizer,
    Int4WeightOnlyQATQuantizer,
    Int4WeightOnlyEmbeddingQATQuantizer
    Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
    disable_4w_fake_quant,
    disable_8da4w_fake_quant,
    enable_4w_fake_quant,
    enable_8da4w_fake_quant,
    Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.qat.api import (
    FakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
    FakeQuantizer,
)
```

Test Plan:
python test/quantization/test_qat.py

ghstack-source-id: add9dcac61e45f3b4ddeed07c300cc78ee3fd23c
Pull Request resolved: #1091
  • Loading branch information
andrewor14 committed Oct 17, 2024
1 parent 7aaf0ff commit e2dd867
Show file tree
Hide file tree
Showing 20 changed files with 1,703 additions and 1,604 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to *
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)

```python
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

qat_quantizer = Int8DynActInt4WeightQATQuantizer()

Expand Down
87 changes: 70 additions & 17 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@
PerRow,
PerToken,
)
from torchao.quantization.prototype.qat.api import (
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
)
from torchao.quantization.prototype.qat.embedding import (
from torchao.quantization.qat.embedding import (
FakeQuantizedEmbedding,
)
from torchao.quantization.prototype.qat.linear import (
from torchao.quantization.qat.linear import (
FakeQuantizedLinear,
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear
)
from torchao.quantization.prototype.qat.utils import (
from torchao.quantization.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
Expand Down Expand Up @@ -181,7 +181,7 @@ def _set_ptq_weight(
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.prototype.qat.linear import (
from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
Expand Down Expand Up @@ -213,7 +213,7 @@ def _set_ptq_weight(

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

group_size = 128
Expand All @@ -238,7 +238,7 @@ def test_qat_8da4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer

group_size = 16
Expand Down Expand Up @@ -272,7 +272,7 @@ def test_qat_8da4w_quantizer(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

with torch.device("meta"):
m = M()
Expand All @@ -287,7 +287,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
"""
from torchao.quantization.prototype.qat import (
from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
Expand Down Expand Up @@ -346,7 +346,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
"""
from torchao.quantization.prototype.qat import (
from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
)
Expand Down Expand Up @@ -428,7 +428,7 @@ def _test_qat_quantized_gradients(self, quantizer):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
self._test_qat_quantized_gradients(quantizer)

Expand Down Expand Up @@ -518,7 +518,7 @@ def test_qat_4w_primitives(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear

group_size = 128
Expand All @@ -545,14 +545,14 @@ def test_qat_4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
self._test_qat_quantized_gradients(quantizer)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

group_size = 32
Expand Down Expand Up @@ -630,7 +630,7 @@ def test_composable_qat_quantizer(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_embedding(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
model = M2()
x = model.example_inputs()
out = model(*x)
Expand Down Expand Up @@ -937,6 +937,59 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
baseline_out = embedding_forward_4w(x2, fq_embedding.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_prototype_bc(self):
"""
Just to make sure we can import all the old prototype paths.
We will remove this test in the near future when we actually break BC.
"""
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
ComposableQATQuantizer,
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyEmbeddingQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.prototype.qat._module_swap_api import (
disable_4w_fake_quant_module_swap,
enable_4w_fake_quant_module_swap,
disable_8da4w_fake_quant_module_swap,
enable_8da4w_fake_quant_module_swap,
Int4WeightOnlyQATQuantizerModuleSwap,
Int8DynActInt4WeightQATQuantizerModuleSwap,
)
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
to_affine_fake_quantized,
)
from torchao.quantization.prototype.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.embedding import (
FakeQuantizedEmbedding,
Int4WeightOnlyEmbeddingQATQuantizer,
Int4WeightOnlyEmbedding,
Int4WeightOnlyQATEmbedding,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
FakeQuantizer,
)
from torchao.quantization.prototype.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
FakeQuantizedLinear,
Int4WeightOnlyQATLinear,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATLinear,
Int8DynActInt4WeightQATQuantizer,
)

if __name__ == "__main__":
unittest.main()
unittest.main()
128 changes: 3 additions & 125 deletions torchao/quantization/prototype/qat/README.md
Original file line number Diff line number Diff line change
@@ -1,125 +1,3 @@
# Quantization-Aware Training (QAT)

Quantization-Aware Training (QAT) refers to applying fake quantization during the
training or fine-tuning process, such that the final quantized model will exhibit
higher accuracies and perplexities. Fake quantization refers to rounding the float
values to quantized values without actually casting them to dtypes with lower
bit-widths, in contrast to post-training quantization (PTQ), which does cast the
quantized values to lower bit-width dtypes, e.g.:

```
# PTQ: x_q is quantized and cast to int8
# scale and zero point (zp) refer to parameters used to quantize x_float
# qmin and qmax refer to the range of quantized values
x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)
# QAT: x_fq is still in float
# Fake quantize simulates the numerics of quantize + dequantize
x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
x_fq = (x_fq - zp) * scale
```

## API

torchao currently supports two QAT schemes for linear layers:
- int8 per token dynamic activations + int4 per group weights
- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)

QAT typically involves applying a transformation to your model before and after training.
In torchao, these are represented as the prepare and convert steps: (1) prepare inserts
fake quantize operations into linear layers, and (2) convert transforms the fake quantize
operations to actual quantize and dequantize operations after training, thereby producing
a quantized model (dequantize operations are typically fused with linear after lowering).
Between these two steps, training can proceed exactly as before.

![qat](images/qat_diagram.png)

To use QAT in torchao, apply the prepare step using the appropriate Quantizer before
training, then apply the convert step after training for inference or generation.
For example, on a single GPU:

```python
import torch
from torchtune.models.llama3 import llama3
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer

# Smaller version of llama3 to fit in a single GPU
model = llama3(
vocab_size=4096,
num_layers=16,
num_heads=16,
num_kv_heads=4,
embed_dim=2048,
max_seq_len=2048,
).cuda()

# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)

# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(10):
example = torch.randint(0, 4096, (2, 16)).cuda()
target = torch.randn((2, 16, 4096)).cuda()
output = model(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()

# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
model = qat_quantizer.convert(model)

# inference or generate
```

Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune)
and apply quantized-aware fine-tuning as follows:

```
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
```

For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).


## Evaluation Results

Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT
integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset)
for 5000 steps using a group size of 256 for the weights. Note that extensive
hyperparameter tuning may further improve these results.

Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5:

| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) |
| ---------------- | ------ | ------ | ------ | ------ | ------ |
| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 |
| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 |
| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 |
| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 |
| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 |

Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the
quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097).

| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) |
| ---------------- | -------- | ------- | ------ | ------ | ------ |
| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 |
| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 |
| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 |
| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 |
| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 |

For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training).
Note: QAT has been moved to torchao/quantization/qat.
This is a legacy folder only for backward compatibility
and will be removed in the near future.
12 changes: 5 additions & 7 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from .api import (
from torchao.quantization.qat import (
ComposableQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from .linear import (
from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATLinear,
Int8DynActInt4WeightQATQuantizer,
)
from .embedding import (
Int4WeightOnlyEmbeddingQATQuantizer,
)

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/prototype/qat/_module_swap_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# For backward compatibility only
# These will be removed in the future

from .linear import (
from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap,
Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap,
enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap,
Expand Down
Loading

0 comments on commit e2dd867

Please sign in to comment.