From a5a0428bd00ab6e9388fdc46e8675780035dfec1 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 16 Oct 2024 08:33:31 -0700 Subject: [PATCH] Move QAT out of prototype Summary: Move QAT out of prototype so we can provide stronger BC guarantees moving forward. **BC-breaking notes** Before: ``` from torchao.quantization.prototype.qat import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.prototype.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.prototype.qat.fake_quantizer import ( FakeQuantizer, ) ``` After: ``` from torchao.quantization.qat import ( ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, ) from torchao.quantization.qat.linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, ) ``` Test Plan: python test/quantization/test_qat.py [ghstack-poisoned] --- README.md | 2 +- test/quantization/test_qat.py | 32 +++++++++--------- .../prototype/qat/_module_swap_api.py | 11 ------ .../{prototype => }/qat/README.md | 2 +- .../{prototype => }/qat/__init__.py | 10 ------ .../qat/affine_fake_quantized_tensor.py | 0 .../quantization/{prototype => }/qat/api.py | 0 .../{prototype => }/qat/embedding.py | 0 .../{prototype => }/qat/fake_quantizer.py | 0 .../qat/images/qat_diagram.png | Bin .../{prototype => }/qat/linear.py | 0 .../quantization/{prototype => }/qat/utils.py | 4 +-- torchao/quantization/quant_api.py | 2 +- 13 files changed, 21 insertions(+), 42 deletions(-) delete mode 100644 torchao/quantization/prototype/qat/_module_swap_api.py rename torchao/quantization/{prototype => }/qat/README.md (98%) rename torchao/quantization/{prototype => }/qat/__init__.py (55%) rename torchao/quantization/{prototype => }/qat/affine_fake_quantized_tensor.py (100%) rename torchao/quantization/{prototype => }/qat/api.py (100%) rename torchao/quantization/{prototype => }/qat/embedding.py (100%) rename torchao/quantization/{prototype => }/qat/fake_quantizer.py (100%) rename torchao/quantization/{prototype => }/qat/images/qat_diagram.png (100%) rename torchao/quantization/{prototype => }/qat/linear.py (100%) rename torchao/quantization/{prototype => }/qat/utils.py (97%) diff --git a/README.md b/README.md index 71fb25fa2..ba48dbf45 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to * Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/) ```python -from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer +from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer qat_quantizer = Int8DynActInt4WeightQATQuantizer() diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 2bc3fce36..b3998efad 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -22,20 +22,20 @@ PerRow, PerToken, ) -from torchao.quantization.prototype.qat.api import ( +from torchao.quantization.qat.api import ( ComposableQATQuantizer, FakeQuantizeConfig, ) -from torchao.quantization.prototype.qat.fake_quantizer import ( +from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, ) -from torchao.quantization.prototype.qat.embedding import ( +from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) -from torchao.quantization.prototype.qat.linear import ( +from torchao.quantization.qat.linear import ( FakeQuantizedLinear, ) -from torchao.quantization.prototype.qat.utils import ( +from torchao.quantization.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, @@ -175,7 +175,7 @@ def _set_ptq_weight( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) - from torchao.quantization.prototype.qat.linear import ( + from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, ) @@ -207,7 +207,7 @@ def _set_ptq_weight( @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): - from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear + from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear group_size = 128 @@ -232,7 +232,7 @@ def test_qat_8da4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer group_size = 16 @@ -266,7 +266,7 @@ def test_qat_8da4w_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer with torch.device("meta"): m = M() @@ -281,7 +281,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. """ - from torchao.quantization.prototype.qat import ( + from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, enable_8da4w_fake_quant, @@ -340,7 +340,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. """ - from torchao.quantization.prototype.qat import ( + from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, ) @@ -422,7 +422,7 @@ def _test_qat_quantized_gradients(self, quantizer): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_gradients(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) self._test_qat_quantized_gradients(quantizer) @@ -512,7 +512,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear + from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear group_size = 128 @@ -539,14 +539,14 @@ def test_qat_4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_quantizer_gradients(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer group_size = 32 @@ -624,7 +624,7 @@ def test_composable_qat_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_embedding(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer + from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer model = M2() x = model.example_inputs() out = model(*x) diff --git a/torchao/quantization/prototype/qat/_module_swap_api.py b/torchao/quantization/prototype/qat/_module_swap_api.py deleted file mode 100644 index 0b44974f2..000000000 --- a/torchao/quantization/prototype/qat/_module_swap_api.py +++ /dev/null @@ -1,11 +0,0 @@ -# For backward compatibility only -# These will be removed in the future - -from .linear import ( - Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap, - Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap, - enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap, - disable_8da4w_fake_quant as disable_8da4w_fake_quant_module_swap, - enable_4w_fake_quant as enable_4w_fake_quant_module_swap, - disable_4w_fake_quant as disable_4w_fake_quant_module_swap, -) diff --git a/torchao/quantization/prototype/qat/README.md b/torchao/quantization/qat/README.md similarity index 98% rename from torchao/quantization/prototype/qat/README.md rename to torchao/quantization/qat/README.md index 286932229..6ecccd2b1 100644 --- a/torchao/quantization/prototype/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -41,7 +41,7 @@ For example, on a single GPU: ```python import torch from torchtune.models.llama3 import llama3 -from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer +from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer # Smaller version of llama3 to fit in a single GPU model = llama3( diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/qat/__init__.py similarity index 55% rename from torchao/quantization/prototype/qat/__init__.py rename to torchao/quantization/qat/__init__.py index 09ea6e708..09ef10af6 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -2,12 +2,7 @@ ComposableQATQuantizer, ) from .linear import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATLinear, Int8DynActInt4WeightQATQuantizer, ) from .embedding import ( @@ -15,13 +10,8 @@ ) __all__ = [ - "disable_4w_fake_quant", - "disable_8da4w_fake_quant", - "enable_4w_fake_quant", - "enable_8da4w_fake_quant", "ComposableQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int4WeightOnlyEmbeddingQATQuantizer" "Int8DynActInt4WeightQATQuantizer", - "Int8DynActInt4WeightQATLinear", ] diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py similarity index 100% rename from torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py rename to torchao/quantization/qat/affine_fake_quantized_tensor.py diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/qat/api.py similarity index 100% rename from torchao/quantization/prototype/qat/api.py rename to torchao/quantization/qat/api.py diff --git a/torchao/quantization/prototype/qat/embedding.py b/torchao/quantization/qat/embedding.py similarity index 100% rename from torchao/quantization/prototype/qat/embedding.py rename to torchao/quantization/qat/embedding.py diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py similarity index 100% rename from torchao/quantization/prototype/qat/fake_quantizer.py rename to torchao/quantization/qat/fake_quantizer.py diff --git a/torchao/quantization/prototype/qat/images/qat_diagram.png b/torchao/quantization/qat/images/qat_diagram.png similarity index 100% rename from torchao/quantization/prototype/qat/images/qat_diagram.png rename to torchao/quantization/qat/images/qat_diagram.png diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/qat/linear.py similarity index 100% rename from torchao/quantization/prototype/qat/linear.py rename to torchao/quantization/qat/linear.py diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/qat/utils.py similarity index 97% rename from torchao/quantization/prototype/qat/utils.py rename to torchao/quantization/qat/utils.py index 8f2dd9d13..e2234a255 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -46,7 +46,7 @@ def forward( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) @@ -88,7 +88,7 @@ def forward( input: torch.Tensor, ) -> torch.Tensor: # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) assert isinstance(input, AffineFakeQuantizedTensor) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 91803fe3f..ba6fe9b2c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -220,7 +220,7 @@ def _replace_with_custom_fn_if_matches_filter( def _is_linear(mod, *args): # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, )