Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] SmoothQuant using tensor subclassing #1030

Draft
wants to merge 38 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d34859b
SmoothQuant using tensor subclassing
Xia-Weiwen Sep 30, 2024
df5b49f
Merge branch 'main' into smooth_quant
Xia-Weiwen Oct 8, 2024
847f1f2
Update UT
Xia-Weiwen Oct 8, 2024
f03cfb3
Add SmoothQuant example
Xia-Weiwen Oct 8, 2024
a2518f1
Remove duplicate implementation of int_scaled_matmul for CPU
Xia-Weiwen Oct 9, 2024
28fb8ce
Update example.py
Xia-Weiwen Oct 9, 2024
bada2b0
Remove unused code
Xia-Weiwen Oct 9, 2024
921efc0
Implement with LinearActivationQuantizedTensor
Xia-Weiwen Oct 10, 2024
ad5b97e
Fix load/save
Xia-Weiwen Oct 10, 2024
f1be01d
Fix device mismatch in observer
Xia-Weiwen Oct 10, 2024
7ee1f13
Fix fp16 overflow issue in int_scaled_matmul
Xia-Weiwen Oct 11, 2024
c773386
Merge branch 'main' into smooth_quant
Xia-Weiwen Oct 11, 2024
427ff73
Add linear_activation_scale_quantized.py for torch.compile
Xia-Weiwen Oct 11, 2024
9916113
Quantize act/wei to 7 bit on old CPU platforms
Xia-Weiwen Oct 12, 2024
52260b6
Fix device mismatch
Xia-Weiwen Oct 12, 2024
ca50fee
Fix UT failures
Xia-Weiwen Oct 12, 2024
3e90789
Fix UT
Xia-Weiwen Oct 12, 2024
d47fcc1
Don't use torch._int_mm for CPU now because it may overflow
Xia-Weiwen Oct 12, 2024
a195e73
Remove reduce_range
Xia-Weiwen Oct 12, 2024
6627be1
Refine code
Xia-Weiwen Oct 14, 2024
fb981e7
Remove torch.compile from example
Xia-Weiwen Oct 14, 2024
17c374e
Add torch.compile in example
Xia-Weiwen Oct 15, 2024
bb76de6
Debug CI failures
Xia-Weiwen Oct 15, 2024
98b2de1
Debug CI failures (1)
Xia-Weiwen Oct 15, 2024
316f5ea
Debug CI failures (2)
Xia-Weiwen Oct 15, 2024
b4d8383
Debug CI failures (3)
Xia-Weiwen Oct 16, 2024
aca06d2
Work with torch.compile
Xia-Weiwen Oct 16, 2024
dde7545
Update torchao/kernel/intmm.py
Xia-Weiwen Oct 16, 2024
00cfadd
Update readme.md
Xia-Weiwen Oct 17, 2024
466d2f1
Update readme.md
Xia-Weiwen Oct 17, 2024
e970a4a
Debug CI failures (4)
Xia-Weiwen Oct 17, 2024
5e2abbe
Reimplement with nested tensor subclassing
Xia-Weiwen Oct 18, 2024
90d1b7d
Test torch.compile only with PyTorch >= 2.5
Xia-Weiwen Oct 18, 2024
03d490b
Debug CI failures (5)
Xia-Weiwen Oct 18, 2024
f595ed4
Debug CI failures (6)
Xia-Weiwen Oct 18, 2024
2202f69
Debug CI failures (7)
Xia-Weiwen Oct 18, 2024
6ea8aa8
Use MovingAvg observer for activation; Update UT and readme
Xia-Weiwen Oct 18, 2024
fa1144c
Revert changes to test_spinquant.py; refine readme
Xia-Weiwen Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from copy import deepcopy
import os
import pytest
import torch
import tempfile
from torch.ao.quantization import MovingAverageMinMaxObserver
from torchao.quantization import quantize_
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_2,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
)
from torchao.quantization.utils import (
dynamically_quantize_per_channel,
dequantize_per_channel,
)
from torchao.prototype.smoothquant import (
insert_smooth_quant_observer,
smooth_quant,
SmoothQuantObservedLinear,
save_smooth_quant_recipe,
load_smooth_quant_recipe
)

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
self.linear3 = torch.nn.Linear(k, 1, bias=False)

