Skip to content

Commit

Permalink
Move QAT out of prototype
Browse files Browse the repository at this point in the history
Summary: Move QAT out of prototype so we can provide stronger
BC guarantees moving forward.

**BC-breaking notes**

Before:
```
from torchao.quantization.prototype.qat import (
    disable_4w_fake_quant,
    disable_8da4w_fake_quant,
    enable_4w_fake_quant,
    enable_8da4w_fake_quant,
    ComposableQATQuantizer,
    Int4WeightOnlyQATQuantizer,
    Int4WeightOnlyEmbeddingQATQuantizer
    Int8DynActInt4WeightQATQuantizer,
    Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.prototype.qat.api import (
    FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
    FakeQuantizer,
)
```

After:
```
from torchao.quantization.qat import (
    ComposableQATQuantizer,
    Int4WeightOnlyQATQuantizer,
    Int4WeightOnlyEmbeddingQATQuantizer
    Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
    disable_4w_fake_quant,
    disable_8da4w_fake_quant,
    enable_4w_fake_quant,
    enable_8da4w_fake_quant,
    Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.qat.api import (
    FakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
    FakeQuantizer,
)
```

Test Plan:
python test/quantization/test_qat.py

ghstack-source-id: cb72a8b083ab5e203e0c6631884d2fa695e76d96
Pull Request resolved: #1091
  • Loading branch information
andrewor14 committed Oct 17, 2024
1 parent 7aaf0ff commit 35dbae7
Show file tree
Hide file tree
Showing 13 changed files with 21 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to *
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)

```python
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

qat_quantizer = Int8DynActInt4WeightQATQuantizer()

Expand Down
32 changes: 16 additions & 16 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@
PerRow,
PerToken,
)
from torchao.quantization.prototype.qat.api import (
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
)
from torchao.quantization.prototype.qat.embedding import (
from torchao.quantization.qat.embedding import (
FakeQuantizedEmbedding,
)
from torchao.quantization.prototype.qat.linear import (
from torchao.quantization.qat.linear import (
FakeQuantizedLinear,
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear
)
from torchao.quantization.prototype.qat.utils import (
from torchao.quantization.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
Expand Down Expand Up @@ -181,7 +181,7 @@ def _set_ptq_weight(
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.prototype.qat.linear import (
from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
Expand Down Expand Up @@ -213,7 +213,7 @@ def _set_ptq_weight(

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

group_size = 128
Expand All @@ -238,7 +238,7 @@ def test_qat_8da4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer

group_size = 16
Expand Down Expand Up @@ -272,7 +272,7 @@ def test_qat_8da4w_quantizer(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

with torch.device("meta"):
m = M()
Expand All @@ -287,7 +287,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
"""
from torchao.quantization.prototype.qat import (
from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
Expand Down Expand Up @@ -346,7 +346,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
"""
from torchao.quantization.prototype.qat import (
from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
)
Expand Down Expand Up @@ -428,7 +428,7 @@ def _test_qat_quantized_gradients(self, quantizer):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
self._test_qat_quantized_gradients(quantizer)

Expand Down Expand Up @@ -518,7 +518,7 @@ def test_qat_4w_primitives(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear

group_size = 128
Expand All @@ -545,14 +545,14 @@ def test_qat_4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
self._test_qat_quantized_gradients(quantizer)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

group_size = 32
Expand Down Expand Up @@ -630,7 +630,7 @@ def test_composable_qat_quantizer(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_embedding(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
model = M2()
x = model.example_inputs()
out = model(*x)
Expand Down
11 changes: 0 additions & 11 deletions torchao/quantization/prototype/qat/_module_swap_api.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ For example, on a single GPU:
```python
import torch
from torchtune.models.llama3 import llama3
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

# Smaller version of llama3 to fit in a single GPU
model = llama3(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,16 @@
ComposableQATQuantizer,
)
from .linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATLinear,
Int8DynActInt4WeightQATQuantizer,
)
from .embedding import (
Int4WeightOnlyEmbeddingQATQuantizer,
)

__all__ = [
"disable_4w_fake_quant",
"disable_8da4w_fake_quant",
"enable_4w_fake_quant",
"enable_8da4w_fake_quant",
"ComposableQATQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int4WeightOnlyEmbeddingQATQuantizer"
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATLinear",
]
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
# avoid circular dependencies
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
from torchao.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)

Expand Down Expand Up @@ -88,7 +88,7 @@ def forward(
input: torch.Tensor,
) -> torch.Tensor:
# avoid circular dependencies
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
from torchao.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
assert isinstance(input, AffineFakeQuantizedTensor)
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _replace_with_custom_fn_if_matches_filter(

def _is_linear(mod, *args):
# avoid circular dependencies
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
from torchao.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)

Expand Down

0 comments on commit 35dbae7

Please sign in to comment.