-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] SmoothQuant using tensor subclassing #1030
Draft
Xia-Weiwen
wants to merge
38
commits into
pytorch:main
Choose a base branch
from
Xia-Weiwen:smooth_quant
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+805
−2
Draft
Changes from 28 commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
d34859b
SmoothQuant using tensor subclassing
Xia-Weiwen df5b49f
Merge branch 'main' into smooth_quant
Xia-Weiwen 847f1f2
Update UT
Xia-Weiwen f03cfb3
Add SmoothQuant example
Xia-Weiwen a2518f1
Remove duplicate implementation of int_scaled_matmul for CPU
Xia-Weiwen 28fb8ce
Update example.py
Xia-Weiwen bada2b0
Remove unused code
Xia-Weiwen 921efc0
Implement with LinearActivationQuantizedTensor
Xia-Weiwen ad5b97e
Fix load/save
Xia-Weiwen f1be01d
Fix device mismatch in observer
Xia-Weiwen 7ee1f13
Fix fp16 overflow issue in int_scaled_matmul
Xia-Weiwen c773386
Merge branch 'main' into smooth_quant
Xia-Weiwen 427ff73
Add linear_activation_scale_quantized.py for torch.compile
Xia-Weiwen 9916113
Quantize act/wei to 7 bit on old CPU platforms
Xia-Weiwen 52260b6
Fix device mismatch
Xia-Weiwen ca50fee
Fix UT failures
Xia-Weiwen 3e90789
Fix UT
Xia-Weiwen d47fcc1
Don't use torch._int_mm for CPU now because it may overflow
Xia-Weiwen a195e73
Remove reduce_range
Xia-Weiwen 6627be1
Refine code
Xia-Weiwen fb981e7
Remove torch.compile from example
Xia-Weiwen 17c374e
Add torch.compile in example
Xia-Weiwen bb76de6
Debug CI failures
Xia-Weiwen 98b2de1
Debug CI failures (1)
Xia-Weiwen 316f5ea
Debug CI failures (2)
Xia-Weiwen b4d8383
Debug CI failures (3)
Xia-Weiwen aca06d2
Work with torch.compile
Xia-Weiwen dde7545
Update torchao/kernel/intmm.py
Xia-Weiwen 00cfadd
Update readme.md
Xia-Weiwen 466d2f1
Update readme.md
Xia-Weiwen e970a4a
Debug CI failures (4)
Xia-Weiwen 5e2abbe
Reimplement with nested tensor subclassing
Xia-Weiwen 90d1b7d
Test torch.compile only with PyTorch >= 2.5
Xia-Weiwen 03d490b
Debug CI failures (5)
Xia-Weiwen f595ed4
Debug CI failures (6)
Xia-Weiwen 2202f69
Debug CI failures (7)
Xia-Weiwen 6ea8aa8
Use MovingAvg observer for activation; Update UT and readme
Xia-Weiwen fa1144c
Revert changes to test_spinquant.py; refine readme
Xia-Weiwen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
from copy import deepcopy | ||
import os | ||
import pytest | ||
import torch | ||
import tempfile | ||
from torch.ao.quantization import HistogramObserver | ||
from torchao.quantization import quantize_ | ||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_2, | ||
TORCH_VERSION_AT_LEAST_2_4, | ||
) | ||
from torchao.quantization.utils import ( | ||
dynamically_quantize_per_channel, | ||
dequantize_per_channel, | ||
) | ||
from torchao.prototype.smoothquant import ( | ||
insert_smooth_quant_observer, | ||
smooth_quant, | ||
SmoothQuantObservedLinear, | ||
save_smooth_quant_recipe, | ||
load_smooth_quant_recipe | ||
) | ||
|
||
class ToyLinearModel(torch.nn.Module): | ||
def __init__(self, m=512, n=256, k=128): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(m, n, bias=False) | ||
self.linear2 = torch.nn.Linear(n, k, bias=False) | ||
self.linear3 = torch.nn.Linear(k, 1, bias=False) | ||
|
||
def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): | ||
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
x = self.linear3(x) | ||
return x | ||
|
||
|
||
bias_list = [True, False] | ||
alpha_list = [0.5, 0.75] | ||
quant_mode_list = ["static", "dynamic"] | ||
devices = ["cpu"] | ||
if torch.cuda.is_available(): | ||
devices.append("cuda") | ||
idtypes = (torch.float, torch.bfloat16, torch.half) | ||
|
||
@pytest.mark.parametrize("bias", bias_list) | ||
@pytest.mark.parametrize("alpha", alpha_list) | ||
@pytest.mark.parametrize("quant_mode", quant_mode_list) | ||
@pytest.mark.parametrize("device", devices) | ||
@pytest.mark.parametrize("idtype", idtypes) | ||
def test_compute(bias, alpha, quant_mode, device, idtype): | ||
class Linear(torch.nn.Module): | ||
def __init__(self, bias: bool): | ||
super().__init__() | ||
self.fc = torch.nn.Linear(16, 16, bias) | ||
self.fc.weight.data = torch.randn_like(self.fc.weight.data) | ||
|
||
def forward(self, x): | ||
return self.fc(x) | ||
|
||
m = Linear(bias).eval().to(idtype).to(device) | ||
m_ref = deepcopy(m) | ||
data = torch.randn(2, 16, dtype=idtype, device=device) | ||
|
||
# calibrate | ||
insert_smooth_quant_observer(m, alpha, quant_mode, n_calib_examples=1) | ||
m(data) | ||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, smooth_quant(), is_observed_linear) | ||
with torch.inference_mode(): | ||
# m = torch.compile(m, fullgraph=True) | ||
out = m(data) | ||
|
||
# reference | ||
weight = m_ref.fc.weight.data.float() | ||
b = m_ref.fc.bias if bias else None | ||
x_abs_max_per_ic = torch.abs(data).max(dim=0).values | ||
w_abs_max_per_ic = torch.abs(weight).max(dim=0).values | ||
smoothing_factor = ( | ||
torch.pow(x_abs_max_per_ic, alpha) / torch.pow( | ||
w_abs_max_per_ic, 1 - alpha) | ||
) | ||
act = data / smoothing_factor | ||
wei = weight * smoothing_factor | ||
qw, w_scales, w_zps = dynamically_quantize_per_channel( | ||
wei, -127, 127, torch.int8 | ||
) | ||
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype) | ||
if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ | ||
not TORCH_VERSION_AT_LEAST_2_2: | ||
# _int_mm is not supported in these cases | ||
out_ref = torch.nn.functional.linear(act.to(idtype), fq_wei, b) | ||
elif quant_mode == "static": | ||
# activation is quantized per-tensor | ||
obs = HistogramObserver( | ||
dtype=torch.int8, | ||
qscheme=torch.per_tensor_symmetric, | ||
quant_min=-127, | ||
quant_max=127, | ||
) | ||
obs(act.float().to("cpu")) | ||
act_scale, _ = obs.calculate_qparams() | ||
fq_act = torch.quantize_per_tensor( | ||
act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8 | ||
).dequantize().to(idtype) | ||
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) | ||
else: | ||
# activation is quantized per-row (batch * sequence_length) | ||
qx, x_scales, x_zps = dynamically_quantize_per_channel( | ||
act.float(), -127, 127, torch.int8 | ||
) | ||
fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype) | ||
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) | ||
|
||
# BFloat16 and Float16 have larger errors | ||
assert torch.allclose(out, out_ref.to(idtype), atol = 0.2) | ||
|
||
|
||
@pytest.mark.parametrize("alpha", alpha_list) | ||
@pytest.mark.parametrize("quant_mode", quant_mode_list) | ||
@pytest.mark.parametrize("device", devices) | ||
@pytest.mark.parametrize("idtype", idtypes) | ||
def test_save_load_recipe(alpha, quant_mode, device, idtype): | ||
dataset_size = 20 | ||
l1, l2, l3 = 512, 256, 128 | ||
original_dtype = idtype | ||
n_calib_examples = 10 | ||
sequence_length = 5 | ||
|
||
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) | ||
m_save_load = deepcopy(m) | ||
|
||
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) | ||
calibration_data = dataset[:n_calib_examples] | ||
|
||
# calibrate | ||
insert_smooth_quant_observer(m, alpha, quant_mode, n_calib_examples=n_calib_examples) | ||
insert_smooth_quant_observer(m_save_load, alpha, quant_mode, n_calib_examples=n_calib_examples) | ||
|
||
for example in calibration_data: | ||
m(example.to(device)) | ||
m_save_load(example.to(device)) | ||
|
||
with tempfile.NamedTemporaryFile() as fp: | ||
save_path = fp.name | ||
save_smooth_quant_recipe(m_save_load, save_path) | ||
load_smooth_quant_recipe(m_save_load, save_path) | ||
|
||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, smooth_quant(), is_observed_linear) | ||
# m = torch.compile(m, fullgraph=True) | ||
# m_save_load = torch.compile(m_save_load, fullgraph=True) | ||
out_list = [m(data.squeeze(0)) for data in dataset] | ||
out = torch.cat(out_list) | ||
save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset] | ||
save_load_out = torch.cat(save_load_out_list) | ||
|
||
assert out is not None | ||
assert save_load_out is not None | ||
assert torch.allclose(out, save_load_out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .api import ( | ||
insert_smooth_quant_observer, | ||
smooth_quant, | ||
save_smooth_quant_recipe, | ||
load_smooth_quant_recipe, | ||
) | ||
from .core import SmoothQuantObservedLinear |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import torch | ||
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter | ||
from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static | ||
from torchao.quantization.linear_activation_scale_quantized import ( | ||
to_linear_scale_activation_quantized, | ||
) | ||
from torchao.quantization.quant_primitives import MappingType | ||
from torchao.quantization.utils import _get_per_token_block_size | ||
from torchao.prototype.smoothquant.core import( | ||
SmoothQuantObserver, | ||
SmoothQuantObservedLinear, | ||
) | ||
from typing import Dict, Optional | ||
from torch._dynamo import is_compiling as dynamo_is_compiling | ||
|
||
|
||
def insert_smooth_quant_observer( | ||
model: torch.nn.Module, | ||
alpha: float = 0.5, | ||
quant_mode: str = "static", | ||
n_calib_examples: int = 20): | ||
""" | ||
Inserts SmoothQuantObserver into Linear layers of a given model. | ||
|
||
Args: | ||
model: The model to be modified (in place). Ensure model is on the desired device for calibration | ||
mapping_type: symmetric or asymmetric quantization of weight | ||
n_calib_examples: Number of examples used for calibration | ||
""" | ||
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) | ||
|
||
quant_min, quant_max = -127, 127 | ||
eps = torch.finfo(torch.float32).eps | ||
|
||
def replace_with_observer(layer): | ||
# creates observer and replaces linear layers with observed linear layers | ||
observer = SmoothQuantObserver( | ||
layer.weight, | ||
alpha, | ||
quant_mode, | ||
n_calib_examples, | ||
quant_min=quant_min, | ||
quant_max=quant_max, | ||
eps=eps) | ||
return SmoothQuantObservedLinear.from_float(layer, observer) | ||
|
||
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) | ||
|
||
|
||
def _observed_linear_subclass_inserter(constructor): | ||
""" | ||
Replaces unquantized observed linear instances with quantized linear instances. | ||
|
||
Args: | ||
constructor: the function which applies quantization to the observed linear layer | ||
""" | ||
def insert_subclass(observed_linear): | ||
# creates the new linear layer using constructor | ||
linear = torch.nn.Linear( | ||
observed_linear.in_features, | ||
observed_linear.out_features, | ||
observed_linear.bias is not None, | ||
device=observed_linear.weight.device, | ||
dtype=observed_linear.weight.dtype | ||
) | ||
linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False) | ||
linear.bias = observed_linear.bias | ||
return linear | ||
|
||
return insert_subclass | ||
|
||
|
||
def save_smooth_quant_recipe(model: torch.nn.Module, save_path: str) -> Dict[str, torch.Tensor]: | ||
""" | ||
Save smoothing_factors, act_scales, and wei_scales for each SmoothQuantObservedLinear layer in the model. | ||
""" | ||
result = {} | ||
|
||
def recurse(module: torch.nn.Module, name: str = ''): | ||
for child_name, child in module.named_children(): | ||
full_name = f"{name}.{child_name}" if name else child_name | ||
|
||
# Apply the analysis function to this layer | ||
if isinstance(child, SmoothQuantObservedLinear): | ||
smoothing_factor, act_scales, wei_scales = child.obs.calculate_qparams() | ||
result[full_name + ".smoothing_factor"] = smoothing_factor | ||
result[full_name + ".act_scales"] = act_scales | ||
result[full_name + ".wei_scales"] = wei_scales | ||
|
||
# Recurse into child modules | ||
recurse(child, full_name) | ||
|
||
recurse(model) | ||
|
||
torch.save(result, save_path) | ||
|
||
|
||
def load_smooth_quant_recipe(model: torch.nn.Module, recipe_path: str, device=None) -> torch.nn.Module: | ||
recipe = torch.load(recipe_path, weights_only=True) | ||
|
||
def recurse(module: torch.nn.Module, name: str = ''): | ||
if isinstance(module, SmoothQuantObservedLinear): | ||
smoothing_factor = recipe.get(name + ".smoothing_factor", None) | ||
act_scales = recipe.get(name + ".act_scales", None) | ||
wei_scales = recipe.get(name + ".wei_scales", None) | ||
if device is not None: | ||
module.to(device=device) | ||
# act_scales is None for dynamic quantization | ||
if any(x is None for x in (smoothing_factor, wei_scales)): | ||
return module | ||
return smooth_quant(smoothing_factor, act_scales, wei_scales)(module) | ||
|
||
mod_new = module | ||
|
||
for child_name, child in module.named_children(): | ||
full_name = f"{name}.{child_name}" if name else child_name | ||
setattr(mod_new, child_name, recurse(child, full_name)) | ||
return mod_new | ||
|
||
recurse(model) | ||
|
||
|
||
def smooth_quant( | ||
smoothing_factor: Optional[torch.Tensor] = None, | ||
act_scales: Optional[torch.Tensor] = None, | ||
wei_scales: Optional[torch.Tensor] = None | ||
): | ||
""" | ||
Quantizes linear layers when passed into quantize_() | ||
|
||
Args: | ||
smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None. | ||
act_scales: The activation scales for the layer. Acquired from the layer's observer if None. | ||
wei_scales: The weight scales for the layer. Acquired from the layer's observer if None. | ||
""" | ||
|
||
def quantize_weight(observed_linear): | ||
target_dtype = torch.int8 | ||
# act_scales is None for dynamic quantization thus not checked | ||
if any(x is None for x in (smoothing_factor, wei_scales)): | ||
factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() | ||
weight = observed_linear.obs.weight * factor | ||
else: | ||
factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales | ||
weight = observed_linear.weight * factor | ||
weight = weight.to(observed_linear.weight.dtype) | ||
block_size = (1, weight.size(1)) | ||
wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64) | ||
qw = to_affine_quantized_intx_static( | ||
weight, | ||
w_scales, | ||
wei_zero_points, | ||
block_size, | ||
target_dtype, | ||
) | ||
|
||
return to_linear_scale_activation_quantized( | ||
qw, factor, x_scale, None, target_dtype, -127, 127 | ||
) | ||
|
||
return _observed_linear_subclass_inserter(quantize_weight) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this? or just saving the state_dict for observed model is enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to have an API to modify (tune) quantization parameters, i.e. the recipe here. Do you have any concern about adding this API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the state_dict is supposed to be used by other APIs to tune quantization parameters? I think that's fine if you have this use case in mind, is the model with SmoothQuantObservedLinear not serializable by itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SmoothQuantObservedLinear
is serializable. However, a recipe is more flexible to tune parameters. Thanks.