def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"):
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)]

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x


bias_list = [True, False]
alpha_list = [0.5, 0.75]
quant_mode_list = ["static", "dynamic"]
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
idtypes = (torch.float, torch.bfloat16)

if TORCH_VERSION_AT_LEAST_2_5:
# This test case will trigger recompilation many times, so set a large cache_size_limit here
torch._dynamo.config.cache_size_limit = 32

@pytest.mark.parametrize("bias", bias_list)
@pytest.mark.parametrize("alpha", alpha_list)
@pytest.mark.parametrize("quant_mode", quant_mode_list)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("idtype", idtypes)
def test_compute(bias, alpha, quant_mode, device, idtype):
class Linear(torch.nn.Module):
def __init__(self, bias: bool):
super().__init__()
self.fc = torch.nn.Linear(16, 16, bias)
self.fc.weight.data = torch.randn_like(self.fc.weight.data)

def forward(self, x):
return self.fc(x)

m = Linear(bias).eval().to(idtype).to(device)
m_ref = deepcopy(m)
data = torch.randn(2, 16, dtype=idtype, device=device)

# calibrate
insert_smooth_quant_observer(m, alpha, quant_mode, n_calib_examples=1)
m(data)
# quantize
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, smooth_quant(), is_observed_linear)
with torch.inference_mode():
if TORCH_VERSION_AT_LEAST_2_5:
m = torch.compile(m, fullgraph=True)
out = m(data)

# reference
weight = m_ref.fc.weight.data.float()
b = m_ref.fc.bias if bias else None
x_abs_max_per_ic = torch.abs(data).max(dim=0).values
w_abs_max_per_ic = torch.abs(weight).max(dim=0).values
smoothing_factor = (
torch.pow(x_abs_max_per_ic, alpha) / torch.pow(
w_abs_max_per_ic, 1 - alpha)
)
act = data / smoothing_factor
wei = weight * smoothing_factor
qw, w_scales, w_zps = dynamically_quantize_per_channel(
wei, -127, 127, torch.int8
)
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype)
if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \
not TORCH_VERSION_AT_LEAST_2_2:
# _int_mm is not supported in these cases
out_ref = torch.nn.functional.linear(act.to(idtype), fq_wei, b)
elif quant_mode == "static":
# activation is quantized per-tensor
obs = MovingAverageMinMaxObserver(
dtype=torch.int8,
qscheme=torch.per_tensor_symmetric,
quant_min=-127,
quant_max=127,
)
obs(act.float().to("cpu"))
act_scale, _ = obs.calculate_qparams()
fq_act = torch.quantize_per_tensor(
act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8
).dequantize().to(idtype)
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b)
else:
# activation is quantized per-row (batch * sequence_length)
qx, x_scales, x_zps = dynamically_quantize_per_channel(
act.float(), -127, 127, torch.int8
)
fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype)
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b)

# BFloat16 and Float16 have larger errors
assert torch.allclose(out, out_ref.to(idtype), atol = 0.2)


@pytest.mark.parametrize("alpha", alpha_list)
@pytest.mark.parametrize("quant_mode", quant_mode_list)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("idtype", idtypes)
def test_save_load_recipe(alpha, quant_mode, device, idtype):
dataset_size = 20
l1, l2, l3 = 512, 256, 128
original_dtype = idtype
n_calib_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
m_save_load = deepcopy(m)

dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device)
calibration_data = dataset[:n_calib_examples]

# calibrate
insert_smooth_quant_observer(m, alpha, quant_mode, n_calib_examples=n_calib_examples)
insert_smooth_quant_observer(m_save_load, alpha, quant_mode, n_calib_examples=n_calib_examples)

for example in calibration_data:
m(example.to(device))
m_save_load(example.to(device))

