diff --git a/README.md b/README.md index 71fb25fa2..ba48dbf45 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to * Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/) ```python -from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer +from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer qat_quantizer = Int8DynActInt4WeightQATQuantizer() diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 918697724..8b62b8f57 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -22,22 +22,22 @@ PerRow, PerToken, ) -from torchao.quantization.prototype.qat.api import ( +from torchao.quantization.qat.api import ( ComposableQATQuantizer, FakeQuantizeConfig, ) -from torchao.quantization.prototype.qat.fake_quantizer import ( +from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, ) -from torchao.quantization.prototype.qat.embedding import ( +from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) -from torchao.quantization.prototype.qat.linear import ( +from torchao.quantization.qat.linear import ( FakeQuantizedLinear, Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear ) -from torchao.quantization.prototype.qat.utils import ( +from torchao.quantization.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, @@ -181,7 +181,7 @@ def _set_ptq_weight( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) - from torchao.quantization.prototype.qat.linear import ( + from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, ) @@ -213,7 +213,7 @@ def _set_ptq_weight( @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): - from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear + from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear group_size = 128 @@ -238,7 +238,7 @@ def test_qat_8da4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer group_size = 16 @@ -272,7 +272,7 @@ def test_qat_8da4w_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer with torch.device("meta"): m = M() @@ -287,7 +287,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. """ - from torchao.quantization.prototype.qat import ( + from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, enable_8da4w_fake_quant, @@ -346,7 +346,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. """ - from torchao.quantization.prototype.qat import ( + from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, ) @@ -428,7 +428,7 @@ def _test_qat_quantized_gradients(self, quantizer): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_gradients(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) self._test_qat_quantized_gradients(quantizer) @@ -518,7 +518,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear + from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear group_size = 128 @@ -545,14 +545,14 @@ def test_qat_4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_quantizer_gradients(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer group_size = 32 @@ -630,7 +630,7 @@ def test_composable_qat_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_embedding(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer + from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer model = M2() x = model.example_inputs() out = model(*x) diff --git a/torchao/quantization/prototype/qat/_module_swap_api.py b/torchao/quantization/prototype/qat/_module_swap_api.py deleted file mode 100644 index 0b44974f2..000000000 --- a/torchao/quantization/prototype/qat/_module_swap_api.py +++ /dev/null @@ -1,11 +0,0 @@ -# For backward compatibility only -# These will be removed in the future - -from .linear import ( - Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap, - Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap, - enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap, - disable_8da4w_fake_quant as disable_8da4w_fake_quant_module_swap, - enable_4w_fake_quant as enable_4w_fake_quant_module_swap, - disable_4w_fake_quant as disable_4w_fake_quant_module_swap, -) diff --git a/torchao/quantization/prototype/qat/README.md b/torchao/quantization/qat/README.md similarity index 98% rename from torchao/quantization/prototype/qat/README.md rename to torchao/quantization/qat/README.md index 286932229..6ecccd2b1 100644 --- a/torchao/quantization/prototype/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -41,7 +41,7 @@ For example, on a single GPU: ```python import torch from torchtune.models.llama3 import llama3 -from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer +from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer # Smaller version of llama3 to fit in a single GPU model = llama3( diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/qat/__init__.py similarity index 55% rename from torchao/quantization/prototype/qat/__init__.py rename to torchao/quantization/qat/__init__.py index 09ea6e708..09ef10af6 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -2,12 +2,7 @@ ComposableQATQuantizer, ) from .linear import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATLinear, Int8DynActInt4WeightQATQuantizer, ) from .embedding import ( @@ -15,13 +10,8 @@ ) __all__ = [ - "disable_4w_fake_quant", - "disable_8da4w_fake_quant", - "enable_4w_fake_quant", - "enable_8da4w_fake_quant", "ComposableQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int4WeightOnlyEmbeddingQATQuantizer" "Int8DynActInt4WeightQATQuantizer", - "Int8DynActInt4WeightQATLinear", ] diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py similarity index 100% rename from torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py rename to torchao/quantization/qat/affine_fake_quantized_tensor.py diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/qat/api.py similarity index 100% rename from torchao/quantization/prototype/qat/api.py rename to torchao/quantization/qat/api.py diff --git a/torchao/quantization/prototype/qat/embedding.py b/torchao/quantization/qat/embedding.py similarity index 100% rename from torchao/quantization/prototype/qat/embedding.py rename to torchao/quantization/qat/embedding.py diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py similarity index 100% rename from torchao/quantization/prototype/qat/fake_quantizer.py rename to torchao/quantization/qat/fake_quantizer.py diff --git a/torchao/quantization/prototype/qat/images/qat_diagram.png b/torchao/quantization/qat/images/qat_diagram.png similarity index 100% rename from torchao/quantization/prototype/qat/images/qat_diagram.png rename to torchao/quantization/qat/images/qat_diagram.png diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/qat/linear.py similarity index 100% rename from torchao/quantization/prototype/qat/linear.py rename to torchao/quantization/qat/linear.py diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/qat/utils.py similarity index 97% rename from torchao/quantization/prototype/qat/utils.py rename to torchao/quantization/qat/utils.py index 8f2dd9d13..e2234a255 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -46,7 +46,7 @@ def forward( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) @@ -88,7 +88,7 @@ def forward( input: torch.Tensor, ) -> torch.Tensor: # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) assert isinstance(input, AffineFakeQuantizedTensor) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 91803fe3f..ba6fe9b2c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -220,7 +220,7 @@ def _replace_with_custom_fn_if_matches_filter( def _is_linear(mod, *args): # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, )