Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic fake quantized embedding for QAT #1085

Merged
merged 1 commit into from
Oct 16, 2024
Merged

Add generic fake quantized embedding for QAT #1085

merged 1 commit into from
Oct 16, 2024

Commits on Oct 16, 2024

  1. Add generic fake quantized embedding for QAT

    Summary: This is equivalent to #1020
    but for nn.Embedding. This commit adds a generic fake quantized
    embedding module to replace the uses of the existing more specific
    QAT embeddings. For example, `Int4WeightOnlyQATEmbedding` can be
    expressed as follows:
    
    ```
    from torchao.quantization.prototype.qat.api import FakeQuantizeConfig
    from torchao.quantization.prototype.qat.embedding import FakeQuantizedEmbedding
    
    weight_config = FakeQuantizeConfig(
        dtype=torch.int4,
        group_size=group_size,
        is_symmetric=True,
    )
    fq_embedding = FakeQuantizedEmbedding(16, 32, weight_config=weight_config)
    ```
    
    Test Plan:
    python test/quantization/test_qat.py -k test_qat_4w_embedding
    python test/quantization/test_qat.py -k test_fake_quantized_embedding_4w
    andrewor14 committed Oct 16, 2024
    Configuration menu
    Copy the full SHA
    53239e2 View commit details
    Browse the repository at this point in the history