with tempfile.NamedTemporaryFile() as fp:
save_path = fp.name
save_smooth_quant_recipe(m_save_load, save_path)
load_smooth_quant_recipe(m_save_load, save_path)

# quantize
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, smooth_quant(), is_observed_linear)
if TORCH_VERSION_AT_LEAST_2_5:
# earlier versions are not compatible
m = torch.compile(m, fullgraph=True)
m_save_load = torch.compile(m_save_load, fullgraph=True)
out_list = [m(data.squeeze(0)) for data in dataset]
out = torch.cat(out_list)
save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset]
save_load_out = torch.cat(save_load_out_list)

assert out is not None
assert save_load_out is not None
assert torch.allclose(out, save_load_out)
1 change: 0 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,6 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
isinstance(input_tensor, AffineQuantizedTensor) and
_aqt_is_int8_reduced_range(input_tensor) and
isinstance(weight_tensor, AffineQuantizedTensor) and
weight_tensor.is_cuda and
input_tensor.dtype == weight_tensor.dtype and
isinstance(input_tensor._layout, PlainLayout) and
isinstance(weight_tensor._layout, PlainLayout)
Expand Down
5 changes: 4 additions & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
# torch.compile path
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
if input.device.type == "cpu":
# Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend
return out_dtype(torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float())
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)

# error checking for cublas path
Expand Down Expand Up @@ -126,7 +129,7 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -
"""
M, K = a.shape
K, N = b.shape
assert M == scales1.size(0)
assert M == scales1.size(0) or scales1.numel() == 1
assert 1 == scales1.size(1)
assert scales1.is_contiguous()
scales1 = scales1.expand((M, N))
Expand Down
7 changes: 7 additions & 0 deletions torchao/prototype/smoothquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .api import (
insert_smooth_quant_observer,
smooth_quant,
save_smooth_quant_recipe,
load_smooth_quant_recipe,
)
from .core import SmoothQuantObservedLinear
190 changes: 190 additions & 0 deletions torchao/prototype/smoothquant/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import torch
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static
from torchao.quantization.linear_activation_quantized_tensor import (
to_linear_activation_quantized,
)
from torchao.quantization.linear_activation_scale import (
to_weight_tensor_with_linear_activation_scale_metadata,
)
from torchao.quantization.weight_tensor_linear_activation_quantization import (
to_weight_tensor_with_linear_activation_quantization_metadata,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.utils import _get_per_token_block_size
from torchao.prototype.smoothquant.core import(
SmoothQuantObserver,
SmoothQuantObservedLinear,
)
from typing import Dict, Optional


def insert_smooth_quant_observer(
model: torch.nn.Module,
alpha: float = 0.5,
quant_mode: str = "static",
n_calib_examples: int = 20):
"""
Inserts SmoothQuantObserver into Linear layers of a given model.

Args:
model: The model to be modified (in place). Ensure model is on the desired device for calibration
mapping_type: symmetric or asymmetric quantization of weight
n_calib_examples: Number of examples used for calibration
"""
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)

quant_min, quant_max = -127, 127
eps = torch.finfo(torch.float32).eps

def replace_with_observer(layer):
# creates observer and replaces linear layers with observed linear layers
observer = SmoothQuantObserver(
layer.weight,
alpha,
quant_mode,
n_calib_examples,
quant_min=quant_min,
quant_max=quant_max,
eps=eps)
return SmoothQuantObservedLinear.from_float(layer, observer)

_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)


def _observed_linear_subclass_inserter(constructor):
"""
Replaces unquantized observed linear instances with quantized linear instances.

