Skip to content

Commit

Permalink
Move and rename GranularityType -> Granularity (#1038)
Browse files Browse the repository at this point in the history
* Make module swap the main QAT flow again

Summary: Following #987, this
commit makes module swap the main QAT flow today. We remove all
tensor subclass fake quantize injection logic since this is not
needed in both the long term and the short term plans for QAT.
In the short term, we will continue to use a full module swap
flow, and only migrate to the long term flow once there is
general distributed support for tensor subclasses and when
tensor subclass composability provides meaningful benefits.

Test Plan:
python test/quantization/test_qat.py

[ghstack-poisoned]

* Move and rename GranularityType -> Granularity

Summary: Move GranularityType to quant_primitives.py to be
consistent with other similar fields like MappingType and
ZeroPointDomain.

Test Plan: CI

[ghstack-poisoned]

* Update on "Move and rename GranularityType -> Granularity"


Summary: Move GranularityType to quant_primitives.py to be
consistent with other similar fields like MappingType and
ZeroPointDomain.

Test Plan: CI

[ghstack-poisoned]

* Update on "Move and rename GranularityType -> Granularity"


Summary: Move GranularityType to quant_primitives.py to be
consistent with other similar fields like MappingType and
ZeroPointDomain.

Test Plan: CI

[ghstack-poisoned]

* Update on "Move and rename GranularityType -> Granularity"


Summary: Move GranularityType to quant_primitives.py to be
consistent with other similar fields like MappingType and
ZeroPointDomain.

Test Plan: CI

[ghstack-poisoned]

* Update base for Update on "Move and rename GranularityType -> Granularity"


Summary: Move GranularityType to quant_primitives.py to be
consistent with other similar fields like MappingType and
ZeroPointDomain.

Test Plan: CI

[ghstack-poisoned]
  • Loading branch information
andrewor14 authored Oct 10, 2024
1 parent 107e378 commit 0f6bae5
Show file tree
Hide file tree
Showing 15 changed files with 143 additions and 111 deletions.
10 changes: 8 additions & 2 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@
float8_weight_only,
quantize_,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.granularity import (
PerRow,
PerTensor,
)
from torchao.quantization.quant_api import (
float8_static_activation_float8_weight,
)
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
)

random.seed(0)
torch.manual_seed(0)
Expand Down
22 changes: 12 additions & 10 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase

from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
from torchao.quantization.granularity import (
PerAxis,
PerTensor,
)
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
)
from torchao.quantization.quant_api import (
insert_observers_,
)
Expand Down Expand Up @@ -42,7 +44,7 @@ def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -54,7 +56,7 @@ def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerAxis(axis=0),
granularity=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -68,7 +70,7 @@ def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -87,7 +89,7 @@ def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
granularity=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -102,7 +104,7 @@ def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(0),
granularity=PerAxis(0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -121,7 +123,7 @@ def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
granularity=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand Down Expand Up @@ -149,7 +151,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -159,7 +161,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
weight_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand Down
4 changes: 2 additions & 2 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.quantization.granularity import PerRow, PerTensor

from tokenizer import get_tokenizer
import time
Expand Down Expand Up @@ -255,4 +255,4 @@ def run_evaluation(
args.calibration_limit,
args.calibration_seq_length,
args.pad_calibration_inputs,
)
)
2 changes: 1 addition & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def main(
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerTensor, PerRow
from torchao.quantization.granularity import PerTensor, PerRow
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch
import torch.nn.functional as F

from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
_DTYPE_TO_QVALUE_BOUNDS,
)
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
from torchao.quantization.observer import PerGroup
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
from torchao.dtypes import(
Expand Down
9 changes: 5 additions & 4 deletions torchao/prototype/awq/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization.granularity import Granularity
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.observer import (
AffineQuantizedObserverBase, GranularityType
AffineQuantizedObserverBase,
)


class AWQObserver(AffineQuantizedObserverBase):
def __init__(self,
weight: torch.Tensor,
bias: torch.Tensor,
quantization_granularity: GranularityType,
quantization_granularity: Granularity,
mapping_type: MappingType,
target_dtype: torch.dtype,
n_validation_examples: int,
Expand All @@ -40,7 +41,7 @@ def __init__(self,
Args:
weight: The weight tensor to be observed.
bias: The bias tensor to be observed.
quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point
quantization_granularity: Granularity which specifies how many weights share the same scale/zero point
input_dtype: The data type of the input tensor.
mapping_type: Always set to asymmetric
target_dtype: The target data type of the quantized tensor
Expand Down Expand Up @@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver):
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
observed_linear.weight = float_linear.weight
observed_linear.bias = float_linear.bias
return observed_linear
return observed_linear
2 changes: 1 addition & 1 deletion torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model)
```python
# for torch 2.4+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from torchao.quantization.observer import PerTensor
from torchao.quantization.quant_api import PerTensor
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
```

Expand Down
8 changes: 5 additions & 3 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
safe_int_mm,
from .granularity import (
PerAxis,
PerRow,
PerTensor,
)
from .quant_primitives import safe_int_mm
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
from torchao.quantization.utils import quantize_activation_per_token_absmax
from torchao.quantization.observer import PerAxis, PerTensor, PerRow
from torchao.float8.inference import Float8MMConfig

import torch.nn.functional as F
Expand Down
76 changes: 76 additions & 0 deletions torchao/quantization/granularity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass


@dataclass(frozen=True)
class Granularity:
"""
Base class for representing the granularity of quantization.
This class serves as a parent for specific granularity types used in
quantization operations, such as per-tensor or per-axis quantization.
"""
pass

@dataclass(frozen=True)
class PerTensor(Granularity):
"""
Represents per-tensor granularity in quantization.
This granularity type calcualtes the quantization parameters
based off the entire tensor.
"""
pass

@dataclass(frozen=True)
class PerAxis(Granularity):
"""
Represents per-axis granularity in quantization.
This granularity type calcualtes different quantization parameters
along a specified axis of the tensor.
For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.
Attributes:
axis (int): The axis along which reduction is performed.
"""
axis: int

@dataclass(frozen=True)

class PerGroup(Granularity):
"""
Represents per-channel group granularity in quantization.
This granularity type calcualtes different quantization parameters
for each group of <group_size> elements.
For example if the input tensor is shape [8, 16], and the group size is 4, then
the input tensor is reshaped to [64, 4]
quantization parameters are calculated for each group of 4 elements,
giving a total of 64 quantization parameters.
Attributes:
group_size (int): The size of each quantization group
"""
group_size: int

class PerRow(Granularity):
"""
Represents row-wise granularity in quantization.
This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
"""
pass
Loading

0 comments on commit 0f6bae5

Please sign in to comment.