Skip to content

Commit

Permalink
Add example for awq like flow (#990)
Browse files Browse the repository at this point in the history
* Add example for awq like flow

Summary:
Added an example and util for awq like flow that applies extra equalization scale tensor to input
activation

Test Plan:
python tutorials/calibration_flow/awq_like.py

Reviewers:

Subscribers:

Tasks:

Tags:

* typo

* docs

* doc

* typo
  • Loading branch information
jerryzh168 authored Oct 2, 2024
1 parent 9229df9 commit 378e6a8
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from .linear_activation_scale import ( # noqat: F403
to_weight_tensor_with_linear_activation_scale_metadata,
)

__all__ = [
"swap_conv2d_1x1_to_linear"
Expand Down Expand Up @@ -47,6 +50,7 @@
"fpx_weight_only",
"LinearActivationQuantizedTensor",
"to_linear_activation_quantized",
"to_weight_tensor_with_linear_activation_scale_metadata",
"float8_weight_only",
"float8_dynamic_activation_float8_weight"
]
165 changes: 165 additions & 0 deletions torchao/quantization/linear_activation_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from typing import Any, Callable, Dict, Optional

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
)

__all__ = [
"WeightTensorWithLinearActivationScaleMetadata",
"to_weight_tensor_with_linear_activation_scale_metadata",
]

aten = torch.ops.aten


class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor):
"""
Tensor subclass that wraps a weight tensor and provides metadata for linear activation scaling.
Right now we hardcode how we apply the scale:
scaled_linear_act = input_act / scale
out = F.linear(scaled_linear_act, weight, ...)
We can generalize this to accept a function as well if needed.
Args:
original_weight_tensor (torch.Tensor): The weight tensor to be wrapped.
scale (torch.Tensor): The scale tensor to be applied to activation.
"""

original_weight_tensor: torch.Tensor
scale: torch.Tensor

def __new__(
cls,
original_weight_tensor: torch.Tensor,
scale: torch.Tensor,
):
kwargs = {}
dtype = original_weight_tensor.dtype
kwargs["dtype"] = dtype
kwargs["requires_grad"] = False
kwargs["device"] = original_weight_tensor.device
shape = original_weight_tensor.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
original_weight_tensor: torch.Tensor,
scale: torch.Tensor,
):
self.original_weight_tensor = original_weight_tensor
self.scale = scale

def __repr__(self):
return f"WeightTensorWithLinearActivationScaleMetadata({self.original_weight_tensor}, scale={self.scale}"

def __tensor_flatten__(self):
tensor_data = ["original_weight_tensor", "scale"]
return tensor_data, []

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
return cls(
tensor_data_dict["original_weight_tensor"],
tensor_data_dict["scale"],
)

@staticmethod
def _quantized_linear_op(
input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor
):
original_weight_tensor = weight_tensor.original_weight_tensor
scale = weight_tensor.scale
# Note: we can make this function configurable as well
scaled_input_act = input_tensor / scale
return torch.nn.functional.linear(
scaled_input_act, original_weight_tensor, bias
)

@classmethod
def from_float(
cls,
input_float: torch.Tensor,
scale: torch.Tensor,
):
return cls(input_float, scale)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.original_weight_tensor),
fn(self.scale),
)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs.pop("device")
return self.__class__(
self.original_weight_tensor.to(device),
self.scale.to(device),
)


implements = WeightTensorWithLinearActivationScaleMetadata.implements


@implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if isinstance(weight_tensor, WeightTensorWithLinearActivationScaleMetadata):
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)

raise NotImplementedError(
"LinearActivationQuantizedTensor: No specialized dispatch found for linear op"
)


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)


@implements(aten.t.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
)


to_weight_tensor_with_linear_activation_scale_metadata = (
WeightTensorWithLinearActivationScaleMetadata.from_float
)

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals(
[WeightTensorWithLinearActivationScaleMetadata]
)
155 changes: 155 additions & 0 deletions tutorials/calibration_flow/awq_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
Demo for awq like flow that applies equalization scale to input activation
* insert_observers_: insert observer for activation and weight
* quantize_: convert the observed linear module to quantized linear module
* we first quantize the weight with to_affine_quantized_intx/floatx
* then we apply equalization scale to linear activation with to_weight_tensor_with_linear_activation_scale_metadata (input activation will be divided by equalization_scale), and then call F.linear with
scaled input activation and quantized weight (so we can reuse the efficient quantized linear kernels used by quantized weight)
"""
import torch
import copy

import torch.nn.functional as F
from torch import Tensor
from torchao.dtypes import (
to_affine_quantized_intx_static,
to_affine_quantized_floatx_static,
Float8LayoutType,
)
from torchao.quantization.utils import compute_error
from torchao.quantization import quantize_
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
PerAxis,
)
from torchao.quantization.quant_primitives import (
MappingType,
FP8_TYPES,
)


class ObservedLinear(torch.nn.Linear):
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.act_obs = act_obs
self.weight_obs = weight_obs

def forward(self, input: Tensor):
observed_input = self.act_obs(input)
observed_weight = self.weight_obs(self.weight)
return F.linear(observed_input, observed_weight, self.bias)

@classmethod
def from_float(cls, float_linear, act_obs, weight_obs):
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
observed_linear.weight = float_linear.weight
observed_linear.bias = float_linear.bias
return observed_linear

def insert_observers_(model, act_obs, weight_obs):
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)

def replacement_fn(m):
copied_act_obs = copy.deepcopy(act_obs)
copied_weight_obs = copy.deepcopy(weight_obs)
return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs)

_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)

# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
def apply_awq(target_dtype: torch.dtype):
# target_dtype = torch.uint8
def _apply_awq_to_linear(observed_linear):
# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
return to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
elif target_dtype == torch.float8_e4m3fn:
return to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8LayoutType(mm_config=None))
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

# activation quantization
# pretend this to be the equalization scale, in reality the `act_obs` should
# be an observer that can caluclate equalization scale
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
equalization_scale = torch.ones_like(equalization_scale)

linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight * equalization_scale), requires_grad=False)

linear.weight = torch.nn.Parameter(to_weight_tensor_with_linear_activation_scale_metadata(linear.weight, equalization_scale), requires_grad=False)

return linear

return _apply_awq_to_linear



######## Test ##########
class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, k, bias=False)
self.linear2 = torch.nn.Linear(k, n, bias=False)

def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
print(f"Testing {target_dtype} static quantization:")
torch.manual_seed(0)

dtype = torch.bfloat16
m = ToyLinearModel().eval().to(dtype).to("cuda")

m_for_test = copy.deepcopy(m)

m_bf16 = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
print("example inputs shape:", example_inputs[0].shape)

m_bf16 = torch.compile(m_bf16, mode='max-autotune')

act_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps)
weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps)

before_quant = m(*example_inputs)

insert_observers_(m, act_obs, weight_obs)
# calibrating / training
for _ in range(10):
m(*example_inputs)

after_obs = m(*example_inputs)

m2 = copy.deepcopy(m)

is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)

# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, apply_awq(target_dtype), is_observed_linear)
print("quantized model (applying tensor subclass to weight):", m)
after_quant = m(*example_inputs)
assert compute_error(before_quant, after_quant) > 25
print("test passed")


if __name__ == "__main__":
test_awq(torch.uint8, MappingType.ASYMMETRIC)
test_awq(torch.float8_e4m3fn, MappingType.SYMMETRIC)

0 comments on commit 378e6a8

Please sign in to comment.