From d34859b865402e6b7d950d9b34d8c774774bb2e9 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 30 Sep 2024 07:46:41 -0400 Subject: [PATCH 01/36] SmoothQuant using tensor subclassing --- test/prototype/test_smoothquant.py | 144 ++++++++++ torchao/prototype/smoothquant/__init__.py | 7 + torchao/prototype/smoothquant/api.py | 175 ++++++++++++ torchao/prototype/smoothquant/core.py | 327 ++++++++++++++++++++++ torchao/prototype/smoothquant/example.py | 0 torchao/prototype/smoothquant/readme.md | 0 6 files changed, 653 insertions(+) create mode 100644 test/prototype/test_smoothquant.py create mode 100644 torchao/prototype/smoothquant/__init__.py create mode 100644 torchao/prototype/smoothquant/api.py create mode 100644 torchao/prototype/smoothquant/core.py create mode 100644 torchao/prototype/smoothquant/example.py create mode 100644 torchao/prototype/smoothquant/readme.md diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py new file mode 100644 index 000000000..5e65ce8c2 --- /dev/null +++ b/test/prototype/test_smoothquant.py @@ -0,0 +1,144 @@ +from copy import deepcopy +import os +import pytest +import torch +import tempfile +from torchao.quantization import quantize_ +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_2, + TORCH_VERSION_AT_LEAST_2_4, + dynamically_quantize_per_channel, + dequantize_per_channel, +) + +from torchao.prototype.smoothquant import ( + insert_smooth_quant_observer, + smooth_quant, + SmoothQuantObservedLinear, + save_smooth_quant_recipe, + load_smooth_quant_recipe +) + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=512, n=256, k=128): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + self.linear3 = torch.nn.Linear(k, 1, bias=False) + + def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): + return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + +bias_list = [True, False] +alpha_list = [0.5, 0.75] +quant_mode_list = ["static", "dynamic"] +devices = ["cpu"] +if torch.cuda.is_available(): + devices.append("cuda") +idtypes = (torch.float, torch.bfloat16, torch.half) + +@pytest.mark.parametrize("bias", bias_list) +@pytest.mark.parametrize("alpha", alpha_list) +@pytest.mark.parametrize("quant_mode", quant_mode_list) +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("idtype", idtypes) +def test_compute(bias, alpha, quant_mode, device, idtype): + class Linear(torch.nn.module): + def __init__(self, bias: bool): + super().__init__() + self.fc = torch.nn.Linear(16, 16, bias) + self.fc.weight.data = torch.randn_like(self.fc.weight.data) + + def forward(self, x): + return self.fc(x) + + original_dtype = idtype + m = Linear(bias).eval().to(original_dtype).to(device) + m_ref = deepcopy(m) + data = torch.randn(2, 16, dtype=original_dtype, device=device) + + # calibrate + insert_smooth_quant_observer(m, alpha, quant_mode, 1) + m(data) + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) + quantize_(m, smooth_quant(), is_observed_linear) + # m = torch.compile(m, fullgraph=True) + out = m(data) + + # reference + weight = m_ref.fc.weight.data + bias = m_ref.fc.bias + x_abs_max_per_ic = torch.abs(data).max(dim=0).values + w_abs_max_per_ic = torch.abs(weight).max(dim=0).values + smoothing_factor = ( + torch.pow(x_abs_max_per_ic, alpha) / torch.pow( + w_abs_max_per_ic, 1 - alpha) + ) + act = data / smoothing_factor + wei = weight * smoothing_factor + qw, w_scales, w_zps = dynamically_quantize_per_channel( + wei, -128, 127, torch.int8 + ) + if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ + not TORCH_VERSION_AT_LEAST_2_2: + dqw = dequantize_per_channel(qw, w_scales, w_zps, torch.float32) + out_ref = torch.nn.functional.linear(act, dqw, bias) + elif quant_mode == "static": + pass + else: + pass + + assert torch.allclose(out, out_ref, atol = 1e-2) + + +@pytest.mark.parametrize("alpha", alpha_list) +@pytest.mark.parametrize("quant_mode", quant_mode_list) +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("idtype", idtypes) +def test_save_load_recipe(alpha, quant_mode, device, idtype): + dataset_size = 20 + l1, l2, l3 = 512, 256, 128 + original_dtype = idtype + n_calib_examples = 10 + sequence_length = 5 + + m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) + m_save_load = deepcopy(m) + + dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) + calibration_data = dataset[:n_calib_examples] + + # calibrate + insert_smooth_quant_observer(m, alpha, quant_mode, n_calib_examples) + insert_smooth_quant_observer(m_save_load, alpha, quant_mode, n_calib_examples) + + for example in calibration_data: + m(example.to(device)) + m_save_load(example.to(device)) + + with tempfile.NamedTemporaryFile() as fp: + save_path = fp.name + save_smooth_quant_recipe(m_save_load, save_path) + load_smooth_quant_recipe(m_save_load, save_path) + + # quantize + is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) + quantize_(m, smooth_quant(), is_observed_linear) + # m = torch.compile(m, fullgraph=True) + # m_save_load = torch.compile(m_save_load, fullgraph=True) + out_list = [m(data.squeeze(0)) for data in dataset] + out = torch.cat(out_list) + save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset] + save_load_out = torch.cat(save_load_out_list) + + assert out is not None + assert save_load_out is not None + assert torch.allclose(out, save_load_out, atol = 1e-2) diff --git a/torchao/prototype/smoothquant/__init__.py b/torchao/prototype/smoothquant/__init__.py new file mode 100644 index 000000000..6fced77e3 --- /dev/null +++ b/torchao/prototype/smoothquant/__init__.py @@ -0,0 +1,7 @@ +from .api import ( + insert_smooth_quant_observer, + smooth_quant, + save_smooth_quant_recipe, + load_smooth_quant_recipe, +) +from .core import SmoothQuantObservedLinear \ No newline at end of file diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py new file mode 100644 index 000000000..373b1eae0 --- /dev/null +++ b/torchao/prototype/smoothquant/api.py @@ -0,0 +1,175 @@ +import torch +import torch.nn.functional as F +import json +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + _DTYPE_TO_QVALUE_BOUNDS, +) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.dtypes import to_affine_quantized_intx_static +from torchao.dtypes.uintx.uintx import _DTYPE_TO_BIT_WIDTH +from torchao.prototype.smoothquant.core import( + SmoothQuantObserver, + SmoothQuantObservedLinear, + SmoothQuantLayoutType, +) +from typing import Dict, Optional + + +assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" + + +def insert_smooth_quant_observer( + model: torch.nn.Module, + alpha: float = 0.5, + quant_mode: str = "static", + n_calib_examples: int = 20): + """ + Inserts SmoothQuantObserver into Linear layers of a given model. + + Args: + model: The model to be modified (in place). Ensure model is on the desired device for calibration + mapping_type: symmetric or asymmetric quantization of weight + n_calib_examples: Number of examples used for calibration + """ + _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) + + quant_min = _DTYPE_TO_QVALUE_BOUNDS[torch.int8][0] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8][1] + eps = torch.finfo(torch.float32).eps + + def replace_with_observer(layer): + # creates observer and replaces linear layers with observed linear layers + observer = SmoothQuantObserver( + layer.weight, + alpha, + quant_mode, + n_calib_examples, + quant_min=quant_min, + quant_max = quant_max, + eps = eps) + return SmoothQuantObservedLinear.from_float(layer, observer) + + _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) + + +def _observed_linear_subclass_inserter(constructor): + """ + Replaces unquantized observed linear instances with quantized linear instances. + + Args: + constructor: the function which applies quantization to the observed linear layer + """ + def insert_subclass(observed_linear): + # creates the new linear layer using constructor + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias is not None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype + ) + linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False) + linear.bias = observed_linear.bias + return linear + + return insert_subclass + + +def save_smooth_quant_recipe(model: torch.nn.Module, save_path: str) -> Dict[str, torch.Tensor]: + """ + Save smoothing_factors, act_scales, and wei_scales for each SmoothQuantObservedLinear layer in the model. + """ + result = {} + + def recurse(module: torch.nn.Module, name: str = ''): + for child_name, child in module.named_children(): + full_name = f"{name}.{child_name}" if name else child_name + + # Apply the analysis function to this layer + if isinstance(child, SmoothQuantObservedLinear): + smoothing_factor, act_scales, wei_scales = child.obs.calculate_qparams() + result[full_name + ".smoothing_factor"] = smoothing_factor + result[full_name + ".act_scales"] = act_scales + result[full_name + ".wei_scales"] = wei_scales + + # Recurse into child modules + recurse(child, full_name) + + recurse(model) + + torch.save(result, save_path) + +def load_smooth_quant_recipe(model: torch.nn.Module, recipe_path: str, device=None) -> torch.nn.Module: + recipe = torch.load(recipe_path, weights_only=True) + + def recurse(module: torch.nn.Module, name: str = ''): + if isinstance(module, SmoothQuantObservedLinear): + smoothing_factor = recipe.get(name + ".smoothing_factor", None) + act_scales = recipe.get(name + ".act_scales", None) + wei_scales = recipe.get(name + ".wei_scales", None) + if device is not None: + module.to(device=device) + # act_scales is None for dynamic quantization + if any(x is None for x in (smoothing_factor, wei_scales)): + return module + return smooth_quant(smoothing_factor, act_scales, wei_scales)(module) + + mod_new = module + + for child_name, child in module.named_children(): + full_name = f"{name}.{child_name}" if name else child_name + setattr(mod_new, child_name, recurse(child, full_name)) + return mod_new + + recurse(model) + + +def smooth_quant( + smoothing_factor: Optional[torch.Tensor] = None, + act_scales: Optional[torch.Tensor] = None, + wei_scales: Optional[torch.Tensor] = None + ): + """ + Quantizes linear layers when passed into quantize_() + + Args: + smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None. + act_scales: The activation scales for the layer. Acquired from the layer's observer if None. + wei_scales: The weight scales for the layer. Acquired from the layer's observer if None. + """ + + def quantize_weight(observed_linear): + target_dtype = torch.int8 + quant_min = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][0] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][1] + nonlocal smoothing_factor, act_scales, wei_scales + # act_scales is None for dynamic quantization + if any(x is None for x in (smoothing_factor, wei_scales)): + factor, s_act, s_wei = observed_linear.obs.calculate_qparams() + weight = observed_linear.obs.weight * factor + else: + factor, s_act, s_wei = smoothing_factor, act_scales, wei_scales + weight = observed_linear.weight * factor + weight_t = weight.t().contiguous() + block_size = (weight_t.size(0), 1) + inv_smoothing_factor = 1 / factor + layout_type = SmoothQuantLayoutType( + inv_smoothing_factor, s_act, s_wei + ) + wei_zero_points = torch.zeros_like(s_wei, dtype=torch.int64) + return to_affine_quantized_intx_static( + weight_t, + s_wei, + wei_zero_points, + block_size, + target_dtype, + quant_min, + quant_max, + layout_type=layout_type + ) + + return _observed_linear_subclass_inserter(quantize_weight) + + diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py new file mode 100644 index 000000000..7991fbe14 --- /dev/null +++ b/torchao/prototype/smoothquant/core.py @@ -0,0 +1,327 @@ +from dataclasses import dataclass +from typing import Tuple, Optional + +import torch +import torch.nn.functional as F +from torch.utils._python_dispatch import return_and_correct_aliasing +from torch.ao.quantization import PerChannelMinMaxObserver, HistogramObserver +from torchao.dtypes.uintx.uintx import to_uintx +from torchao.dtypes.affine_quantized_tensor import ( + to_affine_quantized_intx, + LayoutType, + register_layout_cls, + AQTLayout, + register_aqt_quantized_linear_dispatch + +) +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + _DTYPE_TO_QVALUE_BOUNDS, +) +from torchao.quantization.observer import ( + AffineQuantizedObserverBase, PerRow +) +from torchao.quantization.utils import ( + dynamically_quantize_per_channel, + quant_int8_per_token_matmul, +) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_4 + +class SmoothQuantObserver(AffineQuantizedObserverBase): + def __init__(self, + weight: torch.Tensor, + alpha: float = 0.5, + quant_mode: str = "static", # or dynamic + n_calib_examples: int = 20, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + zero_point_domain = ZeroPointDomain.INT, + reduce_range: Optional[bool] = False, + ): + """ + A custom observer for SmoothQuant + + Args: + weight: The weight tensor to be observed. + mapping_type: symmetric or asymmetric quantization of weight + alpha: The alpha value to determine smoothing factor + quant_mode: The mode of activation quantization, either static or dynamic + n_calib_examples: Number of examples used to calibrate observer + quant_min: The minimum quantized value + quant_max: The maximum quantized value + eps: The minimum scale to avoid dividing by zero. + scale_dtype: The data type of the scale tensor. + zero_point_dtype: The data type of the zero point tensor. + zero_point_domain: The domain of the zero point. + reduce_range: Quantize act/wei to less than 8 bits on old platforms + """ + super().__init__( + MappingType.SYMMETRIC, + torch.int8, + PerRow(), + quant_min = quant_min, + quant_max = quant_max, + eps = eps, + scale_dtype = scale_dtype, + zero_point_dtype = zero_point_dtype, + preserve_zero = True, + zero_point_domain = zero_point_domain, + ) + assert weight.ndim == 2 + self.weight = weight + self.n_calib_examples = n_calib_examples + self.inputs = [] + self.device = self.weight.device + self.alpha = alpha + assert quant_mode in ["static", "dynamic"] + self.quant_mode = quant_mode + # act.shape = [mb, ic] (reshape if needed), wei.shape = [oc, ic] + # *_ic_obs are used to determine smoothing_factor + # *_mb/oc_obs are used to find qparams for quantization + self.act_ic_obs = PerChannelMinMaxObserver( + ch_axis=-1, + dtype=torch.int8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + ) + self.act_obs = HistogramObserver( + dtype=torch.int8, + qscheme=torch.per_tensor_symmetric, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + ) + self.wei_ic_obs = PerChannelMinMaxObserver( + ch_axis=1, + dtype=torch.int8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + ) + self.wei_oc_obs = PerChannelMinMaxObserver( + ch_axis=0, + dtype=torch.int8, + qscheme=torch.per_channel_symmetric, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + ) + self.wei_ic_obs(self.weight) + + @torch.no_grad() + def forward(self, input: torch.Tensor): + if self.quant_mode == "static": + # record inputs to find qparams for activation + if len(self.inputs) < self.n_calib_examples: + self.inputs.append(input.to("cpu").view(-1, input.size(-1))) + self.act_ic_obs(input) + return input + + def calculate_qparams(self): + # 1 Get min/max per IC from observers + wei_min_per_ic = self.wei_ic_obs.min_val + wei_max_per_ic = self.wei_ic_obs.max_val + act_min_per_ic = self.act_ic_obs.min_val + act_max_per_ic = self.act_ic_obs.max_val + x_abs_max_per_ic = ( + torch.max(torch.abs(act_min_per_ic), torch.abs(act_max_per_ic)) + self.eps + ) + w_abs_max_per_ic = ( + torch.max(torch.abs(wei_min_per_ic), torch.abs(wei_max_per_ic)) + self.eps + ) + # 2 calculate the smoothing factor + smoothing_factor = torch.pow(x_abs_max_per_ic, self.alpha) / torch.pow( + w_abs_max_per_ic, 1 - self.alpha + ) + # 3 apply smoothing factor to activations and find scales for static quantization + act_scales = None + if self.quant_mode == "static": + inv_smoothing_factor = 1 / smoothing_factor + for act in self.inputs: + act_new = act * inv_smoothing_factor + self.act_obs(act_new) + act_scale, _ = self.act_obs.calculate_qparams() + act_scales = torch.Tensor([act_scale]) + # 4 update weight and find scales + self.wei_oc_obs(self.weight * smoothing_factor) + wei_scales, _ = self.wei_oc_obs.calculate_qparams() + # 5 return results + return smoothing_factor, act_scales, wei_scales + + +class SmoothQuantObservedLinear(torch.nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + obs: SmoothQuantObserver, + device=None, + dtype=None): + super().__init__(in_features, out_features, bias, device, dtype) + assert isinstance(obs, SmoothQuantObserver) + self.obs = obs + + def forward(self, input: torch.Tensor): + input = self.obs(input) + output = F.linear(input, self.weight, self.bias) + return output + + @classmethod + def from_float(cls, float_linear: torch.nn.Linear, obs: SmoothQuantObserver): + observed_linear = cls( + float_linear.in_features, + float_linear.out_features, + float_linear.bias is not None, + obs, + device=float_linear.weight.device, + dtype=float_linear.weight.dtype, + ) + observed_linear.weight = float_linear.weight + observed_linear.bias = float_linear.bias + return observed_linear + + +@dataclass(frozen=True) +class SmoothQuantLayoutType(LayoutType): + inv_smoothing_factor: torch.Tensor + act_scales: torch.Tensor + wei_scales: torch.Tensor + + +def _quantized_linear_impl(input_tensor, weight_tensor, bias): + inv_smoothing_factor = weight_tensor.layout_tensor.layout_type.inv_smoothing_factor + act_scales = weight_tensor.layout_tensor.layout_type.act_scales + wei_scales = weight_tensor.layout_tensor.layout_type.wei_scales + input = input_tensor * inv_smoothing_factor + if (weight_tensor.device.type == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ + not TORCH_VERSION_AT_LEAST_2_2: + # _int_mm is not available on CUDA before PyTorch 2.2 + # _int_mm is not available on CPU before PyTorch 2.4 + # So compute in float here + y = F.linear(input, weight_tensor.dequantize(), bias) + else: + target_dtype = torch.int8 + quant_min = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][0] + quant_max = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][1] + if act_scales is not None: + # dynamic quant + act_zero_points = torch.zeros_like(act_scales, dtype=torch.int64) + qx = torch.ops.quantized_decomposed.quantize_per_tensor( + input, + act_scales, + act_zero_points, + quant_min, + quant_max, + dtype=target_dtype, + ) + act_scales = act_scales * torch.ones(input.size(0), dtype=act_scales.dtype) + else: + # static quant + qx, act_scales, _ = dynamically_quantize_per_channel(input, quant_min, quant_max, target_dtype) + y = quant_int8_per_token_matmul( + qx, act_scales, weight_tensor.layout_tensor.int_data, wei_scales + ) + if bias is not None: + y += bias + return y.to(input_tensor.dtype) + + +def _linear_sq_check(input_tensor, weight_tensor, bias): + return isinstance(weight_tensor.layout_tensor, SmoothQuantAQTLayout) + + +register_aqt_quantized_linear_dispatch(_linear_sq_check, _quantized_linear_impl) + + +@register_layout_cls(SmoothQuantLayoutType) +class SmoothQuantAQTLayout(AQTLayout): + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self.layout_type = layout_type + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.int_data, self.scale, self.zero_point + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [self.layout_type] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + layout_type, = tensor_attributes + return cls(int_data, scale, zero_point, layout_type) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is torch.ops.aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"SmoothQuantAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, SmoothQuantLayoutType) + return cls(int_data, scale, zero_point, layout_type) + + def get_layout_type(self) -> LayoutType: + return self.layout_type + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.scale = fn(self.scale) + self.zero_point = fn(self.zero_point) + return self + +to_smooth_quant = SmoothQuantAQTLayout.from_plain \ No newline at end of file diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchao/prototype/smoothquant/readme.md b/torchao/prototype/smoothquant/readme.md new file mode 100644 index 000000000..e69de29bb From 847f1f28a6f6cd67b6a1856a444a75ca674a252d Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 8 Oct 2024 01:21:29 -0400 Subject: [PATCH 02/36] Update UT --- test/prototype/test_smoothquant.py | 87 +++++++++++++++++---------- torchao/prototype/smoothquant/core.py | 8 +-- 2 files changed, 58 insertions(+), 37 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 5e65ce8c2..fdc38dd7f 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -3,14 +3,16 @@ import pytest import torch import tempfile +from torch.ao.quantization import HistogramObserver from torchao.quantization import quantize_ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_4, +) +from torchao.quantization.utils import ( dynamically_quantize_per_channel, dequantize_per_channel, ) - from torchao.prototype.smoothquant import ( insert_smooth_quant_observer, smooth_quant, @@ -50,7 +52,7 @@ def forward(self, x): @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("idtype", idtypes) def test_compute(bias, alpha, quant_mode, device, idtype): - class Linear(torch.nn.module): + class Linear(torch.nn.Module): def __init__(self, bias: bool): super().__init__() self.fc = torch.nn.Linear(16, 16, bias) @@ -59,10 +61,9 @@ def __init__(self, bias: bool): def forward(self, x): return self.fc(x) - original_dtype = idtype - m = Linear(bias).eval().to(original_dtype).to(device) + m = Linear(bias).eval().to(idtype).to(device) m_ref = deepcopy(m) - data = torch.randn(2, 16, dtype=original_dtype, device=device) + data = torch.randn(2, 16, dtype=idtype, device=device) # calibrate insert_smooth_quant_observer(m, alpha, quant_mode, 1) @@ -70,33 +71,53 @@ def forward(self, x): # quantize is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) quantize_(m, smooth_quant(), is_observed_linear) - # m = torch.compile(m, fullgraph=True) - out = m(data) - - # reference - weight = m_ref.fc.weight.data - bias = m_ref.fc.bias - x_abs_max_per_ic = torch.abs(data).max(dim=0).values - w_abs_max_per_ic = torch.abs(weight).max(dim=0).values - smoothing_factor = ( - torch.pow(x_abs_max_per_ic, alpha) / torch.pow( - w_abs_max_per_ic, 1 - alpha) - ) - act = data / smoothing_factor - wei = weight * smoothing_factor - qw, w_scales, w_zps = dynamically_quantize_per_channel( - wei, -128, 127, torch.int8 - ) - if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ - not TORCH_VERSION_AT_LEAST_2_2: - dqw = dequantize_per_channel(qw, w_scales, w_zps, torch.float32) - out_ref = torch.nn.functional.linear(act, dqw, bias) - elif quant_mode == "static": - pass - else: - pass - - assert torch.allclose(out, out_ref, atol = 1e-2) + with torch.inference_mode(): + # m = torch.compile(m, fullgraph=True) + out = m(data) + + # reference + weight = m_ref.fc.weight.data.float() + b = m_ref.fc.bias.float() if bias else None + x_abs_max_per_ic = torch.abs(data).max(dim=0).values + w_abs_max_per_ic = torch.abs(weight).max(dim=0).values + smoothing_factor = ( + torch.pow(x_abs_max_per_ic, alpha) / torch.pow( + w_abs_max_per_ic, 1 - alpha) + ) + act = data / smoothing_factor + wei = weight * smoothing_factor + qw, w_scales, w_zps = dynamically_quantize_per_channel( + wei, -128, 127, torch.int8 + ) + fq_wei = dequantize_per_channel(qw, w_scales, w_zps, torch.float32) + if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ + not TORCH_VERSION_AT_LEAST_2_2: + # _int_mm is not supported in these cases + out_ref = torch.nn.functional.linear(act, fq_wei, b) + elif quant_mode == "static": + # activation is quantized per-tensor + obs = HistogramObserver( + dtype=torch.int8, + qscheme=torch.per_tensor_symmetric, + ) + obs(act.float()) + act_scale, _ = obs.calculate_qparams() + fq_act = torch.quantize_per_tensor( + act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8 + ).dequantize() + out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) + else: + # activation is quantized per-row (batch * sequence_length) + qx, x_scales, x_zps = dynamically_quantize_per_channel( + act.float(), -128, 127, torch.int8 + ) + fq_act = dequantize_per_channel(qx, x_scales, x_zps, torch.float32) + out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) + + # Quantized weights of the reference and the SmoothQuant model may differ by 1 + # when elements are quantized to -128 in one case and -127 in the other. + # So, the tolerance is relatively big here + assert torch.allclose(out, out_ref.to(idtype), atol = 0.2) @pytest.mark.parametrize("alpha", alpha_list) @@ -141,4 +162,4 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype): assert out is not None assert save_load_out is not None - assert torch.allclose(out, save_load_out, atol = 1e-2) + assert torch.allclose(out, save_load_out) diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 7991fbe14..73e5bc282 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -207,8 +207,8 @@ def _quantized_linear_impl(input_tensor, weight_tensor, bias): input = input_tensor * inv_smoothing_factor if (weight_tensor.device.type == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ not TORCH_VERSION_AT_LEAST_2_2: - # _int_mm is not available on CUDA before PyTorch 2.2 - # _int_mm is not available on CPU before PyTorch 2.4 + # _int_mm is not available on CUDA until PyTorch 2.2 + # _int_mm is not available on CPU until PyTorch 2.4 # So compute in float here y = F.linear(input, weight_tensor.dequantize(), bias) else: @@ -216,7 +216,7 @@ def _quantized_linear_impl(input_tensor, weight_tensor, bias): quant_min = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][0] quant_max = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][1] if act_scales is not None: - # dynamic quant + # static quant act_zero_points = torch.zeros_like(act_scales, dtype=torch.int64) qx = torch.ops.quantized_decomposed.quantize_per_tensor( input, @@ -228,7 +228,7 @@ def _quantized_linear_impl(input_tensor, weight_tensor, bias): ) act_scales = act_scales * torch.ones(input.size(0), dtype=act_scales.dtype) else: - # static quant + # dynamic quant qx, act_scales, _ = dynamically_quantize_per_channel(input, quant_min, quant_max, target_dtype) y = quant_int8_per_token_matmul( qx, act_scales, weight_tensor.layout_tensor.int_data, wei_scales From f03cfb3ae45a213604519f4a2188dfa228efc8cc Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 8 Oct 2024 05:55:53 -0400 Subject: [PATCH 03/36] Add SmoothQuant example --- torchao/kernel/intmm.py | 16 ++- torchao/prototype/smoothquant/core.py | 4 +- torchao/prototype/smoothquant/example.py | 156 +++++++++++++++++++++++ 3 files changed, 174 insertions(+), 2 deletions(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 7d076a6e8..c357747d6 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -2,7 +2,7 @@ import os import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_2 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_4 try: # Only works for torch2.2 or newer. @@ -108,6 +108,18 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return torch.ops.torchao.int_matmul(a, b) return safe_int_mm(a, b) +lib = torch.library.Library("torchao", "FRAGMENT") +if intmm_triton is None: + lib.define("int_scaled_matmul(Tensor a, Tensor b, Tensor scales1) -> Tensor") + +@torch.library.impl(lib, "int_scaled_matmul", "CPU") +def int_scaled_matmul_cpu(a, b, scales1): + if TORCH_VERSION_AT_LEAST_2_4: + c = torch._int_mm(a, b) + return c.to(scales1.dtype) * scales1 + else: + return safe_int_mm(a, b) * scales1 + def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor: """ @@ -133,6 +145,8 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - assert scales1.dim() == 2 if intmm_triton is not None and AUTOTUNER_ENABLE: return torch.ops.torchao.int_scaled_matmul(a, b, scales1) + if all([x.device.type == "cpu" for x in (a, b, scales1)]): + return torch.ops.torchao.int_scaled_matmul(a, b, scales1) c = safe_int_mm(a, b) return c * scales1 diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 73e5bc282..38cb910d9 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -204,7 +204,9 @@ def _quantized_linear_impl(input_tensor, weight_tensor, bias): inv_smoothing_factor = weight_tensor.layout_tensor.layout_type.inv_smoothing_factor act_scales = weight_tensor.layout_tensor.layout_type.act_scales wei_scales = weight_tensor.layout_tensor.layout_type.wei_scales + input_shape = input_tensor.shape input = input_tensor * inv_smoothing_factor + input = input.reshape(-1, input_shape[-1]) if (weight_tensor.device.type == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ not TORCH_VERSION_AT_LEAST_2_2: # _int_mm is not available on CUDA until PyTorch 2.2 @@ -235,7 +237,7 @@ def _quantized_linear_impl(input_tensor, weight_tensor, bias): ) if bias is not None: y += bias - return y.to(input_tensor.dtype) + return y.to(input_tensor.dtype).reshape(input_shape[:-1] + (-1,)) def _linear_sq_check(input_tensor, weight_tensor, bias): diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index e69de29bb..a772b7c20 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -0,0 +1,156 @@ +import torch +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from tqdm import tqdm +import time +from torchao.prototype.smoothquant import ( + insert_smooth_quant_observer, + SmoothQuantObservedLinear, + smooth_quant +) +from torchao.quantization import quantize_ + + +# adapted from the AWQ example +def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): + dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") + samples = [] + n_tokens = n_samples * block_size + n_run = n_tokens + for data in dataset: + line = data["text"] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run -= len(line_encoded) + if n_run <= n_samples: + break + + cat_samples = torch.cat(samples, dim=1) + return [cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples)] + +# adapted from the AWQ example +def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True, device="cuda"): + model.eval() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + tokenizer.add_eos_token = False + + dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt') + + encodings['input_ids'] = encodings['input_ids'].to(device) + + lls, t = [], [] + for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose): + begin_loc = max(i + stride - sequence_length, 0) + end_loc = min(i + stride, encodings['input_ids'].size(1)) + trg_len = end_loc - i + input_ids = encodings['input_ids'][:,begin_loc:end_loc] + target_ids = input_ids.clone() + target_ids[:,:-trg_len] = -100 #ignore context + + t1 = time.time() + with torch.no_grad(): + log_likelihood = model(input_ids, labels=target_ids).loss * trg_len + if device == "cuda": + torch.cuda.synchronize() + t2 = time.time() + t.append((t2-t1)) + lls.append(log_likelihood) + + del input_ids, target_ids + + ppl = float(torch.exp(torch.stack(lls).sum() / end_loc)) + pred_time = sum(t)/len(t) + if(verbose): + print('perplexity', ppl) + print('time', str(pred_time) + ' sec/it') + + return {'perplexity':ppl, 'prediction_time':pred_time} + +# adapted from the AWQ example +def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): + model.eval() + model.config.use_cache = False + if tasks is None: + tasks = ["PPL"] + results = {} + if "PPL" in tasks: + results["perplexity"] = wiki2_eval(model, tokenizer, 512, verbose=True, device=device) + return results + + +def wikitext2_ppl( + model_id: str, + quant_mode: str, + calibration_size: int, + device: str, + precision:torch.dtype, + sequence_length: int, + compile: bool, + model_save_path: str): + print(f"Loading model on {device}...") + torch.manual_seed(34) + t0 = time.time() + # load any model with torch.nn.linear layers + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=precision).eval().to(device) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + print(f"running calibration") + t0 = time.time() + # insert observers to find average magnitude and calculate scales + insert_smooth_quant_observer(model, 0.5, quant_mode, calibration_size) + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") + print(f"time for calibration: {time.time() - t0:.02f} seconds") + + is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) + print(f"running SmoothQuant with {quant_mode} quantization") + t0 = time.time() + quantize_(model, smooth_quant(), is_observed_linear) + print(f"time for quantization: {time.time() - t0:.02f} seconds") + if model_save_path is not None: + print(f"Saving quantized model to {model_save_path}") + torch.save(model, model_save_path) + if compile: + model = torch.compile(model) + + return benchmark(model, tokenizer, sequence_length, tasks=["PPL"], device=device) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") + + + # Optional arguments with default values + parser.add_argument("--model-id", "-m", type=str, help="Repository ID of the model.") + parser.add_argument("--quant-mode", type=str, help="Quantization mode, either static or dynamic.") + parser.add_argument("--calibration-samples", type=int, default=10, help="Number of samples to use for calibration. Default is 10.") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") + parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") + parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate and evaluate model on. Default is 512") + parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") + parser.add_argument("--model-save-path", type=str, default=None, help="Path to store quantized model.") + + args = parser.parse_args() + + # Convert precision argument to torch dtype + precision_dtype = getattr(torch, args.precision, torch.bfloat16) + ppl = wikitext2_ppl( + args.model_id, + args.quant_mode, + args.calibration_samples, + args.device, + args.precision, + args.seq_len, + args.compile, + args.model_save_path + ) From a2518f1c76c327134f3d1561d4569cad2844fc65 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 9 Oct 2024 02:26:18 -0400 Subject: [PATCH 04/36] Remove duplicate implementation of int_scaled_matmul for CPU --- torchao/kernel/intmm_triton.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index d10dac0ab..4e84d9cd3 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -356,9 +356,3 @@ def int_scaled_matmul_cuda(a, b, scales1): int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs ) return int_scaled_matmul_kernel(a, b, scales1, c, best_config) - - -@torch.library.impl(lib, "int_scaled_matmul", "CPU") -def int_scaled_matmul_cpu(a, b, scales1): - c = torch._int_mm(a, b) - return c.to(scales1.dtype) * scales1 From 28fb8ce3fd32526b470a5f808b8a0f5912aa8839 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 9 Oct 2024 02:52:21 -0400 Subject: [PATCH 05/36] Update example.py --- torchao/prototype/smoothquant/example.py | 54 +++++++++++++++--------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index a772b7c20..f2e66cd03 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -1,5 +1,6 @@ import torch import argparse +import os from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from tqdm import tqdm @@ -95,32 +96,41 @@ def wikitext2_ppl( precision:torch.dtype, sequence_length: int, compile: bool, + model_load_path: str, model_save_path: str): print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() # load any model with torch.nn.linear layers tokenizer = AutoTokenizer.from_pretrained(model_id) - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=precision).eval().to(device) - print(f"Time to load model: {time.time() - t0:.02f} seconds") - print(f"running calibration") - t0 = time.time() - # insert observers to find average magnitude and calculate scales - insert_smooth_quant_observer(model, 0.5, quant_mode, calibration_size) - calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length) - for batch in calibration_data: - model(batch.to(device)) - batch.to("cpu") - print(f"time for calibration: {time.time() - t0:.02f} seconds") - - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - print(f"running SmoothQuant with {quant_mode} quantization") - t0 = time.time() - quantize_(model, smooth_quant(), is_observed_linear) - print(f"time for quantization: {time.time() - t0:.02f} seconds") - if model_save_path is not None: - print(f"Saving quantized model to {model_save_path}") - torch.save(model, model_save_path) + if model_load_path is not None and os.path.exists(model_load_path): + print(f"Loading quantized model from {model_load_path}") + t0 = time.time() + model = torch.load(model_load_path, weights_only=False).to(device) + print(f"Time to load quantized model: {time.time() - t0:.02f} seconds") + else: + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=precision).eval().to(device) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + print(f"running calibration") + t0 = time.time() + # insert observers to find average magnitude and calculate scales + insert_smooth_quant_observer(model, 0.5, quant_mode, calibration_size) + calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") + print(f"time for calibration: {time.time() - t0:.02f} seconds") + + is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) + print(f"running SmoothQuant with {quant_mode} quantization") + t0 = time.time() + quantize_(model, smooth_quant(), is_observed_linear) + print(f"time for quantization: {time.time() - t0:.02f} seconds") + if model_save_path is not None: + print(f"Saving quantized model to {model_save_path}") + t0 = time.time() + torch.save(model, model_save_path) + print(f"Time to save quantized model: {time.time() - t0:.02f} seconds") if compile: model = torch.compile(model) @@ -138,6 +148,9 @@ def wikitext2_ppl( parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate and evaluate model on. Default is 512") parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") + parser.add_argument("--model-load-path", type=str, default=None, + help="Path to load quantized model. If this is provided, " + "the model will be loaded from this path instead of quantizing the model.") parser.add_argument("--model-save-path", type=str, default=None, help="Path to store quantized model.") args = parser.parse_args() @@ -152,5 +165,6 @@ def wikitext2_ppl( args.precision, args.seq_len, args.compile, + args.model_load_path, args.model_save_path ) From bada2b0c10e9de84f52132aabcd885e818a9e6c4 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 9 Oct 2024 04:09:59 -0400 Subject: [PATCH 06/36] Remove unused code --- torchao/prototype/smoothquant/api.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 373b1eae0..5244a923b 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -1,14 +1,9 @@ import torch -import torch.nn.functional as F -import json from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, _DTYPE_TO_QVALUE_BOUNDS, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes import to_affine_quantized_intx_static -from torchao.dtypes.uintx.uintx import _DTYPE_TO_BIT_WIDTH from torchao.prototype.smoothquant.core import( SmoothQuantObserver, SmoothQuantObservedLinear, @@ -17,9 +12,6 @@ from typing import Dict, Optional -assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" - - def insert_smooth_quant_observer( model: torch.nn.Module, alpha: float = 0.5, From 921efc0f6cc4b1b2454c8fea7404f458ab8dad74 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 10 Oct 2024 05:04:37 -0400 Subject: [PATCH 07/36] Implement with LinearActivationQuantizedTensor --- test/prototype/test_smoothquant.py | 6 +- torchao/dtypes/affine_quantized_tensor.py | 1 - torchao/kernel/intmm.py | 6 +- torchao/prototype/smoothquant/api.py | 54 +++++--- torchao/prototype/smoothquant/core.py | 156 +--------------------- 5 files changed, 44 insertions(+), 179 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index fdc38dd7f..ef30680bd 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -87,7 +87,7 @@ def forward(self, x): act = data / smoothing_factor wei = weight * smoothing_factor qw, w_scales, w_zps = dynamically_quantize_per_channel( - wei, -128, 127, torch.int8 + wei, -127, 127, torch.int8 ) fq_wei = dequantize_per_channel(qw, w_scales, w_zps, torch.float32) if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ @@ -109,7 +109,7 @@ def forward(self, x): else: # activation is quantized per-row (batch * sequence_length) qx, x_scales, x_zps = dynamically_quantize_per_channel( - act.float(), -128, 127, torch.int8 + act.float(), -127, 127, torch.int8 ) fq_act = dequantize_per_channel(qx, x_scales, x_zps, torch.float32) out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) @@ -131,7 +131,7 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype): n_calib_examples = 10 sequence_length = 5 - m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device) + m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) m_save_load = deepcopy(m) dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index c2c8e3c0b..db9ec04ba 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1346,7 +1346,6 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): isinstance(input_tensor, AffineQuantizedTensor) and _aqt_is_int8_reduced_range(input_tensor) and isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and input_tensor.dtype == weight_tensor.dtype and isinstance(input_tensor.layout_type, PlainLayoutType) and isinstance(weight_tensor.layout_type, PlainLayoutType) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index c357747d6..60be3c780 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -108,15 +108,18 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return torch.ops.torchao.int_matmul(a, b) return safe_int_mm(a, b) + lib = torch.library.Library("torchao", "FRAGMENT") if intmm_triton is None: lib.define("int_scaled_matmul(Tensor a, Tensor b, Tensor scales1) -> Tensor") + @torch.library.impl(lib, "int_scaled_matmul", "CPU") def int_scaled_matmul_cpu(a, b, scales1): if TORCH_VERSION_AT_LEAST_2_4: c = torch._int_mm(a, b) - return c.to(scales1.dtype) * scales1 + c = c.float() * scales1 + return c.to(scales1.dtype) else: return safe_int_mm(a, b) * scales1 @@ -145,6 +148,7 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - assert scales1.dim() == 2 if intmm_triton is not None and AUTOTUNER_ENABLE: return torch.ops.torchao.int_scaled_matmul(a, b, scales1) + if all([x.device.type == "cpu" for x in (a, b, scales1)]): return torch.ops.torchao.int_scaled_matmul(a, b, scales1) diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 5244a923b..298a3d833 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -3,12 +3,16 @@ _DTYPE_TO_QVALUE_BOUNDS, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.dtypes import to_affine_quantized_intx_static +from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static +from torchao.quantization.linear_activation_quantized_tensor import ( + to_linear_activation_quantized, +) +from torchao.quantization.quant_primitives import MappingType +from torchao.quantization.utils import _get_per_token_block_size from torchao.prototype.smoothquant.core import( - SmoothQuantObserver, - SmoothQuantObservedLinear, - SmoothQuantLayoutType, -) + SmoothQuantObserver, + SmoothQuantObservedLinear, +) from typing import Dict, Optional @@ -93,9 +97,10 @@ def recurse(module: torch.nn.Module, name: str = ''): torch.save(result, save_path) + def load_smooth_quant_recipe(model: torch.nn.Module, recipe_path: str, device=None) -> torch.nn.Module: recipe = torch.load(recipe_path, weights_only=True) - + def recurse(module: torch.nn.Module, name: str = ''): if isinstance(module, SmoothQuantObservedLinear): smoothing_factor = recipe.get(name + ".smoothing_factor", None) @@ -114,9 +119,9 @@ def recurse(module: torch.nn.Module, name: str = ''): full_name = f"{name}.{child_name}" if name else child_name setattr(mod_new, child_name, recurse(child, full_name)) return mod_new - + recurse(model) - + def smooth_quant( smoothing_factor: Optional[torch.Tensor] = None, @@ -137,31 +142,40 @@ def quantize_weight(observed_linear): quant_min = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][0] quant_max = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][1] nonlocal smoothing_factor, act_scales, wei_scales - # act_scales is None for dynamic quantization + # act_scales is None for dynamic quantization thus not checked if any(x is None for x in (smoothing_factor, wei_scales)): factor, s_act, s_wei = observed_linear.obs.calculate_qparams() weight = observed_linear.obs.weight * factor else: factor, s_act, s_wei = smoothing_factor, act_scales, wei_scales weight = observed_linear.weight * factor - weight_t = weight.t().contiguous() - block_size = (weight_t.size(0), 1) - inv_smoothing_factor = 1 / factor - layout_type = SmoothQuantLayoutType( - inv_smoothing_factor, s_act, s_wei - ) + weight = weight.to(observed_linear.weight.dtype) + block_size = (1, weight.size(1)) wei_zero_points = torch.zeros_like(s_wei, dtype=torch.int64) - return to_affine_quantized_intx_static( - weight_t, + qw = to_affine_quantized_intx_static( + weight, s_wei, wei_zero_points, block_size, target_dtype, quant_min, quant_max, - layout_type=layout_type ) - - return _observed_linear_subclass_inserter(quantize_weight) + def static_quantize_act(input): + x_scales = s_act * torch.ones(input.numel() // input.size(-1), dtype=s_act.dtype) + x_zp = torch.zeros_like(x_scales, dtype=torch.int64) + act = input / factor + act = act.to(input.dtype) + return to_affine_quantized_intx_static(act, x_scales, x_zp, _get_per_token_block_size(act), target_dtype, quant_min=-127) + def dynamic_quantize_act(input): + act = input / factor + act = act.to(input.dtype) + return to_affine_quantized_intx(act, MappingType.SYMMETRIC, _get_per_token_block_size(act), target_dtype, quant_min=-127) + + is_dynamic = s_act is None + input_quant_func = dynamic_quantize_act if is_dynamic else static_quantize_act + return to_linear_activation_quantized(qw, input_quant_func) + + return _observed_linear_subclass_inserter(quantize_weight) diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 38cb910d9..fd2d290c1 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -1,32 +1,16 @@ from dataclasses import dataclass -from typing import Tuple, Optional - +from typing import Optional import torch import torch.nn.functional as F -from torch.utils._python_dispatch import return_and_correct_aliasing from torch.ao.quantization import PerChannelMinMaxObserver, HistogramObserver -from torchao.dtypes.uintx.uintx import to_uintx -from torchao.dtypes.affine_quantized_tensor import ( - to_affine_quantized_intx, - LayoutType, - register_layout_cls, - AQTLayout, - register_aqt_quantized_linear_dispatch - -) from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, - _DTYPE_TO_QVALUE_BOUNDS, ) from torchao.quantization.observer import ( AffineQuantizedObserverBase, PerRow ) -from torchao.quantization.utils import ( - dynamically_quantize_per_channel, - quant_int8_per_token_matmul, -) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_4 + class SmoothQuantObserver(AffineQuantizedObserverBase): def __init__(self, @@ -191,139 +175,3 @@ def from_float(cls, float_linear: torch.nn.Linear, obs: SmoothQuantObserver): observed_linear.weight = float_linear.weight observed_linear.bias = float_linear.bias return observed_linear - - -@dataclass(frozen=True) -class SmoothQuantLayoutType(LayoutType): - inv_smoothing_factor: torch.Tensor - act_scales: torch.Tensor - wei_scales: torch.Tensor - - -def _quantized_linear_impl(input_tensor, weight_tensor, bias): - inv_smoothing_factor = weight_tensor.layout_tensor.layout_type.inv_smoothing_factor - act_scales = weight_tensor.layout_tensor.layout_type.act_scales - wei_scales = weight_tensor.layout_tensor.layout_type.wei_scales - input_shape = input_tensor.shape - input = input_tensor * inv_smoothing_factor - input = input.reshape(-1, input_shape[-1]) - if (weight_tensor.device.type == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ - not TORCH_VERSION_AT_LEAST_2_2: - # _int_mm is not available on CUDA until PyTorch 2.2 - # _int_mm is not available on CPU until PyTorch 2.4 - # So compute in float here - y = F.linear(input, weight_tensor.dequantize(), bias) - else: - target_dtype = torch.int8 - quant_min = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][1] - if act_scales is not None: - # static quant - act_zero_points = torch.zeros_like(act_scales, dtype=torch.int64) - qx = torch.ops.quantized_decomposed.quantize_per_tensor( - input, - act_scales, - act_zero_points, - quant_min, - quant_max, - dtype=target_dtype, - ) - act_scales = act_scales * torch.ones(input.size(0), dtype=act_scales.dtype) - else: - # dynamic quant - qx, act_scales, _ = dynamically_quantize_per_channel(input, quant_min, quant_max, target_dtype) - y = quant_int8_per_token_matmul( - qx, act_scales, weight_tensor.layout_tensor.int_data, wei_scales - ) - if bias is not None: - y += bias - return y.to(input_tensor.dtype).reshape(input_shape[:-1] + (-1,)) - - -def _linear_sq_check(input_tensor, weight_tensor, bias): - return isinstance(weight_tensor.layout_tensor, SmoothQuantAQTLayout) - - -register_aqt_quantized_linear_dispatch(_linear_sq_check, _quantized_linear_impl) - - -@register_layout_cls(SmoothQuantLayoutType) -class SmoothQuantAQTLayout(AQTLayout): - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout_type: LayoutType, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout_type: LayoutType, - ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point - self.layout_type = layout_type - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return self.int_data, self.scale, self.zero_point - - def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self.layout_type] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - layout_type, = tensor_attributes - return cls(int_data, scale, zero_point, layout_type) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is torch.ops.aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - raise NotImplementedError( - f"SmoothQuantAQTLayout dispatch: attempting to run {func}, this is not supported" - ) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout_type: LayoutType, - ): - assert isinstance(layout_type, SmoothQuantLayoutType) - return cls(int_data, scale, zero_point, layout_type) - - def get_layout_type(self) -> LayoutType: - return self.layout_type - - def _apply_fn_to_data(self, fn): - self.int_data = fn(self.int_data) - self.scale = fn(self.scale) - self.zero_point = fn(self.zero_point) - return self - -to_smooth_quant = SmoothQuantAQTLayout.from_plain \ No newline at end of file From ad5b97e4e7b9bcdfbf9cc73361f41025e6315002 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 10 Oct 2024 08:33:06 -0400 Subject: [PATCH 08/36] Fix load/save --- torchao/kernel/intmm.py | 2 +- torchao/prototype/smoothquant/api.py | 57 +++++++++++++++++++--------- 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 60be3c780..90acb13cb 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -141,7 +141,7 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - """ M, K = a.shape K, N = b.shape - assert M == scales1.size(0) + assert M == scales1.size(0) or scales1.numel() == 1 assert 1 == scales1.size(1) assert scales1.is_contiguous() scales1 = scales1.expand((M, N)) diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 298a3d833..cc61bfe41 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -123,6 +123,35 @@ def recurse(module: torch.nn.Module, name: str = ''): recurse(model) +# StaticQuantizeAct and DynamicQuantizeAct are defined as classes to allow for easy serialization and deserialization +class StaticQuantizeAct: + def __init__(self, factor, x_scale, target_dtype, quant_min=-127): + super().__init__() + self.factor = factor + self.x_scale = x_scale + self.x_zp = torch.zeros_like(x_scale, dtype=torch.int64) + self.target_dtype = target_dtype + self.quant_min = quant_min + + def __call__(self, input): + x_zp = torch.zeros([1], dtype=torch.int64) + act = input / self.factor + act = act.to(input.dtype) + return to_affine_quantized_intx_static(act, self.x_scale, x_zp, list(act.shape), self.target_dtype, self.quant_min) + + +class DynamicQuantizeAct: + def __init__(self, factor, target_dtype, quant_min=-127): + self.factor = factor + self.target_dtype = target_dtype + self.quant_min = quant_min + + def __call__(self, input): + act = input / self.factor + act = act.to(input.dtype) + return to_affine_quantized_intx(act, MappingType.SYMMETRIC, _get_per_token_block_size(act), self.target_dtype, self.quant_min) + + def smooth_quant( smoothing_factor: Optional[torch.Tensor] = None, act_scales: Optional[torch.Tensor] = None, @@ -144,17 +173,17 @@ def quantize_weight(observed_linear): nonlocal smoothing_factor, act_scales, wei_scales # act_scales is None for dynamic quantization thus not checked if any(x is None for x in (smoothing_factor, wei_scales)): - factor, s_act, s_wei = observed_linear.obs.calculate_qparams() + factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() weight = observed_linear.obs.weight * factor else: - factor, s_act, s_wei = smoothing_factor, act_scales, wei_scales + factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales weight = observed_linear.weight * factor weight = weight.to(observed_linear.weight.dtype) block_size = (1, weight.size(1)) - wei_zero_points = torch.zeros_like(s_wei, dtype=torch.int64) + wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64) qw = to_affine_quantized_intx_static( weight, - s_wei, + w_scales, wei_zero_points, block_size, target_dtype, @@ -162,20 +191,12 @@ def quantize_weight(observed_linear): quant_max, ) - def static_quantize_act(input): - x_scales = s_act * torch.ones(input.numel() // input.size(-1), dtype=s_act.dtype) - x_zp = torch.zeros_like(x_scales, dtype=torch.int64) - act = input / factor - act = act.to(input.dtype) - return to_affine_quantized_intx_static(act, x_scales, x_zp, _get_per_token_block_size(act), target_dtype, quant_min=-127) - - def dynamic_quantize_act(input): - act = input / factor - act = act.to(input.dtype) - return to_affine_quantized_intx(act, MappingType.SYMMETRIC, _get_per_token_block_size(act), target_dtype, quant_min=-127) - - is_dynamic = s_act is None - input_quant_func = dynamic_quantize_act if is_dynamic else static_quantize_act + is_dynamic = x_scale is None + if is_dynamic: + input_quant_func = DynamicQuantizeAct(factor, target_dtype) + else: + input_quant_func = StaticQuantizeAct(factor, x_scale, target_dtype) + return to_linear_activation_quantized(qw, input_quant_func) return _observed_linear_subclass_inserter(quantize_weight) From f1be01d04015e88cdb69ba32a05251f408214c72 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 10 Oct 2024 08:40:34 -0400 Subject: [PATCH 09/36] Fix device mismatch in observer --- torchao/prototype/smoothquant/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index fd2d290c1..fa7ec6032 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -136,12 +136,12 @@ def calculate_qparams(self): act_new = act * inv_smoothing_factor self.act_obs(act_new) act_scale, _ = self.act_obs.calculate_qparams() - act_scales = torch.Tensor([act_scale]) + act_scales = torch.Tensor([act_scale]).to(self.device) # 4 update weight and find scales - self.wei_oc_obs(self.weight * smoothing_factor) + self.wei_oc_obs(self.weight * smoothing_factor.to(self.device)) wei_scales, _ = self.wei_oc_obs.calculate_qparams() # 5 return results - return smoothing_factor, act_scales, wei_scales + return smoothing_factor.to(self.device), act_scales, wei_scales.to(self.device) class SmoothQuantObservedLinear(torch.nn.Linear): From 7ee1f137945d48447f9325f8314bfdd90d830162 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 10 Oct 2024 22:10:54 -0400 Subject: [PATCH 10/36] Fix fp16 overflow issue in int_scaled_matmul --- torchao/kernel/intmm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 90acb13cb..39be8f018 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -121,7 +121,8 @@ def int_scaled_matmul_cpu(a, b, scales1): c = c.float() * scales1 return c.to(scales1.dtype) else: - return safe_int_mm(a, b) * scales1 + c = safe_int_mm(a, b) * scales1.float() + return c.to(scales1.dtype) def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor: From 427ff73d57afc24436c02757b63b8955c50dd5cf Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 11 Oct 2024 07:48:48 -0400 Subject: [PATCH 11/36] Add linear_activation_scale_quantized.py for torch.compile --- torchao/prototype/smoothquant/api.py | 40 ++-- torchao/prototype/smoothquant/example.py | 2 +- .../linear_activation_scale_quantized.py | 188 ++++++++++++++++++ 3 files changed, 212 insertions(+), 18 deletions(-) create mode 100644 torchao/quantization/linear_activation_scale_quantized.py diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index cc61bfe41..d0509059c 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -4,8 +4,8 @@ ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static -from torchao.quantization.linear_activation_quantized_tensor import ( - to_linear_activation_quantized, +from torchao.quantization.linear_activation_scale_quantized import ( + to_linear_scale_activation_quantized, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.utils import _get_per_token_block_size @@ -14,6 +14,7 @@ SmoothQuantObservedLinear, ) from typing import Dict, Optional +from torch._dynamo import is_compiling as dynamo_is_compiling def insert_smooth_quant_observer( @@ -94,7 +95,7 @@ def recurse(module: torch.nn.Module, name: str = ''): recurse(child, full_name) recurse(model) - + torch.save(result, save_path) @@ -125,9 +126,8 @@ def recurse(module: torch.nn.Module, name: str = ''): # StaticQuantizeAct and DynamicQuantizeAct are defined as classes to allow for easy serialization and deserialization class StaticQuantizeAct: - def __init__(self, factor, x_scale, target_dtype, quant_min=-127): + def __init__(self, x_scale, target_dtype, quant_min=-127): super().__init__() - self.factor = factor self.x_scale = x_scale self.x_zp = torch.zeros_like(x_scale, dtype=torch.int64) self.target_dtype = target_dtype @@ -135,21 +135,27 @@ def __init__(self, factor, x_scale, target_dtype, quant_min=-127): def __call__(self, input): x_zp = torch.zeros([1], dtype=torch.int64) - act = input / self.factor - act = act.to(input.dtype) - return to_affine_quantized_intx_static(act, self.x_scale, x_zp, list(act.shape), self.target_dtype, self.quant_min) + qx = to_affine_quantized_intx_static( + input, self.x_scale, x_zp, list(input.shape), self.target_dtype, self.quant_min + ) + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + return qx.tensor_impl.int_data.to(qx.dtype) + return qx class DynamicQuantizeAct: - def __init__(self, factor, target_dtype, quant_min=-127): - self.factor = factor + def __init__(self, target_dtype, quant_min=-127): self.target_dtype = target_dtype self.quant_min = quant_min def __call__(self, input): - act = input / self.factor - act = act.to(input.dtype) - return to_affine_quantized_intx(act, MappingType.SYMMETRIC, _get_per_token_block_size(act), self.target_dtype, self.quant_min) + block_size = _get_per_token_block_size(input) + qx = to_affine_quantized_intx( + input, MappingType.SYMMETRIC, block_size, self.target_dtype, self.quant_min + ) + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + return qx.tensor_impl.int_data.to(qx.dtype) + return qx def smooth_quant( @@ -193,10 +199,10 @@ def quantize_weight(observed_linear): is_dynamic = x_scale is None if is_dynamic: - input_quant_func = DynamicQuantizeAct(factor, target_dtype) + input_quant_func = DynamicQuantizeAct(target_dtype) else: - input_quant_func = StaticQuantizeAct(factor, x_scale, target_dtype) - - return to_linear_activation_quantized(qw, input_quant_func) + input_quant_func = StaticQuantizeAct(x_scale, target_dtype) + + return to_linear_scale_activation_quantized(qw, factor, input_quant_func) return _observed_linear_subclass_inserter(quantize_weight) diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index f2e66cd03..e5227bbca 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -133,7 +133,7 @@ def wikitext2_ppl( print(f"Time to save quantized model: {time.time() - t0:.02f} seconds") if compile: model = torch.compile(model) - + return benchmark(model, tokenizer, sequence_length, tasks=["PPL"], device=device) if __name__ == "__main__": diff --git a/torchao/quantization/linear_activation_scale_quantized.py b/torchao/quantization/linear_activation_scale_quantized.py new file mode 100644 index 000000000..ffef7b33f --- /dev/null +++ b/torchao/quantization/linear_activation_scale_quantized.py @@ -0,0 +1,188 @@ +import torch +from typing import Callable +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.utils import ( + TorchAOBaseTensor, + TORCH_VERSION_AT_LEAST_2_5, +) + +__all__ = [ + "LinearActivationScaleQuantizedTensor", + "to_linear_scale_activation_quantized", +] + +aten = torch.ops.aten + +class LinearActivationScaleQuantizedTensor(TorchAOBaseTensor): + """ + Applies activation scaling then quantization for linear operator, this is used to support + SmoothQuant with dynamic quantization or static quantization, user can pass in a `input_quant_func` + that is used to quantize the activation + + Args: + `original_weight_tensor`: the weight tensor, if weight need to be quantized as well, we'd need + to apply quantization to weight first, e.g. for int8 dynamic activation int8 weight quantization + we will first apply int8 quantization to weight and then apply LinearActivationScaleQuantizedTensor + on top of it + `scale`: The scale tensor to be applied to activation. + `input_quant_func` (Callable[[torch.Tensor], torch.Tensor]): a function that takes a high precision + floating point tensor and returns a quantized tensor, this is used to quantize input + """ + def __new__( + cls, + original_weight_tensor: torch.Tensor, + scale: torch.Tensor, + input_quant_func: Callable, + ): + kwargs = {} + dtype = original_weight_tensor.dtype + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + kwargs["device"] = original_weight_tensor.device + shape = original_weight_tensor.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + original_weight_tensor: torch.Tensor, + scale: torch.Tensor, + input_quant_func: Callable[[torch.Tensor], torch.Tensor], + ): + self.original_weight_tensor = original_weight_tensor + self.scale = scale + self.input_quant_func = input_quant_func + + def __repr__(self): + return (f"LinearActivationScaleQuantizedTensor({self.original_weight_tensor}, " + f"scale={self.scale}, quant_func={self.input_quant_func})") + + def __tensor_flatten__(self): + return ["original_weight_tensor", "scale"], [self.input_quant_func] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + original_weight_tensor = tensor_data_dict["original_weight_tensor"] + scale = tensor_data_dict["scale"] + input_quant_func, = tensor_attributes + return cls( + original_weight_tensor, + scale, + input_quant_func, + ) + + @staticmethod + def _quantized_linear_op(input_tensor, weight_tensor, bias): + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + scale = weight_tensor.scale + scaled_input_act = input_tensor / scale + scaled_input_act = scaled_input_act.to(input_tensor.dtype) + aqt = input_quant_func(scaled_input_act) + return torch.nn.functional.linear(aqt, original_weight_tensor, bias) + + @classmethod + def from_float(cls, input_float, scale, input_quant_func): + return cls(input_float, scale, input_quant_func) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.original_weight_tensor), + fn(self.scale), + self.input_quant_func, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.original_weight_tensor.to(**kwargs), + self.scale.to(**kwargs), + self.input_quant_func, + ) + +implements = LinearActivationScaleQuantizedTensor.implements + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if isinstance(weight_tensor, LinearActivationScaleQuantizedTensor): + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + + raise NotImplementedError("LinearActivationScaleQuantizedTensor: No specialized dispatch found for linear op") + +@implements([aten.mm.default, aten.addmm.default]) +def _(func, types, args, kwargs): + if not args[0].is_floating_point(): + raise NotImplementedError(f"LinearActivationScaleQuantizedTensor: expecting a floating point input") + + if func == aten.addmm.default: + assert args[1].shape[-1] == args[2].shape[0], ( + f"need mat1 shape: {args[1].shape} final" + f"dim to match mat2 shape: {args[2].shape} first dim " + ) + input_tensor, weight_tensor, bias = ( + args[1], + args[2], + args[0], + ) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + scale = weight_tensor.scale + scaled_input_act = input_tensor / scale + aqt = input_quant_func(scaled_input_act) + return func(bias, aqt, original_weight_tensor) + else: + # aten.mm.default + assert args[0].shape[-1] == args[1].shape[0], ( + f"need mat1 shape: {args[0].shape} final dim" + f"to match mat2 shape: {args[1].shape} first dim" + ) + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + scale = weight_tensor.scale + scaled_input_act = input_tensor / scale + aqt = input_quant_func(scaled_input_act) + return func(aqt, original_weight_tensor) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + +@implements(aten.t.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t) + ) + +to_linear_scale_activation_quantized = LinearActivationScaleQuantizedTensor.from_float + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with LinearActivationScaleQuantizedTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([LinearActivationScaleQuantizedTensor]) From 9916113eed421dee47dd7cc7f95fefc70a044fda Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 12 Oct 2024 01:55:57 -0400 Subject: [PATCH 12/36] Quantize act/wei to 7 bit on old CPU platforms --- test/prototype/test_smoothquant.py | 16 +++++++++++++--- torchao/prototype/smoothquant/api.py | 10 ++++++---- torchao/prototype/smoothquant/core.py | 10 ---------- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index ef30680bd..7b238fbc1 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -66,7 +66,10 @@ def forward(self, x): data = torch.randn(2, 16, dtype=idtype, device=device) # calibrate - insert_smooth_quant_observer(m, alpha, quant_mode, 1) + reduce_range = device == "cpu" + insert_smooth_quant_observer( + m, alpha, quant_mode, reduce_range=reduce_range, n_calib_examples=1 + ) m(data) # quantize is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) @@ -87,7 +90,10 @@ def forward(self, x): act = data / smoothing_factor wei = weight * smoothing_factor qw, w_scales, w_zps = dynamically_quantize_per_channel( - wei, -127, 127, torch.int8 + wei, + -63 if reduce_range else -127, + 63 if reduce_range else 127, + torch.int8 ) fq_wei = dequantize_per_channel(qw, w_scales, w_zps, torch.float32) if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ @@ -99,6 +105,7 @@ def forward(self, x): obs = HistogramObserver( dtype=torch.int8, qscheme=torch.per_tensor_symmetric, + reduce_range=reduce_range, ) obs(act.float()) act_scale, _ = obs.calculate_qparams() @@ -109,7 +116,10 @@ def forward(self, x): else: # activation is quantized per-row (batch * sequence_length) qx, x_scales, x_zps = dynamically_quantize_per_channel( - act.float(), -127, 127, torch.int8 + act.float(), + -63 if reduce_range else -127, + 63 if reduce_range else 127, + torch.int8 ) fq_act = dequantize_per_channel(qx, x_scales, x_zps, torch.float32) out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index d0509059c..50f36121e 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -21,6 +21,7 @@ def insert_smooth_quant_observer( model: torch.nn.Module, alpha: float = 0.5, quant_mode: str = "static", + reduce_range: bool = False, n_calib_examples: int = 20): """ Inserts SmoothQuantObserver into Linear layers of a given model. @@ -28,12 +29,13 @@ def insert_smooth_quant_observer( Args: model: The model to be modified (in place). Ensure model is on the desired device for calibration mapping_type: symmetric or asymmetric quantization of weight + reduce_range: Quantize act/wei to 7 bits on old CPU platforms n_calib_examples: Number of examples used for calibration """ _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - quant_min = _DTYPE_TO_QVALUE_BOUNDS[torch.int8][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8][1] + quant_min = -63 if reduce_range else -127 + quant_max = 63 if reduce_range else 127 eps = torch.finfo(torch.float32).eps def replace_with_observer(layer): @@ -44,8 +46,8 @@ def replace_with_observer(layer): quant_mode, n_calib_examples, quant_min=quant_min, - quant_max = quant_max, - eps = eps) + quant_max=quant_max, + eps=eps) return SmoothQuantObservedLinear.from_float(layer, observer) _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index fa7ec6032..9db6e2d98 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -24,7 +24,6 @@ def __init__(self, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, zero_point_domain = ZeroPointDomain.INT, - reduce_range: Optional[bool] = False, ): """ A custom observer for SmoothQuant @@ -41,7 +40,6 @@ def __init__(self, scale_dtype: The data type of the scale tensor. zero_point_dtype: The data type of the zero point tensor. zero_point_domain: The domain of the zero point. - reduce_range: Quantize act/wei to less than 8 bits on old platforms """ super().__init__( MappingType.SYMMETRIC, @@ -70,15 +68,11 @@ def __init__(self, ch_axis=-1, dtype=torch.int8, qscheme=torch.per_channel_affine, - reduce_range=False, - quant_min=quant_min, - quant_max=quant_max, eps=eps, ) self.act_obs = HistogramObserver( dtype=torch.int8, qscheme=torch.per_tensor_symmetric, - reduce_range=reduce_range, quant_min=quant_min, quant_max=quant_max, eps=eps, @@ -87,16 +81,12 @@ def __init__(self, ch_axis=1, dtype=torch.int8, qscheme=torch.per_channel_affine, - reduce_range=False, - quant_min=quant_min, - quant_max=quant_max, eps=eps, ) self.wei_oc_obs = PerChannelMinMaxObserver( ch_axis=0, dtype=torch.int8, qscheme=torch.per_channel_symmetric, - reduce_range=reduce_range, quant_min=quant_min, quant_max=quant_max, eps=eps, From 52260b6231d82ed491c82d7d9777d7a838541d58 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 12 Oct 2024 02:16:47 -0400 Subject: [PATCH 13/36] Fix device mismatch --- test/prototype/test_smoothquant.py | 5 +++-- torchao/prototype/smoothquant/api.py | 2 +- torchao/prototype/smoothquant/core.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 7b238fbc1..38e299a0c 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -105,9 +105,10 @@ def forward(self, x): obs = HistogramObserver( dtype=torch.int8, qscheme=torch.per_tensor_symmetric, - reduce_range=reduce_range, + quant_min=-63 if reduce_range else -127, + quant_max=63 if reduce_range else 127, ) - obs(act.float()) + obs(act.float().to("cpu")) act_scale, _ = obs.calculate_qparams() fq_act = torch.quantize_per_tensor( act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8 diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 50f36121e..86b83265d 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -136,7 +136,7 @@ def __init__(self, x_scale, target_dtype, quant_min=-127): self.quant_min = quant_min def __call__(self, input): - x_zp = torch.zeros([1], dtype=torch.int64) + x_zp = torch.zeros([1], dtype=torch.int64, device=self.x_scale.device) qx = to_affine_quantized_intx_static( input, self.x_scale, x_zp, list(input.shape), self.target_dtype, self.quant_min ) diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 9db6e2d98..3cdbde384 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -99,7 +99,7 @@ def forward(self, input: torch.Tensor): # record inputs to find qparams for activation if len(self.inputs) < self.n_calib_examples: self.inputs.append(input.to("cpu").view(-1, input.size(-1))) - self.act_ic_obs(input) + self.act_ic_obs(input.to("cpu")) return input def calculate_qparams(self): From ca50feef96cfc100b271f5b3c84e77dff4142cb8 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 12 Oct 2024 03:26:23 -0400 Subject: [PATCH 14/36] Fix UT failures --- test/prototype/test_smoothquant.py | 21 ++++++++++++--------- torchao/kernel/intmm.py | 4 +++- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 38e299a0c..3818359f5 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -80,7 +80,7 @@ def forward(self, x): # reference weight = m_ref.fc.weight.data.float() - b = m_ref.fc.bias.float() if bias else None + b = m_ref.fc.bias if bias else None x_abs_max_per_ic = torch.abs(data).max(dim=0).values w_abs_max_per_ic = torch.abs(weight).max(dim=0).values smoothing_factor = ( @@ -95,7 +95,7 @@ def forward(self, x): 63 if reduce_range else 127, torch.int8 ) - fq_wei = dequantize_per_channel(qw, w_scales, w_zps, torch.float32) + fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype) if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ not TORCH_VERSION_AT_LEAST_2_2: # _int_mm is not supported in these cases @@ -112,7 +112,7 @@ def forward(self, x): act_scale, _ = obs.calculate_qparams() fq_act = torch.quantize_per_tensor( act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8 - ).dequantize() + ).dequantize().to(idtype) out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) else: # activation is quantized per-row (batch * sequence_length) @@ -122,12 +122,10 @@ def forward(self, x): 63 if reduce_range else 127, torch.int8 ) - fq_act = dequantize_per_channel(qx, x_scales, x_zps, torch.float32) + fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype) out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) - # Quantized weights of the reference and the SmoothQuant model may differ by 1 - # when elements are quantized to -128 in one case and -127 in the other. - # So, the tolerance is relatively big here + # BFloat16 and Float16 have larger errors assert torch.allclose(out, out_ref.to(idtype), atol = 0.2) @@ -149,8 +147,13 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype): calibration_data = dataset[:n_calib_examples] # calibrate - insert_smooth_quant_observer(m, alpha, quant_mode, n_calib_examples) - insert_smooth_quant_observer(m_save_load, alpha, quant_mode, n_calib_examples) + reduce_range = device == "cpu" + insert_smooth_quant_observer( + m, alpha, quant_mode, reduce_range=reduce_range, n_calib_examples=n_calib_examples + ) + insert_smooth_quant_observer( + m_save_load, alpha, quant_mode, reduce_range=reduce_range, n_calib_examples=n_calib_examples + ) for example in calibration_data: m(example.to(device)) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 39be8f018..fca0c8ad7 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -154,4 +154,6 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - return torch.ops.torchao.int_scaled_matmul(a, b, scales1) c = safe_int_mm(a, b) - return c * scales1 + # to float to avoid overflow of float16 + c = c.float() * scales1 + return c.to(scales1.dtype) From 3e90789e86062dc809382b89b7f4c04a38b0a7fc Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 12 Oct 2024 04:23:11 -0400 Subject: [PATCH 15/36] Fix UT --- test/prototype/test_smoothquant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 3818359f5..cf01530b2 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -99,7 +99,7 @@ def forward(self, x): if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ not TORCH_VERSION_AT_LEAST_2_2: # _int_mm is not supported in these cases - out_ref = torch.nn.functional.linear(act, fq_wei, b) + out_ref = torch.nn.functional.linear(act.to(idtype), fq_wei, b) elif quant_mode == "static": # activation is quantized per-tensor obs = HistogramObserver( From d47fcc1a71452103bd943ba1dd72f207f73a7608 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 12 Oct 2024 04:38:48 -0400 Subject: [PATCH 16/36] Don't use torch._int_mm for CPU now because it may overflow --- torchao/kernel/intmm.py | 19 ------------------- torchao/kernel/intmm_triton.py | 6 ++++++ 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index fca0c8ad7..74d396284 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -109,22 +109,6 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return safe_int_mm(a, b) -lib = torch.library.Library("torchao", "FRAGMENT") -if intmm_triton is None: - lib.define("int_scaled_matmul(Tensor a, Tensor b, Tensor scales1) -> Tensor") - - -@torch.library.impl(lib, "int_scaled_matmul", "CPU") -def int_scaled_matmul_cpu(a, b, scales1): - if TORCH_VERSION_AT_LEAST_2_4: - c = torch._int_mm(a, b) - c = c.float() * scales1 - return c.to(scales1.dtype) - else: - c = safe_int_mm(a, b) * scales1.float() - return c.to(scales1.dtype) - - def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor: """ Performs scaled integer matrix multiplication. @@ -150,9 +134,6 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - if intmm_triton is not None and AUTOTUNER_ENABLE: return torch.ops.torchao.int_scaled_matmul(a, b, scales1) - if all([x.device.type == "cpu" for x in (a, b, scales1)]): - return torch.ops.torchao.int_scaled_matmul(a, b, scales1) - c = safe_int_mm(a, b) # to float to avoid overflow of float16 c = c.float() * scales1 diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index 4e84d9cd3..d10dac0ab 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -356,3 +356,9 @@ def int_scaled_matmul_cuda(a, b, scales1): int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs ) return int_scaled_matmul_kernel(a, b, scales1, c, best_config) + + +@torch.library.impl(lib, "int_scaled_matmul", "CPU") +def int_scaled_matmul_cpu(a, b, scales1): + c = torch._int_mm(a, b) + return c.to(scales1.dtype) * scales1 From a195e7320651f72811211f77c21bb6840263a6ed Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 12 Oct 2024 05:04:35 -0400 Subject: [PATCH 17/36] Remove reduce_range --- test/prototype/test_smoothquant.py | 28 +++++++--------------------- torchao/kernel/intmm.py | 7 ++++--- torchao/prototype/smoothquant/api.py | 5 +---- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index cf01530b2..f8c792a77 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -66,10 +66,7 @@ def forward(self, x): data = torch.randn(2, 16, dtype=idtype, device=device) # calibrate - reduce_range = device == "cpu" - insert_smooth_quant_observer( - m, alpha, quant_mode, reduce_range=reduce_range, n_calib_examples=1 - ) + insert_smooth_quant_observer(m, alpha, quant_mode, n_calib_examples=1) m(data) # quantize is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) @@ -90,10 +87,7 @@ def forward(self, x): act = data / smoothing_factor wei = weight * smoothing_factor qw, w_scales, w_zps = dynamically_quantize_per_channel( - wei, - -63 if reduce_range else -127, - 63 if reduce_range else 127, - torch.int8 + wei, -127, 127, torch.int8 ) fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype) if (device == "cpu" and not TORCH_VERSION_AT_LEAST_2_4) or \ @@ -105,8 +99,8 @@ def forward(self, x): obs = HistogramObserver( dtype=torch.int8, qscheme=torch.per_tensor_symmetric, - quant_min=-63 if reduce_range else -127, - quant_max=63 if reduce_range else 127, + quant_min=-127, + quant_max=127, ) obs(act.float().to("cpu")) act_scale, _ = obs.calculate_qparams() @@ -117,10 +111,7 @@ def forward(self, x): else: # activation is quantized per-row (batch * sequence_length) qx, x_scales, x_zps = dynamically_quantize_per_channel( - act.float(), - -63 if reduce_range else -127, - 63 if reduce_range else 127, - torch.int8 + act.float(), -127, 127, torch.int8 ) fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype) out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) @@ -147,13 +138,8 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype): calibration_data = dataset[:n_calib_examples] # calibrate - reduce_range = device == "cpu" - insert_smooth_quant_observer( - m, alpha, quant_mode, reduce_range=reduce_range, n_calib_examples=n_calib_examples - ) - insert_smooth_quant_observer( - m_save_load, alpha, quant_mode, reduce_range=reduce_range, n_calib_examples=n_calib_examples - ) + insert_smooth_quant_observer(m, alpha, quant_mode, n_calib_examples=n_calib_examples) + insert_smooth_quant_observer(m_save_load, alpha, quant_mode, n_calib_examples=n_calib_examples) for example in calibration_data: m(example.to(device)) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 74d396284..fc5929fb4 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -2,7 +2,7 @@ import os import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_2 try: # Only works for torch2.2 or newer. @@ -56,9 +56,10 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: if device_cpu or bad_dimensions_for_cublas: # fallback path - return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( + # Compute in float instead of int32 because int32 matmul is not parallelized on CPU + return torch.matmul(input.cpu().to(torch.float), mat2.cpu().to(torch.float)).to( input.device.type - ) + ).to(torch.int32) # cublas paths if not mat2.is_contiguous(): # silently gives incorrect result without this diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 86b83265d..b6efe4f7b 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -21,7 +21,6 @@ def insert_smooth_quant_observer( model: torch.nn.Module, alpha: float = 0.5, quant_mode: str = "static", - reduce_range: bool = False, n_calib_examples: int = 20): """ Inserts SmoothQuantObserver into Linear layers of a given model. @@ -29,13 +28,11 @@ def insert_smooth_quant_observer( Args: model: The model to be modified (in place). Ensure model is on the desired device for calibration mapping_type: symmetric or asymmetric quantization of weight - reduce_range: Quantize act/wei to 7 bits on old CPU platforms n_calib_examples: Number of examples used for calibration """ _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - quant_min = -63 if reduce_range else -127 - quant_max = 63 if reduce_range else 127 + quant_min, quant_max = -127, 127 eps = torch.finfo(torch.float32).eps def replace_with_observer(layer): From 6627be1d313e1f813a115782e29f096248ded6dd Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 14 Oct 2024 07:37:12 -0400 Subject: [PATCH 18/36] Refine code --- torchao/prototype/smoothquant/api.py | 8 -------- torchao/prototype/smoothquant/example.py | 4 ++++ torchao/quantization/linear_activation_scale_quantized.py | 2 ++ 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index b6efe4f7b..8536c5313 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -1,7 +1,4 @@ import torch -from torchao.quantization.quant_primitives import ( - _DTYPE_TO_QVALUE_BOUNDS, -) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static from torchao.quantization.linear_activation_scale_quantized import ( @@ -173,9 +170,6 @@ def smooth_quant( def quantize_weight(observed_linear): target_dtype = torch.int8 - quant_min = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[target_dtype][1] - nonlocal smoothing_factor, act_scales, wei_scales # act_scales is None for dynamic quantization thus not checked if any(x is None for x in (smoothing_factor, wei_scales)): factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() @@ -192,8 +186,6 @@ def quantize_weight(observed_linear): wei_zero_points, block_size, target_dtype, - quant_min, - quant_max, ) is_dynamic = x_scale is None diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index e5227bbca..19e1baab9 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -43,11 +43,15 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True, devi tokenizer.padding_side = "right" tokenizer.add_eos_token = False + print("Loading dataset") + t0 = time.time() dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt') + print(f"Time to load dataset: {time.time() - t0:.02f} seconds") encodings['input_ids'] = encodings['input_ids'].to(device) + print("Running evaluation") lls, t = [], [] for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose): begin_loc = max(i + stride - sequence_length, 0) diff --git a/torchao/quantization/linear_activation_scale_quantized.py b/torchao/quantization/linear_activation_scale_quantized.py index ffef7b33f..86caceee2 100644 --- a/torchao/quantization/linear_activation_scale_quantized.py +++ b/torchao/quantization/linear_activation_scale_quantized.py @@ -134,6 +134,7 @@ def _(func, types, args, kwargs): original_weight_tensor = weight_tensor.original_weight_tensor scale = weight_tensor.scale scaled_input_act = input_tensor / scale + scaled_input_act = scaled_input_act.to(input_tensor.dtype) aqt = input_quant_func(scaled_input_act) return func(bias, aqt, original_weight_tensor) else: @@ -150,6 +151,7 @@ def _(func, types, args, kwargs): original_weight_tensor = weight_tensor.original_weight_tensor scale = weight_tensor.scale scaled_input_act = input_tensor / scale + scaled_input_act = scaled_input_act.to(input_tensor.dtype) aqt = input_quant_func(scaled_input_act) return func(aqt, original_weight_tensor) From fb981e7c94e0038a1c17d390bb43ecc7b1f28f25 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 14 Oct 2024 07:47:56 -0400 Subject: [PATCH 19/36] Remove torch.compile from example --- torchao/prototype/smoothquant/example.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index 19e1baab9..99e5b1321 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -99,7 +99,6 @@ def wikitext2_ppl( device: str, precision:torch.dtype, sequence_length: int, - compile: bool, model_load_path: str, model_save_path: str): print(f"Loading model on {device}...") @@ -135,8 +134,6 @@ def wikitext2_ppl( t0 = time.time() torch.save(model, model_save_path) print(f"Time to save quantized model: {time.time() - t0:.02f} seconds") - if compile: - model = torch.compile(model) return benchmark(model, tokenizer, sequence_length, tasks=["PPL"], device=device) @@ -151,7 +148,6 @@ def wikitext2_ppl( parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate and evaluate model on. Default is 512") - parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") parser.add_argument("--model-load-path", type=str, default=None, help="Path to load quantized model. If this is provided, " "the model will be loaded from this path instead of quantizing the model.") @@ -168,7 +164,6 @@ def wikitext2_ppl( args.device, args.precision, args.seq_len, - args.compile, args.model_load_path, args.model_save_path ) From 17c374eea5e874e3ab280707ff07a22d3cf87afa Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 15 Oct 2024 03:22:54 -0400 Subject: [PATCH 20/36] Add torch.compile in example --- torchao/prototype/smoothquant/example.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index 99e5b1321..19e1baab9 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -99,6 +99,7 @@ def wikitext2_ppl( device: str, precision:torch.dtype, sequence_length: int, + compile: bool, model_load_path: str, model_save_path: str): print(f"Loading model on {device}...") @@ -134,6 +135,8 @@ def wikitext2_ppl( t0 = time.time() torch.save(model, model_save_path) print(f"Time to save quantized model: {time.time() - t0:.02f} seconds") + if compile: + model = torch.compile(model) return benchmark(model, tokenizer, sequence_length, tasks=["PPL"], device=device) @@ -148,6 +151,7 @@ def wikitext2_ppl( parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") parser.add_argument("--precision", type=str, default="bfloat16", help="Precision type. Default is 'bfloat16'.") parser.add_argument("--seq_len", type=int, default=512, help="Length of examples to calibrate and evaluate model on. Default is 512") + parser.add_argument("--compile", action="store_true", help="Flag to indicate if compilation is required.") parser.add_argument("--model-load-path", type=str, default=None, help="Path to load quantized model. If this is provided, " "the model will be loaded from this path instead of quantizing the model.") @@ -164,6 +168,7 @@ def wikitext2_ppl( args.device, args.precision, args.seq_len, + args.compile, args.model_load_path, args.model_save_path ) From bb76de69e9d27c792f46b22538522d25466cd4b9 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 15 Oct 2024 03:37:48 -0400 Subject: [PATCH 21/36] Debug CI failures --- test/prototype/test_smoothquant.py | 2 ++ torchao/kernel/intmm.py | 11 ++++------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index f8c792a77..653faf7a0 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -52,6 +52,7 @@ def forward(self, x): @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("idtype", idtypes) def test_compute(bias, alpha, quant_mode, device, idtype): + return class Linear(torch.nn.Module): def __init__(self, bias: bool): super().__init__() @@ -125,6 +126,7 @@ def forward(self, x): @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("idtype", idtypes) def test_save_load_recipe(alpha, quant_mode, device, idtype): + return dataset_size = 20 l1, l2, l3 = 512, 256, 128 original_dtype = idtype diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index fc5929fb4..7d076a6e8 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -56,10 +56,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: if device_cpu or bad_dimensions_for_cublas: # fallback path - # Compute in float instead of int32 because int32 matmul is not parallelized on CPU - return torch.matmul(input.cpu().to(torch.float), mat2.cpu().to(torch.float)).to( + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( input.device.type - ).to(torch.int32) + ) # cublas paths if not mat2.is_contiguous(): # silently gives incorrect result without this @@ -127,7 +126,7 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - """ M, K = a.shape K, N = b.shape - assert M == scales1.size(0) or scales1.numel() == 1 + assert M == scales1.size(0) assert 1 == scales1.size(1) assert scales1.is_contiguous() scales1 = scales1.expand((M, N)) @@ -136,6 +135,4 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - return torch.ops.torchao.int_scaled_matmul(a, b, scales1) c = safe_int_mm(a, b) - # to float to avoid overflow of float16 - c = c.float() * scales1 - return c.to(scales1.dtype) + return c * scales1 From 98b2de1ae542875ffc6fabdd67a34ee05473a4e9 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 15 Oct 2024 04:48:29 -0400 Subject: [PATCH 22/36] Debug CI failures (1) --- torchao/kernel/intmm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 7d076a6e8..7b8fbe72a 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -135,4 +135,6 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - return torch.ops.torchao.int_scaled_matmul(a, b, scales1) c = safe_int_mm(a, b) - return c * scales1 + # to float to avoid overflow of float16 + c = c.float() * scales1 + return c.to(scales1.dtype) From 316f5eab9c40e81530241f24e00bb65c86c8f071 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 15 Oct 2024 06:46:12 -0400 Subject: [PATCH 23/36] Debug CI failures (2) --- torchao/kernel/intmm.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 7b8fbe72a..4970f9499 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -135,6 +135,10 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - return torch.ops.torchao.int_scaled_matmul(a, b, scales1) c = safe_int_mm(a, b) - # to float to avoid overflow of float16 - c = c.float() * scales1 - return c.to(scales1.dtype) + + if all([x.device.type == "cpu" for x in (c, scales1)]): + # to float to avoid overflow of float16 + c = c.float() * scales1 + return c.to(scales1.dtype) + + return c * scales1 From b4d8383b939eab59a6a8006d20e501c8cd93e630 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 15 Oct 2024 20:59:58 -0400 Subject: [PATCH 24/36] Debug CI failures (3) --- torchao/kernel/intmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 4970f9499..a1f14218f 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -56,9 +56,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: if device_cpu or bad_dimensions_for_cublas: # fallback path - return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( + return torch.matmul(input.cpu().to(torch.float), mat2.cpu().to(torch.float)).to( input.device.type - ) + ).to(torch.int32) # cublas paths if not mat2.is_contiguous(): # silently gives incorrect result without this From aca06d2212d68b092e7de882317dfff06075b576 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 16 Oct 2024 03:45:08 -0400 Subject: [PATCH 25/36] Work with torch.compile --- test/prototype/test_smoothquant.py | 2 - torchao/kernel/intmm.py | 5 +- torchao/prototype/smoothquant/api.py | 44 +--- .../linear_activation_scale_quantized.py | 191 ++++++++++++++---- 4 files changed, 161 insertions(+), 81 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 653faf7a0..f8c792a77 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -52,7 +52,6 @@ def forward(self, x): @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("idtype", idtypes) def test_compute(bias, alpha, quant_mode, device, idtype): - return class Linear(torch.nn.Module): def __init__(self, bias: bool): super().__init__() @@ -126,7 +125,6 @@ def forward(self, x): @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("idtype", idtypes) def test_save_load_recipe(alpha, quant_mode, device, idtype): - return dataset_size = 20 l1, l2, l3 = 512, 256, 128 original_dtype = idtype diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index a1f14218f..9d5fb5840 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -37,6 +37,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ # torch.compile path if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + if input.device.type == "cpu": + # CPU needs float for better performance and further optimizations + return out_dtype(torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float()) return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) # error checking for cublas path @@ -126,7 +129,7 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - """ M, K = a.shape K, N = b.shape - assert M == scales1.size(0) + assert M == scales1.size(0) or scales1.numel() == 1 assert 1 == scales1.size(1) assert scales1.is_contiguous() scales1 = scales1.expand((M, N)) diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 8536c5313..65981edad 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -120,40 +120,6 @@ def recurse(module: torch.nn.Module, name: str = ''): recurse(model) -# StaticQuantizeAct and DynamicQuantizeAct are defined as classes to allow for easy serialization and deserialization -class StaticQuantizeAct: - def __init__(self, x_scale, target_dtype, quant_min=-127): - super().__init__() - self.x_scale = x_scale - self.x_zp = torch.zeros_like(x_scale, dtype=torch.int64) - self.target_dtype = target_dtype - self.quant_min = quant_min - - def __call__(self, input): - x_zp = torch.zeros([1], dtype=torch.int64, device=self.x_scale.device) - qx = to_affine_quantized_intx_static( - input, self.x_scale, x_zp, list(input.shape), self.target_dtype, self.quant_min - ) - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - return qx.tensor_impl.int_data.to(qx.dtype) - return qx - - -class DynamicQuantizeAct: - def __init__(self, target_dtype, quant_min=-127): - self.target_dtype = target_dtype - self.quant_min = quant_min - - def __call__(self, input): - block_size = _get_per_token_block_size(input) - qx = to_affine_quantized_intx( - input, MappingType.SYMMETRIC, block_size, self.target_dtype, self.quant_min - ) - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - return qx.tensor_impl.int_data.to(qx.dtype) - return qx - - def smooth_quant( smoothing_factor: Optional[torch.Tensor] = None, act_scales: Optional[torch.Tensor] = None, @@ -188,12 +154,8 @@ def quantize_weight(observed_linear): target_dtype, ) - is_dynamic = x_scale is None - if is_dynamic: - input_quant_func = DynamicQuantizeAct(target_dtype) - else: - input_quant_func = StaticQuantizeAct(x_scale, target_dtype) - - return to_linear_scale_activation_quantized(qw, factor, input_quant_func) + return to_linear_scale_activation_quantized( + qw, factor, x_scale, None, target_dtype, -127, 127 + ) return _observed_linear_subclass_inserter(quantize_weight) diff --git a/torchao/quantization/linear_activation_scale_quantized.py b/torchao/quantization/linear_activation_scale_quantized.py index 86caceee2..243ad0685 100644 --- a/torchao/quantization/linear_activation_scale_quantized.py +++ b/torchao/quantization/linear_activation_scale_quantized.py @@ -1,10 +1,13 @@ import torch -from typing import Callable +from typing import Callable, Optional from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import ( TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_5, ) +from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static +from torchao.quantization.utils import _get_per_token_block_size +from torchao.quantization.quant_primitives import MappingType __all__ = [ "LinearActivationScaleQuantizedTensor", @@ -31,8 +34,12 @@ class LinearActivationScaleQuantizedTensor(TorchAOBaseTensor): def __new__( cls, original_weight_tensor: torch.Tensor, - scale: torch.Tensor, - input_quant_func: Callable, + equalization_scale: torch.Tensor, + act_scales: Optional[torch.Tensor], + act_zero_points: Optional[torch.Tensor], + target_dtype: torch.dtype, + quant_min: int, + quant_max: int, ): kwargs = {} dtype = original_weight_tensor.dtype @@ -45,60 +52,147 @@ def __new__( def __init__( self, original_weight_tensor: torch.Tensor, - scale: torch.Tensor, - input_quant_func: Callable[[torch.Tensor], torch.Tensor], + equalization_scale: torch.Tensor, + act_scales: Optional[torch.Tensor], + act_zero_points: Optional[torch.Tensor], + target_dtype: torch.dtype, + quant_min: int, + quant_max: int, ): self.original_weight_tensor = original_weight_tensor - self.scale = scale - self.input_quant_func = input_quant_func + self.equalization_scale = equalization_scale + self.act_scales = act_scales + self.act_zero_points = act_zero_points + self.target_dtype = target_dtype + self.quant_min = quant_min + self.quant_max = quant_max def __repr__(self): return (f"LinearActivationScaleQuantizedTensor({self.original_weight_tensor}, " - f"scale={self.scale}, quant_func={self.input_quant_func})") + f"equalization_scale={self.equalization_scale}, " + f"act_scales={self.act_scales}), " + f"act_zero_points={self.act_zero_points}, " + f"target_dtype={self.target_dtype}, " + f"quant_min={self.quant_min}, " + f"quant_max={self.quant_max})" + ) def __tensor_flatten__(self): - return ["original_weight_tensor", "scale"], [self.input_quant_func] + tensor_data = [ + "original_weight_tensor", + "equalization_scale", + ] + tensor_attributes = [self.target_dtype, self.quant_min, self.quant_max] + if self.act_scales is not None: + tensor_data.append("act_scales") + if self.act_zero_points is not None: + tensor_data.append("act_zero_points") + return tensor_data, tensor_attributes @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): original_weight_tensor = tensor_data_dict["original_weight_tensor"] - scale = tensor_data_dict["scale"] - input_quant_func, = tensor_attributes + equalization_scale = tensor_data_dict["equalization_scale"] + act_scales = tensor_data_dict["act_scales"] if "act_scales" in tensor_data_dict else None + act_zero_points = tensor_data_dict["act_zero_points"] if "act_zero_points" in tensor_data_dict else None + target_dtype, quant_min, quant_max = tensor_attributes return cls( original_weight_tensor, - scale, - input_quant_func, + equalization_scale, + act_scales, + act_zero_points, + target_dtype, + quant_min, + quant_max, ) @staticmethod def _quantized_linear_op(input_tensor, weight_tensor, bias): - input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - scale = weight_tensor.scale - scaled_input_act = input_tensor / scale + equalization_scale = weight_tensor.equalization_scale + scaled_input_act = input_tensor / equalization_scale scaled_input_act = scaled_input_act.to(input_tensor.dtype) - aqt = input_quant_func(scaled_input_act) + if weight_tensor.act_scales is not None: + # static quant + act_zero_points = ( + weight_tensor.act_zero_points + if weight_tensor.act_zero_points is not None + else torch.zeros_like(weight_tensor.act_scales, dtype=torch.int64) + ) + aqt = to_affine_quantized_intx_static( + scaled_input_act, + weight_tensor.act_scales, + act_zero_points, + list(scaled_input_act.shape), + weight_tensor.target_dtype, + quant_min=weight_tensor.quant_min, + quant_max=weight_tensor.quant_max, + ) + else: + # dynamic quant + block_size = _get_per_token_block_size(scaled_input_act) + aqt = to_affine_quantized_intx( + scaled_input_act, + MappingType.SYMMETRIC, + block_size, + weight_tensor.target_dtype, + quant_min=weight_tensor.quant_min, + quant_max=weight_tensor.quant_max, + ) + return torch.nn.functional.linear(aqt, original_weight_tensor, bias) @classmethod - def from_float(cls, input_float, scale, input_quant_func): - return cls(input_float, scale, input_quant_func) + def from_float(cls, input_float, equalization_scale, act_scales, act_zero_points, target_dtype, quant_min, quant_max): + return cls( + input_float, + equalization_scale, + act_scales, + act_zero_points, + target_dtype, + quant_min, + quant_max, + ) def _apply_fn_to_data(self, fn): return self.__class__( fn(self.original_weight_tensor), - fn(self.scale), - self.input_quant_func, + fn(self.equalization_scale), + ( + fn(self.act_scales) + if self.act_scales is not None + else None + ), + ( + fn(self.act_zero_points) + if self.act_zero_points is not None + else None + ), + self.target_dtype, + self.quant_min, + self.quant_max, ) def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) return self.__class__( self.original_weight_tensor.to(**kwargs), - self.scale.to(**kwargs), - self.input_quant_func, + self.equalization_scale.to(**kwargs), + ( + self.act_scales.to(**kwargs) + if self.act_scales is not None + else None + ), + ( + self.act_zero_points.to(**kwargs) + if self.act_zero_points is not None + else None + ), + self.target_dtype, + self.quant_min, + self.quant_max, ) implements = LinearActivationScaleQuantizedTensor.implements @@ -130,29 +224,52 @@ def _(func, types, args, kwargs): args[2], args[0], ) - input_quant_func = weight_tensor.input_quant_func - original_weight_tensor = weight_tensor.original_weight_tensor - scale = weight_tensor.scale - scaled_input_act = input_tensor / scale - scaled_input_act = scaled_input_act.to(input_tensor.dtype) - aqt = input_quant_func(scaled_input_act) - return func(bias, aqt, original_weight_tensor) else: # aten.mm.default assert args[0].shape[-1] == args[1].shape[0], ( f"need mat1 shape: {args[0].shape} final dim" f"to match mat2 shape: {args[1].shape} first dim" ) - input_tensor, weight_tensor = ( + input_tensor, weight_tensor, bias = ( args[0], args[1], + None, ) - input_quant_func = weight_tensor.input_quant_func - original_weight_tensor = weight_tensor.original_weight_tensor - scale = weight_tensor.scale - scaled_input_act = input_tensor / scale - scaled_input_act = scaled_input_act.to(input_tensor.dtype) - aqt = input_quant_func(scaled_input_act) + original_weight_tensor = weight_tensor.original_weight_tensor + equalization_scale = weight_tensor.equalization_scale + scaled_input_act = input_tensor / equalization_scale + scaled_input_act = scaled_input_act.to(input_tensor.dtype) + if weight_tensor.act_scales is not None: + # static quant + act_zero_points = ( + weight_tensor.act_zero_points + if weight_tensor.act_zero_points is not None + else torch.zeros_like(weight_tensor.act_scales, dtype=torch.int64) + ) + aqt = to_affine_quantized_intx_static( + scaled_input_act, + weight_tensor.act_scales, + act_zero_points, + list(scaled_input_act.shape), + weight_tensor.target_dtype, + quant_min=weight_tensor.quant_min, + quant_max=weight_tensor.quant_max, + ) + else: + # dynamic quant + block_size = _get_per_token_block_size(scaled_input_act) + aqt = to_affine_quantized_intx( + scaled_input_act, + MappingType.SYMMETRIC, + block_size, + weight_tensor.target_dtype, + quant_min=weight_tensor.quant_min, + quant_max=weight_tensor.quant_max, + ) + + if func == aten.addmm.default: + return func(bias, aqt, original_weight_tensor) + else: return func(aqt, original_weight_tensor) From dde7545a46c9e297f8944427a2f283a9dabad149 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 16 Oct 2024 05:00:11 -0400 Subject: [PATCH 26/36] Update torchao/kernel/intmm.py --- torchao/kernel/intmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 9d5fb5840..12a01f62c 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -139,7 +139,7 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - c = safe_int_mm(a, b) - if all([x.device.type == "cpu" for x in (c, scales1)]): + if scales1.dtype == torch.half: # to float to avoid overflow of float16 c = c.float() * scales1 return c.to(scales1.dtype) From 00cfadd2e89b1fa114055f1672c6875d8d1d5b82 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 16 Oct 2024 21:35:54 -0400 Subject: [PATCH 27/36] Update readme.md --- torchao/prototype/smoothquant/example.py | 35 ++++++++++++------------ 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index 19e1baab9..1aaba5b24 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -13,9 +13,8 @@ from torchao.quantization import quantize_ -# adapted from the AWQ example def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): - dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") + dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='validation') samples = [] n_tokens = n_samples * block_size n_run = n_tokens @@ -36,11 +35,11 @@ def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): cat_samples = torch.cat(samples, dim=1) return [cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples)] -# adapted from the AWQ example + def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True, device="cuda"): model.eval() - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "right" + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" tokenizer.add_eos_token = False print("Loading dataset") @@ -56,10 +55,10 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True, devi for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose): begin_loc = max(i + stride - sequence_length, 0) end_loc = min(i + stride, encodings['input_ids'].size(1)) - trg_len = end_loc - i + trg_len = end_loc - i input_ids = encodings['input_ids'][:,begin_loc:end_loc] target_ids = input_ids.clone() - target_ids[:,:-trg_len] = -100 #ignore context + target_ids[:,:-trg_len] = -100 # ignore context t1 = time.time() with torch.no_grad(): @@ -79,8 +78,8 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True, devi print('time', str(pred_time) + ' sec/it') return {'perplexity':ppl, 'prediction_time':pred_time} - -# adapted from the AWQ example + + def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): model.eval() model.config.use_cache = False @@ -94,18 +93,18 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): def wikitext2_ppl( model_id: str, + alpha: float, quant_mode: str, - calibration_size: int, - device: str, - precision:torch.dtype, - sequence_length: int, + calibration_size: int, + device: str, + precision:torch.dtype, + sequence_length: int, compile: bool, model_load_path: str, model_save_path: str): print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() - # load any model with torch.nn.linear layers tokenizer = AutoTokenizer.from_pretrained(model_id) if model_load_path is not None and os.path.exists(model_load_path): print(f"Loading quantized model from {model_load_path}") @@ -118,7 +117,7 @@ def wikitext2_ppl( print(f"running calibration") t0 = time.time() # insert observers to find average magnitude and calculate scales - insert_smooth_quant_observer(model, 0.5, quant_mode, calibration_size) + insert_smooth_quant_observer(model, alpha, quant_mode, calibration_size) calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length) for batch in calibration_data: model(batch.to(device)) @@ -136,16 +135,17 @@ def wikitext2_ppl( torch.save(model, model_save_path) print(f"Time to save quantized model: {time.time() - t0:.02f} seconds") if compile: - model = torch.compile(model) + model = torch.compile(model, dynamic=True) return benchmark(model, tokenizer, sequence_length, tasks=["PPL"], device=device) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.") - # Optional arguments with default values parser.add_argument("--model-id", "-m", type=str, help="Repository ID of the model.") + parser.add_argument("--alpha", type=float, default=0.5, help="The alpha hyperparameter for SmoothQuant.") parser.add_argument("--quant-mode", type=str, help="Quantization mode, either static or dynamic.") parser.add_argument("--calibration-samples", type=int, default=10, help="Number of samples to use for calibration. Default is 10.") parser.add_argument("--device", type=str, default="cuda", help="Device to run the evaluation on. Default is 'cuda'.") @@ -163,6 +163,7 @@ def wikitext2_ppl( precision_dtype = getattr(torch, args.precision, torch.bfloat16) ppl = wikitext2_ppl( args.model_id, + args.alpha, args.quant_mode, args.calibration_samples, args.device, From 466d2f10450ea36883d16a601bec8e2f488b5184 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 17 Oct 2024 01:45:30 -0400 Subject: [PATCH 28/36] Update readme.md --- torchao/prototype/smoothquant/readme.md | 83 +++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/torchao/prototype/smoothquant/readme.md b/torchao/prototype/smoothquant/readme.md index e69de29bb..ca7272609 100644 --- a/torchao/prototype/smoothquant/readme.md +++ b/torchao/prototype/smoothquant/readme.md @@ -0,0 +1,83 @@ +# SmothQuant quantization +This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438). + +In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. If activations are dynamically quantized, qparams (i.e., scales) are found at runtime while qparams are found during quantization for static quantization. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. Generally, dynamic quantization produces better accuracy while static quantization has better latency. In both cases, weights and activations are symmetrically quantized. + +## Quick start +Run the example code with +```bash +python example.py -m MODLE_ID --device= --quant-mode= +# An example +python example.py -m meta-llama/Llama-2-7b-hf --device=cuda --quant-mode=dynamic +``` +To use the `torch.compile` for speedup, add `--compile`. You may want to export `TORCHINDUCTOR_FREEZING=1` for even better performance. +```bash +TORCHINDUCTOR_FREEZING=1 python example.py -m MODLE_ID --device= --quant-mode= --compile +``` +To save a quantized model for reuse, specify `--model-save-path` +```bash +python example.py -m MODLE_ID --device= --quant-mode= --model-save-path ./quantized_model.pt +``` +And load it by `--model-load-path` +```bash +python example.py -m MODLE_ID --device= --quant-mode= --model-load-path ./quantized_model.pt +``` + + +## Usage of API +The following APIs are provided: +- insert_smooth_quant_observer +- smooth_quant +- save_smooth_quant_recipe (advanced) +- load_smooth_quant_recipe (advanced) + +`insert_smooth_quant_observer` inserts observers into the model to be quantized. For example: +```python +insert_smooth_quant_observer(model, alpha=0.5, quant_mode="dynamic") +``` +After insertion, run the model for calibration on a certain dataset or (advanced) load a recipe. + +`smooth_quant` applies SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example: +```python +from torchao.prototype.smoothquant import SmoothQuantObservedLinear +is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) +torchao.quantization.quantize_(model, smooth_quant(), is_observed_linear) +``` +`is_observed_linear` is a filter so that we only quantize observed linear layers. + +(Advanced) `save_smooth_quant_recipe` and `load_smooth_quant_recipe` saves or loads a recipe for a model. + +A recipe contains smoothing factors and quantization parameters of weights and activation for all linear layers that are to be quantized. For advanced users, these parameters can be saved and modified somehow to produce better accuray, e.g., different alpha for different layres. Users can even leave some linear layers unquantized by deleting these layers in the recipe. Such modifications can be published as a recipe. By loading the recipe, it can be reused and calibration is no longer needed. + +To save a recipe, users should insert observers and run calibration first. For example, +```python +insert_smooth_quant_observer(model, alpha=0.5, quant_mode="dynamic") +for data in dataset: + model(data) +save_smooth_quant_recipe(model, "./smooth_quant_recipe.json") +``` +To load a recipe, users should insert observers first. For example, +```python +insert_smooth_quant_observer(model) +load_smooth_quant_recipe(model, "./smooth_quant_recipe.json") +``` + +## Benchmark +Running the example case with `torch.compile` on a NVIDIA A10 GPU. +### meta-llama/Llama-2-7b-hf +| Quant Method | Perplexity | Latency (sec/it) | +|-|-|-| +| SmoothQuant dynamic | 7.4318 | 0.6874 | +| SmoothQuant static | 424.995 | 0.4560 | +| AWQ UINT4 Group size 128 | 7.4837 | 1.0018 | + +### meta-llama/Meta-Llama-3-8B +| Quant Method | Perplexity | Latency (sec/it) | +|-|-|-| +| SmoothQuant dynamic | 8.8274 | 0.9008 | +| SmoothQuant static | 124.2236 | 0.5537 | +| AWQ UINT4 Group size 128 | 8.7087 | 1.2868 | + +Note: +1. Static quantization needs tuning on recipe and alpha to get good accuracy. So, the data here is for reference only. +2. AWQ's calibration runs on `mit-han-lab/pile-val-backup`, validation split while SmoothQuant's calibration runs on `wikitext/wikitext-2-raw-v1`, test split. From e970a4ab5f1dffde8c1667e328245352031c4524 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 17 Oct 2024 03:57:57 -0400 Subject: [PATCH 29/36] Debug CI failures (4) --- torchao/kernel/intmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 12a01f62c..41eaa76cf 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -59,9 +59,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: if device_cpu or bad_dimensions_for_cublas: # fallback path - return torch.matmul(input.cpu().to(torch.float), mat2.cpu().to(torch.float)).to( + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( input.device.type - ).to(torch.int32) + ) # cublas paths if not mat2.is_contiguous(): # silently gives incorrect result without this From 5e2abbe6b506b6d0363dc08f0bcb858c536baf56 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 17 Oct 2024 21:16:43 -0400 Subject: [PATCH 30/36] Reimplement with nested tensor subclassing --- test/prototype/test_smoothquant.py | 6 +- torchao/kernel/intmm.py | 2 +- torchao/prototype/smoothquant/api.py | 41 ++- .../linear_activation_scale_quantized.py | 307 ------------------ 4 files changed, 40 insertions(+), 316 deletions(-) delete mode 100644 torchao/quantization/linear_activation_scale_quantized.py diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index f8c792a77..17abfa0bf 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -125,6 +125,8 @@ def forward(self, x): @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("idtype", idtypes) def test_save_load_recipe(alpha, quant_mode, device, idtype): + # This test case will trigger recompilation many times, so set a large cache_size_limit here + torch._dynamo.config.cache_size_limit = 32 dataset_size = 20 l1, l2, l3 = 512, 256, 128 original_dtype = idtype @@ -153,8 +155,8 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype): # quantize is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) quantize_(m, smooth_quant(), is_observed_linear) - # m = torch.compile(m, fullgraph=True) - # m_save_load = torch.compile(m_save_load, fullgraph=True) + m = torch.compile(m, fullgraph=True) + m_save_load = torch.compile(m_save_load, fullgraph=True) out_list = [m(data.squeeze(0)) for data in dataset] out = torch.cat(out_list) save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset] diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 41eaa76cf..9d23408d7 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -38,7 +38,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: # torch.compile path if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): if input.device.type == "cpu": - # CPU needs float for better performance and further optimizations + # Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend return out_dtype(torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float()) return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 65981edad..e91825b0d 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -1,8 +1,14 @@ import torch from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static -from torchao.quantization.linear_activation_scale_quantized import ( - to_linear_scale_activation_quantized, +from torchao.quantization.linear_activation_quantized_tensor import ( + to_linear_activation_quantized, +) +from torchao.quantization.linear_activation_scale import ( + to_weight_tensor_with_linear_activation_scale_metadata, +) +from torchao.quantization.weight_tensor_linear_activation_quantization import ( + to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.utils import _get_per_token_block_size @@ -11,7 +17,6 @@ SmoothQuantObservedLinear, ) from typing import Dict, Optional -from torch._dynamo import is_compiling as dynamo_is_compiling def insert_smooth_quant_observer( @@ -120,6 +125,22 @@ def recurse(module: torch.nn.Module, name: str = ''): recurse(model) +class _ActQuantizer: + def __init__(self, target_dtype, quant_min=-127): + self.target_dtype = target_dtype + self.quant_min = quant_min + + def dynamic_quantize(self, input): + return to_affine_quantized_intx( + input, MappingType.SYMMETRIC, _get_per_token_block_size(input), self.target_dtype, self.quant_min + ) + + def static_quantize(self, input, scale, zero_point): + return to_affine_quantized_intx_static( + input, scale, zero_point, list(input.shape), self.target_dtype, self.quant_min + ) + + def smooth_quant( smoothing_factor: Optional[torch.Tensor] = None, act_scales: Optional[torch.Tensor] = None, @@ -154,8 +175,16 @@ def quantize_weight(observed_linear): target_dtype, ) - return to_linear_scale_activation_quantized( - qw, factor, x_scale, None, target_dtype, -127, 127 - ) + if x_scale is None: + # dynamic quant + qw = to_linear_activation_quantized(qw, _ActQuantizer(target_dtype).dynamic_quantize) + else: + # static quant + x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64) + qw = to_weight_tensor_with_linear_activation_quantization_metadata( + qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point + ) + + return to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype)) return _observed_linear_subclass_inserter(quantize_weight) diff --git a/torchao/quantization/linear_activation_scale_quantized.py b/torchao/quantization/linear_activation_scale_quantized.py deleted file mode 100644 index 243ad0685..000000000 --- a/torchao/quantization/linear_activation_scale_quantized.py +++ /dev/null @@ -1,307 +0,0 @@ -import torch -from typing import Callable, Optional -from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TorchAOBaseTensor, - TORCH_VERSION_AT_LEAST_2_5, -) -from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static -from torchao.quantization.utils import _get_per_token_block_size -from torchao.quantization.quant_primitives import MappingType - -__all__ = [ - "LinearActivationScaleQuantizedTensor", - "to_linear_scale_activation_quantized", -] - -aten = torch.ops.aten - -class LinearActivationScaleQuantizedTensor(TorchAOBaseTensor): - """ - Applies activation scaling then quantization for linear operator, this is used to support - SmoothQuant with dynamic quantization or static quantization, user can pass in a `input_quant_func` - that is used to quantize the activation - - Args: - `original_weight_tensor`: the weight tensor, if weight need to be quantized as well, we'd need - to apply quantization to weight first, e.g. for int8 dynamic activation int8 weight quantization - we will first apply int8 quantization to weight and then apply LinearActivationScaleQuantizedTensor - on top of it - `scale`: The scale tensor to be applied to activation. - `input_quant_func` (Callable[[torch.Tensor], torch.Tensor]): a function that takes a high precision - floating point tensor and returns a quantized tensor, this is used to quantize input - """ - def __new__( - cls, - original_weight_tensor: torch.Tensor, - equalization_scale: torch.Tensor, - act_scales: Optional[torch.Tensor], - act_zero_points: Optional[torch.Tensor], - target_dtype: torch.dtype, - quant_min: int, - quant_max: int, - ): - kwargs = {} - dtype = original_weight_tensor.dtype - kwargs["dtype"] = dtype - kwargs["requires_grad"] = False - kwargs["device"] = original_weight_tensor.device - shape = original_weight_tensor.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - original_weight_tensor: torch.Tensor, - equalization_scale: torch.Tensor, - act_scales: Optional[torch.Tensor], - act_zero_points: Optional[torch.Tensor], - target_dtype: torch.dtype, - quant_min: int, - quant_max: int, - ): - self.original_weight_tensor = original_weight_tensor - self.equalization_scale = equalization_scale - self.act_scales = act_scales - self.act_zero_points = act_zero_points - self.target_dtype = target_dtype - self.quant_min = quant_min - self.quant_max = quant_max - - def __repr__(self): - return (f"LinearActivationScaleQuantizedTensor({self.original_weight_tensor}, " - f"equalization_scale={self.equalization_scale}, " - f"act_scales={self.act_scales}), " - f"act_zero_points={self.act_zero_points}, " - f"target_dtype={self.target_dtype}, " - f"quant_min={self.quant_min}, " - f"quant_max={self.quant_max})" - ) - - def __tensor_flatten__(self): - tensor_data = [ - "original_weight_tensor", - "equalization_scale", - ] - tensor_attributes = [self.target_dtype, self.quant_min, self.quant_max] - if self.act_scales is not None: - tensor_data.append("act_scales") - if self.act_zero_points is not None: - tensor_data.append("act_zero_points") - return tensor_data, tensor_attributes - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - original_weight_tensor = tensor_data_dict["original_weight_tensor"] - equalization_scale = tensor_data_dict["equalization_scale"] - act_scales = tensor_data_dict["act_scales"] if "act_scales" in tensor_data_dict else None - act_zero_points = tensor_data_dict["act_zero_points"] if "act_zero_points" in tensor_data_dict else None - target_dtype, quant_min, quant_max = tensor_attributes - return cls( - original_weight_tensor, - equalization_scale, - act_scales, - act_zero_points, - target_dtype, - quant_min, - quant_max, - ) - - @staticmethod - def _quantized_linear_op(input_tensor, weight_tensor, bias): - original_weight_tensor = weight_tensor.original_weight_tensor - equalization_scale = weight_tensor.equalization_scale - scaled_input_act = input_tensor / equalization_scale - scaled_input_act = scaled_input_act.to(input_tensor.dtype) - if weight_tensor.act_scales is not None: - # static quant - act_zero_points = ( - weight_tensor.act_zero_points - if weight_tensor.act_zero_points is not None - else torch.zeros_like(weight_tensor.act_scales, dtype=torch.int64) - ) - aqt = to_affine_quantized_intx_static( - scaled_input_act, - weight_tensor.act_scales, - act_zero_points, - list(scaled_input_act.shape), - weight_tensor.target_dtype, - quant_min=weight_tensor.quant_min, - quant_max=weight_tensor.quant_max, - ) - else: - # dynamic quant - block_size = _get_per_token_block_size(scaled_input_act) - aqt = to_affine_quantized_intx( - scaled_input_act, - MappingType.SYMMETRIC, - block_size, - weight_tensor.target_dtype, - quant_min=weight_tensor.quant_min, - quant_max=weight_tensor.quant_max, - ) - - return torch.nn.functional.linear(aqt, original_weight_tensor, bias) - - @classmethod - def from_float(cls, input_float, equalization_scale, act_scales, act_zero_points, target_dtype, quant_min, quant_max): - return cls( - input_float, - equalization_scale, - act_scales, - act_zero_points, - target_dtype, - quant_min, - quant_max, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.original_weight_tensor), - fn(self.equalization_scale), - ( - fn(self.act_scales) - if self.act_scales is not None - else None - ), - ( - fn(self.act_zero_points) - if self.act_zero_points is not None - else None - ), - self.target_dtype, - self.quant_min, - self.quant_max, - ) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.original_weight_tensor.to(**kwargs), - self.equalization_scale.to(**kwargs), - ( - self.act_scales.to(**kwargs) - if self.act_scales is not None - else None - ), - ( - self.act_zero_points.to(**kwargs) - if self.act_zero_points is not None - else None - ), - self.target_dtype, - self.quant_min, - self.quant_max, - ) - -implements = LinearActivationScaleQuantizedTensor.implements - -@implements([torch.nn.functional.linear, aten.linear.default]) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - if isinstance(weight_tensor, LinearActivationScaleQuantizedTensor): - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - - raise NotImplementedError("LinearActivationScaleQuantizedTensor: No specialized dispatch found for linear op") - -@implements([aten.mm.default, aten.addmm.default]) -def _(func, types, args, kwargs): - if not args[0].is_floating_point(): - raise NotImplementedError(f"LinearActivationScaleQuantizedTensor: expecting a floating point input") - - if func == aten.addmm.default: - assert args[1].shape[-1] == args[2].shape[0], ( - f"need mat1 shape: {args[1].shape} final" - f"dim to match mat2 shape: {args[2].shape} first dim " - ) - input_tensor, weight_tensor, bias = ( - args[1], - args[2], - args[0], - ) - else: - # aten.mm.default - assert args[0].shape[-1] == args[1].shape[0], ( - f"need mat1 shape: {args[0].shape} final dim" - f"to match mat2 shape: {args[1].shape} first dim" - ) - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - None, - ) - original_weight_tensor = weight_tensor.original_weight_tensor - equalization_scale = weight_tensor.equalization_scale - scaled_input_act = input_tensor / equalization_scale - scaled_input_act = scaled_input_act.to(input_tensor.dtype) - if weight_tensor.act_scales is not None: - # static quant - act_zero_points = ( - weight_tensor.act_zero_points - if weight_tensor.act_zero_points is not None - else torch.zeros_like(weight_tensor.act_scales, dtype=torch.int64) - ) - aqt = to_affine_quantized_intx_static( - scaled_input_act, - weight_tensor.act_scales, - act_zero_points, - list(scaled_input_act.shape), - weight_tensor.target_dtype, - quant_min=weight_tensor.quant_min, - quant_max=weight_tensor.quant_max, - ) - else: - # dynamic quant - block_size = _get_per_token_block_size(scaled_input_act) - aqt = to_affine_quantized_intx( - scaled_input_act, - MappingType.SYMMETRIC, - block_size, - weight_tensor.target_dtype, - quant_min=weight_tensor.quant_min, - quant_max=weight_tensor.quant_max, - ) - - if func == aten.addmm.default: - return func(bias, aqt, original_weight_tensor) - else: - return func(aqt, original_weight_tensor) - - -@implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - -@implements(aten._to_copy.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), - ) - -@implements(aten.t.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.t) - ) - -to_linear_scale_activation_quantized = LinearActivationScaleQuantizedTensor.from_float - -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationScaleQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([LinearActivationScaleQuantizedTensor]) From 90d1b7d433d880a8bc887ac242f4a760df3c8678 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 17 Oct 2024 21:42:25 -0400 Subject: [PATCH 31/36] Test torch.compile only with PyTorch >= 2.5 --- test/prototype/test_smoothquant.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 17abfa0bf..d80d08461 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -8,6 +8,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, ) from torchao.quantization.utils import ( dynamically_quantize_per_channel, @@ -155,8 +156,10 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype): # quantize is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) quantize_(m, smooth_quant(), is_observed_linear) - m = torch.compile(m, fullgraph=True) - m_save_load = torch.compile(m_save_load, fullgraph=True) + if TORCH_VERSION_AT_LEAST_2_5: + # earlier versions are not compatible + m = torch.compile(m, fullgraph=True) + m_save_load = torch.compile(m_save_load, fullgraph=True) out_list = [m(data.squeeze(0)) for data in dataset] out = torch.cat(out_list) save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset] From 03d490b6266302b48e65dee3d8c082c0fac8391d Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 17 Oct 2024 22:49:01 -0400 Subject: [PATCH 32/36] Debug CI failures (5) --- test/prototype/test_smoothquant.py | 2 +- torchao/kernel/intmm.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index d80d08461..002df76e2 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -45,7 +45,7 @@ def forward(self, x): devices = ["cpu"] if torch.cuda.is_available(): devices.append("cuda") -idtypes = (torch.float, torch.bfloat16, torch.half) +idtypes = (torch.float, torch.bfloat16) # torch.half @pytest.mark.parametrize("bias", bias_list) @pytest.mark.parametrize("alpha", alpha_list) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 9d23408d7..21ca2f91f 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -139,9 +139,9 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - c = safe_int_mm(a, b) - if scales1.dtype == torch.half: - # to float to avoid overflow of float16 - c = c.float() * scales1 - return c.to(scales1.dtype) + # if scales1.dtype == torch.half: + # # to float to avoid overflow of float16 + # c = c.float() * scales1 + # return c.to(scales1.dtype) return c * scales1 From f595ed41b99685cc16fc480ca2218965bb812bed Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 17 Oct 2024 23:58:08 -0400 Subject: [PATCH 33/36] Debug CI failures (6) --- test/prototype/test_smoothquant.py | 11 +++++++---- torchao/kernel/intmm.py | 8 ++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 002df76e2..1138aad2c 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -45,7 +45,11 @@ def forward(self, x): devices = ["cpu"] if torch.cuda.is_available(): devices.append("cuda") -idtypes = (torch.float, torch.bfloat16) # torch.half +idtypes = (torch.float, torch.bfloat16, torch.half) + +if TORCH_VERSION_AT_LEAST_2_5: + # This test case will trigger recompilation many times, so set a large cache_size_limit here + torch._dynamo.config.cache_size_limit = 32 @pytest.mark.parametrize("bias", bias_list) @pytest.mark.parametrize("alpha", alpha_list) @@ -73,7 +77,8 @@ def forward(self, x): is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) quantize_(m, smooth_quant(), is_observed_linear) with torch.inference_mode(): - # m = torch.compile(m, fullgraph=True) + if TORCH_VERSION_AT_LEAST_2_5: + m = torch.compile(m, fullgraph=True) out = m(data) # reference @@ -126,8 +131,6 @@ def forward(self, x): @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("idtype", idtypes) def test_save_load_recipe(alpha, quant_mode, device, idtype): - # This test case will trigger recompilation many times, so set a large cache_size_limit here - torch._dynamo.config.cache_size_limit = 32 dataset_size = 20 l1, l2, l3 = 512, 256, 128 original_dtype = idtype diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 21ca2f91f..63f51e6f1 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -139,9 +139,9 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - c = safe_int_mm(a, b) - # if scales1.dtype == torch.half: - # # to float to avoid overflow of float16 - # c = c.float() * scales1 - # return c.to(scales1.dtype) + if scales1.dtype == torch.half: + # to float to avoid overflow of float16 + c = c * scales1.float() + return c.to(torch.half) return c * scales1 From 2202f694ceaf7ec547e82058112fdac512265b13 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 18 Oct 2024 01:05:23 -0400 Subject: [PATCH 34/36] Debug CI failures (7) --- test/prototype/test_spinquant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_spinquant.py b/test/prototype/test_spinquant.py index 1216d74e5..820746963 100644 --- a/test/prototype/test_spinquant.py +++ b/test/prototype/test_spinquant.py @@ -31,7 +31,7 @@ def test_spinquant_no_quantization(device): # Output should be the same without quantization (the rotations cancel out) # TODO: not sure if these atol/rtol are excessively large (it fails for smaller values) - torch.testing.assert_close(out, out_spinquant, atol=5e-2, rtol=1e-2) + torch.testing.assert_close(out, out_spinquant, atol=6e-2, rtol=1e-2) # TODO: test GPTQ compatability? \ No newline at end of file From 6ea8aa8650d4a437d9e8b49e44e1b2f47603e2cf Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 18 Oct 2024 04:06:09 -0400 Subject: [PATCH 35/36] Use MovingAvg observer for activation; Update UT and readme --- test/prototype/test_smoothquant.py | 6 ++--- torchao/kernel/intmm.py | 6 ----- torchao/prototype/smoothquant/core.py | 4 +-- torchao/prototype/smoothquant/readme.md | 34 +++++++++++++++---------- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 1138aad2c..a071e1a0e 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -3,7 +3,7 @@ import pytest import torch import tempfile -from torch.ao.quantization import HistogramObserver +from torch.ao.quantization import MovingAverageMinMaxObserver from torchao.quantization import quantize_ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_2, @@ -45,7 +45,7 @@ def forward(self, x): devices = ["cpu"] if torch.cuda.is_available(): devices.append("cuda") -idtypes = (torch.float, torch.bfloat16, torch.half) +idtypes = (torch.float, torch.bfloat16) if TORCH_VERSION_AT_LEAST_2_5: # This test case will trigger recompilation many times, so set a large cache_size_limit here @@ -102,7 +102,7 @@ def forward(self, x): out_ref = torch.nn.functional.linear(act.to(idtype), fq_wei, b) elif quant_mode == "static": # activation is quantized per-tensor - obs = HistogramObserver( + obs = MovingAverageMinMaxObserver( dtype=torch.int8, qscheme=torch.per_tensor_symmetric, quant_min=-127, diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 63f51e6f1..e1fab0a7f 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -138,10 +138,4 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) - return torch.ops.torchao.int_scaled_matmul(a, b, scales1) c = safe_int_mm(a, b) - - if scales1.dtype == torch.half: - # to float to avoid overflow of float16 - c = c * scales1.float() - return c.to(torch.half) - return c * scales1 diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 3cdbde384..eb2e9a06d 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -2,7 +2,7 @@ from typing import Optional import torch import torch.nn.functional as F -from torch.ao.quantization import PerChannelMinMaxObserver, HistogramObserver +from torch.ao.quantization import PerChannelMinMaxObserver, MovingAverageMinMaxObserver from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -70,7 +70,7 @@ def __init__(self, qscheme=torch.per_channel_affine, eps=eps, ) - self.act_obs = HistogramObserver( + self.act_obs = MovingAverageMinMaxObserver( dtype=torch.int8, qscheme=torch.per_tensor_symmetric, quant_min=quant_min, diff --git a/torchao/prototype/smoothquant/readme.md b/torchao/prototype/smoothquant/readme.md index ca7272609..fac8114e6 100644 --- a/torchao/prototype/smoothquant/readme.md +++ b/torchao/prototype/smoothquant/readme.md @@ -63,21 +63,27 @@ load_smooth_quant_recipe(model, "./smooth_quant_recipe.json") ``` ## Benchmark -Running the example case with `torch.compile` on a NVIDIA A10 GPU. +Running the example case with `torch.compile` on a NVIDIA A10G GPU. ### meta-llama/Llama-2-7b-hf -| Quant Method | Perplexity | Latency (sec/it) | -|-|-|-| -| SmoothQuant dynamic | 7.4318 | 0.6874 | -| SmoothQuant static | 424.995 | 0.4560 | -| AWQ UINT4 Group size 128 | 7.4837 | 1.0018 | +| Quant Method | Perplexity | +|-|-| +| SmoothQuant dynamic | 7.4341 | +| SmoothQuant static | 10.6206 | ### meta-llama/Meta-Llama-3-8B -| Quant Method | Perplexity | Latency (sec/it) | -|-|-|-| -| SmoothQuant dynamic | 8.8274 | 0.9008 | -| SmoothQuant static | 124.2236 | 0.5537 | -| AWQ UINT4 Group size 128 | 8.7087 | 1.2868 | +| Quant Method | Perplexity | +|-|-| +| SmoothQuant dynamic | 8.8184 | +| SmoothQuant static | 12.4086 | -Note: -1. Static quantization needs tuning on recipe and alpha to get good accuracy. So, the data here is for reference only. -2. AWQ's calibration runs on `mit-han-lab/pile-val-backup`, validation split while SmoothQuant's calibration runs on `wikitext/wikitext-2-raw-v1`, test split. +Commands +```bash +# dynamic quant +TORCHINDUCTOR_FREEZING=1 python example.py -m --device=cuda --quant-mode=dynamic --compile +# static quant +TORCHINDUCTOR_FREEZING=1 python example.py -m --device=cuda --quant-mode=static --compile +``` +Environment: +- AWS g5.12xlarge instance +- torch==2.6.0.dev20241017+cu124 +- python==3.12.6 \ No newline at end of file From fa1144c4fceb021afba5b5276bf81e11d7eb522e Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 18 Oct 2024 05:20:25 -0400 Subject: [PATCH 36/36] Revert changes to test_spinquant.py; refine readme --- test/prototype/test_spinquant.py | 2 +- torchao/prototype/smoothquant/__init__.py | 2 +- torchao/prototype/smoothquant/readme.md | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/prototype/test_spinquant.py b/test/prototype/test_spinquant.py index 820746963..1216d74e5 100644 --- a/test/prototype/test_spinquant.py +++ b/test/prototype/test_spinquant.py @@ -31,7 +31,7 @@ def test_spinquant_no_quantization(device): # Output should be the same without quantization (the rotations cancel out) # TODO: not sure if these atol/rtol are excessively large (it fails for smaller values) - torch.testing.assert_close(out, out_spinquant, atol=6e-2, rtol=1e-2) + torch.testing.assert_close(out, out_spinquant, atol=5e-2, rtol=1e-2) # TODO: test GPTQ compatability? \ No newline at end of file diff --git a/torchao/prototype/smoothquant/__init__.py b/torchao/prototype/smoothquant/__init__.py index 6fced77e3..b3b17d666 100644 --- a/torchao/prototype/smoothquant/__init__.py +++ b/torchao/prototype/smoothquant/__init__.py @@ -4,4 +4,4 @@ save_smooth_quant_recipe, load_smooth_quant_recipe, ) -from .core import SmoothQuantObservedLinear \ No newline at end of file +from .core import SmoothQuantObservedLinear diff --git a/torchao/prototype/smoothquant/readme.md b/torchao/prototype/smoothquant/readme.md index fac8114e6..23719034a 100644 --- a/torchao/prototype/smoothquant/readme.md +++ b/torchao/prototype/smoothquant/readme.md @@ -47,12 +47,12 @@ torchao.quantization.quantize_(model, smooth_quant(), is_observed_linear) (Advanced) `save_smooth_quant_recipe` and `load_smooth_quant_recipe` saves or loads a recipe for a model. -A recipe contains smoothing factors and quantization parameters of weights and activation for all linear layers that are to be quantized. For advanced users, these parameters can be saved and modified somehow to produce better accuray, e.g., different alpha for different layres. Users can even leave some linear layers unquantized by deleting these layers in the recipe. Such modifications can be published as a recipe. By loading the recipe, it can be reused and calibration is no longer needed. +A recipe contains smoothing factors and quantization parameters of weights and activation for all linear layers that are to be quantized. For advanced users, these parameters can be saved and modified somehow to produce better accuray, e.g., different alpha for different layers. Users can even leave some linear layers unquantized by deleting these layers in the recipe. Such modifications can be published as a recipe. By loading the recipe, it can be reused and calibration is no longer needed. To save a recipe, users should insert observers and run calibration first. For example, ```python insert_smooth_quant_observer(model, alpha=0.5, quant_mode="dynamic") -for data in dataset: +for data in dataset_for_calibration: model(data) save_smooth_quant_recipe(model, "./smooth_quant_recipe.json") ``` @@ -63,7 +63,7 @@ load_smooth_quant_recipe(model, "./smooth_quant_recipe.json") ``` ## Benchmark -Running the example case with `torch.compile` on a NVIDIA A10G GPU. +Running the example with `torch.compile` on a NVIDIA A10G GPU. ### meta-llama/Llama-2-7b-hf | Quant Method | Perplexity | |-|-|