Args:
constructor: the function which applies quantization to the observed linear layer
"""
def insert_subclass(observed_linear):
# creates the new linear layer using constructor
linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.bias is not None,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype
)
linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False)
linear.bias = observed_linear.bias
return linear

return insert_subclass


def save_smooth_quant_recipe(model: torch.nn.Module, save_path: str) -> Dict[str, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this? or just saving the state_dict for observed model is enough?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to have an API to modify (tune) quantization parameters, i.e. the recipe here. Do you have any concern about adding this API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the state_dict is supposed to be used by other APIs to tune quantization parameters? I think that's fine if you have this use case in mind, is the model with SmoothQuantObservedLinear not serializable by itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmoothQuantObservedLinear is serializable. However, a recipe is more flexible to tune parameters. Thanks.

"""
Save smoothing_factors, act_scales, and wei_scales for each SmoothQuantObservedLinear layer in the model.
"""
result = {}

def recurse(module: torch.nn.Module, name: str = ''):
for child_name, child in module.named_children():
full_name = f"{name}.{child_name}" if name else child_name

# Apply the analysis function to this layer
if isinstance(child, SmoothQuantObservedLinear):
smoothing_factor, act_scales, wei_scales = child.obs.calculate_qparams()
result[full_name + ".smoothing_factor"] = smoothing_factor
result[full_name + ".act_scales"] = act_scales
result[full_name + ".wei_scales"] = wei_scales

# Recurse into child modules
recurse(child, full_name)

recurse(model)

torch.save(result, save_path)


def load_smooth_quant_recipe(model: torch.nn.Module, recipe_path: str, device=None) -> torch.nn.Module:
recipe = torch.load(recipe_path, weights_only=True)

def recurse(module: torch.nn.Module, name: str = ''):
if isinstance(module, SmoothQuantObservedLinear):
smoothing_factor = recipe.get(name + ".smoothing_factor", None)
act_scales = recipe.get(name + ".act_scales", None)
wei_scales = recipe.get(name + ".wei_scales", None)
if device is not None:
module.to(device=device)
# act_scales is None for dynamic quantization
if any(x is None for x in (smoothing_factor, wei_scales)):
return module
return smooth_quant(smoothing_factor, act_scales, wei_scales)(module)

mod_new = module

for child_name, child in module.named_children():
full_name = f"{name}.{child_name}" if name else child_name
setattr(mod_new, child_name, recurse(child, full_name))
return mod_new

recurse(model)


class _ActQuantizer:
def __init__(self, target_dtype, quant_min=-127):
self.target_dtype = target_dtype
self.quant_min = quant_min

def dynamic_quantize(self, input):
return to_affine_quantized_intx(
input, MappingType.SYMMETRIC, _get_per_token_block_size(input), self.target_dtype, self.quant_min
)

def static_quantize(self, input, scale, zero_point):
return to_affine_quantized_intx_static(
input, scale, zero_point, list(input.shape), self.target_dtype, self.quant_min
)


def smooth_quant(
smoothing_factor: Optional[torch.Tensor] = None,
act_scales: Optional[torch.Tensor] = None,
wei_scales: Optional[torch.Tensor] = None
):
"""
Quantizes linear layers when passed into quantize_()

Args:
smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None.
act_scales: The activation scales for the layer. Acquired from the layer's observer if None.
wei_scales: The weight scales for the layer. Acquired from the layer's observer if None.
"""

def quantize_weight(observed_linear):
target_dtype = torch.int8
# act_scales is None for dynamic quantization thus not checked
if any(x is None for x in (smoothing_factor, wei_scales)):
factor, x_scale, w_scales = observed_linear.obs.calculate_qparams()
weight = observed_linear.obs.weight * factor
else:
factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales
weight = observed_linear.weight * factor
weight = weight.to(observed_linear.weight.dtype)
block_size = (1, weight.size(1))
wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64)
qw = to_affine_quantized_intx_static(
weight,
w_scales,
wei_zero_points,
block_size,
target_dtype,
)

if x_scale is None:
# dynamic quant
qw = to_linear_activation_quantized(qw, _ActQuantizer(target_dtype).dynamic_quantize)
else:
# static quant
x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64)
qw = to_weight_tensor_with_linear_activation_quantization_metadata(
qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point
)

return to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype))

return _observed_linear_subclass_inserter(quantize_weight)
Loading
Loading