From 7038f8b54e3cbca1c3128bba6a2a5e99700dc36f Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 10 Oct 2024 14:48:08 -0700 Subject: [PATCH] Rename AQT#2 LayoutType -> Layout (#1049) --- benchmarks/benchmark_fp6.py | 4 +- test/dtypes/test_affine_quantized.py | 4 +- test/dtypes/test_floatx.py | 6 +- test/hqq/test_hqq_affine.py | 4 +- test/integration/test_integration.py | 6 +- test/quantization/test_qat.py | 2 +- test/sparsity/test_marlin.py | 6 +- test/sparsity/test_sparse_api.py | 10 +- torchao/_models/llama/eval.py | 4 +- torchao/_models/llama/generate.py | 4 +- torchao/_models/sam/eval_combo.py | 8 +- torchao/dtypes/__init__.py | 24 +- torchao/dtypes/affine_quantized_tensor.py | 288 +++++++++--------- torchao/dtypes/floatx/__init__.py | 2 +- torchao/dtypes/floatx/floatx.py | 42 +-- torchao/dtypes/uintx/__init__.py | 2 +- torchao/dtypes/uintx/uintx.py | 12 +- torchao/dtypes/utils.py | 8 +- torchao/prototype/autoround/core.py | 10 +- torchao/prototype/awq/api.py | 12 +- torchao/prototype/awq/core.py | 6 +- torchao/prototype/hqq/example.py | 14 +- torchao/quantization/GPTQ_MT.py | 7 +- torchao/quantization/autoquant.py | 16 +- .../scripts/BO_acc_throughput.py | 2 +- torchao/quantization/quant_api.py | 54 ++-- torchao/sparsity/README.md | 10 +- .../sparsity/prototype/superblock/utils.py | 8 +- torchao/sparsity/sparse_api.py | 8 +- torchao/utils.py | 44 +-- tutorials/calibration_flow/awq_like.py | 4 +- tutorials/calibration_flow/static_quant.py | 10 +- .../my_dtype_tensor_subclass.py | 60 ++-- .../my_trainable_tensor_subclass.py | 18 +- 34 files changed, 358 insertions(+), 361 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index 509ea6e86..c6d28c0bd 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -2,14 +2,14 @@ import pandas as pd import torch.nn.functional as F from torchao.dtypes import to_affine_quantized_fpx -from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType +from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm def benchmark(m: int, k: int, n: int): float_data = torch.randn(n, k, dtype=torch.half, device="cuda") - fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2)) + fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2)) fp16_weight = fp6_weight.dequantize(torch.half) fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 2265be31e..4e98ffd56 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -10,7 +10,7 @@ int8_dynamic_activation_int8_semi_sparse_weight, float8_weight_only, ) -from torchao.dtypes import SemiSparseLayoutType +from torchao.dtypes import SemiSparseLayout from torch.testing._internal import common_utils from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -31,7 +31,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): base_functions.append(int4_weight_only(group_size=32)) if do_sparse: - base_functions.append(int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())) + base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())) if is_cuda_8_9: base_functions.append(float8_weight_only()) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index f228c4c0c..751476d04 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -10,7 +10,7 @@ ) from torchao.dtypes.floatx import ( FloatxTensorCoreAQTTensorImpl, - FloatxTensorCoreLayoutType, + FloatxTensorCoreLayout, to_scaled_tc_floatx, from_scaled_tc_floatx, ) @@ -81,8 +81,8 @@ def test_to_copy_device(self, ebits, mbits): x = torch.randn(256, 64) scale = choose_qparams_affine_floatx(x, ebits, mbits) x = quantize_affine_floatx(x, scale, ebits, mbits) - layout_type = FloatxTensorCoreLayoutType(ebits, mbits) - floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda() + _layout = FloatxTensorCoreLayout(ebits, mbits) + floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, _layout).cuda() assert floatx_tensor_impl.device.type == "cuda" floatx_tensor_impl = floatx_tensor_impl.cpu() assert floatx_tensor_impl.device.type == "cpu" diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index c1177d2d4..7eda0ab5d 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -4,9 +4,9 @@ to_affine_quantized_intx, ZeroPointDomain, PlainAQTTensorImpl, - PlainLayoutType, + PlainLayout, TensorCoreTiledAQTTensorImpl, - TensorCoreTiledLayoutType, + TensorCoreTiledLayout, MappingType, ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 46799b491..9d3d60ed4 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -19,7 +19,7 @@ from torchao.quantization.dynamic_quant import ( DynamicallyPerAxisQuantizedLinear, ) -from torchao.dtypes import TensorCoreTiledLayoutType +from torchao.dtypes import TensorCoreTiledLayout from torchao.quantization.quant_api import ( int4_weight_only, int8_weight_only, @@ -876,7 +876,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])): for groupsize in [64, 32]: for inner_k_tiles in [4, 2]: - kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)} + kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)} def api(mod): kwargs_copy = kwargs.copy() @@ -888,7 +888,7 @@ def api(mod): unwrap_tensor_subclass(mod) else: kwargs_copy["inner_k_tiles"] = inner_k_tiles - del kwargs_copy["layout_type"] + del kwargs_copy["layout"] change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy) self._test_lin_weight_subclass_api_impl( diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index e1e670d5d..f0c5601ab 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -13,7 +13,7 @@ import torch from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torchao.dtypes import ( - TensorCoreTiledLayoutType, + TensorCoreTiledLayout, ) from torchao.quantization.prototype.qat.api import ( ComposableQATQuantizer, diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index c12f32ef6..173afd7da 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -5,7 +5,7 @@ from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -from torchao.dtypes import MarlinSparseLayoutType +from torchao.dtypes import MarlinSparseLayout from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.quantization.quant_api import int4_weight_only, quantize_ from torchao.sparsity.marlin import ( @@ -50,7 +50,7 @@ def test_quant_sparse_marlin_layout_eager(self): dense_result = model_copy(self.input.bfloat16()).half() # Sparse + quantized - quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" @@ -67,7 +67,7 @@ def test_quant_sparse_marlin_layout_compile(self): dense_result = model_copy(self.input.bfloat16()).half() # Sparse + quantized - quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) self.model.forward = torch.compile(self.model.forward, fullgraph=True) sparse_result = self.model(self.input) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 9d2535e55..fb0fa1b8e 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -5,7 +5,7 @@ import torch from torch import nn from torch.testing._internal import common_utils -from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType +from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout from torchao.quantization.quant_api import ( int4_weight_only, int8_dynamic_activation_int8_weight, @@ -74,7 +74,7 @@ def test_quant_semi_sparse(self, compile): quantize_( model, - int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), ) if compile: model = torch.compile(model) @@ -108,7 +108,7 @@ def test_sparse_marlin(self, compile): dense_result = model_copy(input.bfloat16()).half() # Sparse + quantized - quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) if compile: model = torch.compile(model) sparse_result = model(input) @@ -185,12 +185,12 @@ def test_sparse(self, compile): quantize_(model_copy, int8_dynamic_activation_int8_weight()) reference = model_copy(input) - from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType + from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout quantize_( model, int8_dynamic_activation_int8_weight( - layout_type=BlockSparseLayoutType(blocksize=64) + layout=BlockSparseLayout(blocksize=64) ), ) if compile: diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 9ef75e256..482cb86af 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -100,8 +100,8 @@ def run_evaluation( group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) if "marlin" in quantization: - from torchao.dtypes import MarlinSparseLayoutType - quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + from torchao.dtypes import MarlinSparseLayout + quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) if "int4wo" in quantization and "gptq" in quantization: # avoid circular imports from torchao._models._eval import InputRecorder diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 23ed9864f..4f2cb4ffc 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -230,8 +230,8 @@ def main( assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model, int4_weight_only(group_size=groupsize)) if "marlin" in quantization: - from torchao.dtypes import MarlinSparseLayoutType - quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + from torchao.dtypes import MarlinSparseLayout + quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) if quantization.startswith("awq"): diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index 3074fd684..cb3f1afb9 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -11,7 +11,7 @@ from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight -from torchao.dtypes import SemiSparseLayoutType, MarlinSparseLayoutType +from torchao.dtypes import SemiSparseLayout, MarlinSparseLayout from torchao.utils import unwrap_tensor_subclass from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -315,7 +315,7 @@ def mlp_only(mod, name): int8_dynamic_activation_int8_weight(), attn_only) quantize_(predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), mlp_lin1_only) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), @@ -326,11 +326,11 @@ def mlp_only(mod, name): # apply sparsify first to set qparams apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) - from torchao.dtypes import MarlinSparseLayoutType + from torchao.dtypes import MarlinSparseLayout quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight(), attn_only) - quantize_(predictor.model.image_encoder, int4_weight_only(layout_type=MarlinSparseLayoutType()), mlp_lin1_only) + quantize_(predictor.model.image_encoder, int4_weight_only(layout=MarlinSparseLayout()), mlp_lin1_only) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8d4be52dc..4ab0c3f70 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -9,13 +9,13 @@ to_affine_quantized_fpx, to_affine_quantized_floatx, to_affine_quantized_floatx_static, - LayoutType, - PlainLayoutType, - SemiSparseLayoutType, - TensorCoreTiledLayoutType, - Float8LayoutType, + Layout, + PlainLayout, + SemiSparseLayout, + TensorCoreTiledLayout, + Float8Layout, Float8AQTTensorImpl, - MarlinSparseLayoutType, + MarlinSparseLayout, ) __all__ = [ @@ -28,11 +28,11 @@ "to_affine_quantized_fpx", "to_affine_quantized_floatx", "to_affine_quantized_floatx_static", - "LayoutType", - "PlainLayoutType", - "SemiSparseLayoutType", - "TensorCoreTiledLayoutType", - "Float8LayoutType", + "Layout", + "PlainLayout", + "SemiSparseLayout", + "TensorCoreTiledLayout", + "Float8Layout", "Float8AQTTensorImpl", - "MarlinSparseLayoutType", + "MarlinSparseLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 0fa864f8b..ff5deb7b0 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -23,8 +23,8 @@ ) from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.utils import ( - LayoutType, - PlainLayoutType, + Layout, + PlainLayout, is_device, get_out_shape, ) @@ -65,7 +65,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ pass - def get_layout_type(self) -> LayoutType: + def get_layout(self) -> Layout: pass @classmethod @@ -74,15 +74,15 @@ def from_plain( data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): - """ Construct a TensorImpl from data, scale, zero_point and the layout_type""" + """ Construct a TensorImpl from data, scale, zero_point and the _layout""" pass def __repr__(self): data, scale, zero_point = self.get_plain() - layout_type = self.get_layout_type() - return f"{self.__class__.__name__}(data={str(data)}... , scale={str(scale)}... , zero_point={str(zero_point)}... , layout_type={layout_type})" + _layout = self.get_layout() + return f"{self.__class__.__name__}(data={str(data)}... , scale={str(scale)}... , zero_point={str(zero_point)}... , _layout={_layout})" ############################## @@ -195,16 +195,16 @@ def __repr__(self): ) def _quantization_type(self): - return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, layout_type={self.layout_type}, tensor_impl_dtype={self.tensor_impl.dtype}, quant_min={self.quant_min}, quant_max={self.quant_max}" + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, _layout={self._layout}, tensor_impl_dtype={self.tensor_impl.dtype}, quant_min={self.quant_min}, quant_max={self.quant_max}" def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: output_dtype = self.dtype - from torchao.dtypes.floatx import FloatxTensorCoreLayoutType - if isinstance(self.layout_type, FloatxTensorCoreLayoutType): + from torchao.dtypes.floatx import FloatxTensorCoreLayout + if isinstance(self._layout, FloatxTensorCoreLayout): int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) + return dequantize_affine_floatx(int_data, scale, self._layout.ebits, self._layout.mbits, output_dtype=output_dtype) else: data, scale, zero_point = self.tensor_impl.get_plain() dq = dequantize_affine( @@ -265,11 +265,11 @@ def from_hp_to_intx( zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, - layout_type: LayoutType = PlainLayoutType(), + _layout: Layout = PlainLayout(), use_hqq: bool = False, ): original_shape = input_float.shape - input_float = layout_type.pre_process(input_float) + input_float = _layout.pre_process(input_float) if use_hqq: assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization." @@ -288,9 +288,9 @@ def from_hp_to_intx( data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) # Note: output will be uint8 tensor for sub byte tensors for now - data = layout_type.post_process(data) - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(data, scale, zero_point, layout_type) + data = _layout.post_process(data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) return cls( tensor_impl, block_size, @@ -312,20 +312,20 @@ def from_hp_to_intx_static( quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, - layout_type: LayoutType = PlainLayoutType(), + _layout: Layout = PlainLayout(), ): if target_dtype not in FP8_TYPES: assert zero_point_domain is not None, "zero_point_domain must be specified for non-fp8 types" assert zero_point is not None, "zero_point must be specified for non-fp8 types" original_shape = input_float.shape - input_float, scale, zero_point = layout_type.pre_process_static(input_float, scale, zero_point, block_size) + input_float, scale, zero_point = _layout.pre_process_static(input_float, scale, zero_point, block_size) int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) - int_data = layout_type.post_process(int_data) + int_data = _layout.post_process(int_data) - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, _layout) return cls( tensor_impl, block_size, @@ -342,7 +342,7 @@ def from_hp_to_floatx( input_float: torch.Tensor, block_size: Tuple[int, ...], target_dtype: torch.dtype, - layout_type: LayoutType, + _layout: Layout, scale_dtype: Optional[torch.dtype] = None, ): @@ -359,7 +359,7 @@ def from_hp_to_floatx( zero_point_dtype=None, preserve_zero=True, zero_point_domain=None, - layout_type=layout_type, + _layout=_layout, use_hqq=False, ) else: @@ -372,7 +372,7 @@ def from_hp_to_floatx_static( scale: torch.Tensor, block_size: Tuple[int, ...], target_dtype: torch.dtype, - layout_type: LayoutType, + _layout: Layout, ): if target_dtype in FP8_TYPES: @@ -385,7 +385,7 @@ def from_hp_to_floatx_static( quant_min=math.ceil(torch.finfo(target_dtype).min), quant_max=math.ceil(torch.finfo(target_dtype).max), zero_point_domain=None, - layout_type=layout_type, + _layout=_layout, ) else: raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static") @@ -394,24 +394,24 @@ def from_hp_to_floatx_static( def from_hp_to_fpx( cls, input_float: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): - from torchao.dtypes.floatx import FloatxTensorCoreLayoutType - assert isinstance(layout_type, FloatxTensorCoreLayoutType), f"Only FloatxTensorCoreLayoutType is supported for floatx, got {layout_type}" + from torchao.dtypes.floatx import FloatxTensorCoreLayout + assert isinstance(_layout, FloatxTensorCoreLayout), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" original_shape = input_float.shape - input_float = layout_type.pre_process(input_float) + input_float = _layout.pre_process(input_float) # per axis quantization, where axis = 1 block_size = list(input_float.shape) block_size[1] = 1 - ebits, mbits = layout_type.ebits, layout_type.mbits + ebits, mbits = _layout.ebits, _layout.mbits # Note: these ops are hardcoded to have per axis quantization (axis=1) right now scale = choose_qparams_affine_floatx(input_float, ebits, mbits) floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) - floatx_packed = layout_type.post_process(floatx_unpacked) + floatx_packed = _layout.post_process(floatx_unpacked) - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) return cls( tensor_impl, block_size, @@ -420,8 +420,8 @@ def from_hp_to_fpx( ) @property - def layout_type(self) -> LayoutType: - return self.tensor_impl.layout_type + def _layout(self) -> Layout: + return self.tensor_impl._layout def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -464,13 +464,13 @@ def _apply_fn_to_data(self, fn): ###################################################### -# LayoutType and TensorImpl Subclass Registration # +# Layout and TensorImpl Subclass Registration # ###################################################### register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor @dataclass(frozen=True) -class SemiSparseLayoutType(LayoutType): +class SemiSparseLayout(Layout): def pre_process(self, input: torch.Tensor) -> torch.Tensor: # prune to 2:4 if not already @@ -481,12 +481,12 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: @dataclass(frozen=True) -class BlockSparseLayoutType(LayoutType): +class BlockSparseLayout(Layout): blocksize: int = 64 @dataclass(frozen=True) -class TensorCoreTiledLayoutType(LayoutType): +class TensorCoreTiledLayout(Layout): inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: @@ -523,12 +523,12 @@ def extra_repr(self): @dataclass(frozen=True) -class Float8LayoutType(LayoutType): +class Float8Layout(Layout): mm_config: Optional[Float8MMConfig] = None @dataclass(frozen=True) -class MarlinSparseLayoutType(LayoutType): +class MarlinSparseLayout(Layout): def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. @@ -548,7 +548,7 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return w_24.t() -@register_layout(PlainLayoutType) +@register_layout(PlainLayout) class PlainAQTTensorImpl(AQTTensorImpl): """ TensorImpl storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point @@ -564,7 +564,7 @@ def __new__( int_data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): kwargs = {} kwargs["device"] = int_data.device @@ -581,23 +581,23 @@ def __init__( int_data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): self.int_data = int_data self.scale = scale self.zero_point = zero_point - self.layout_type = layout_type + self._layout = _layout def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self.layout_type] + return ["int_data", "scale", "zero_point"], [self._layout] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - layout_type, = tensor_attributes - return cls(int_data, scale, zero_point, layout_type) + _layout, = tensor_attributes + return cls(int_data, scale, zero_point, _layout) def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -605,7 +605,7 @@ def to(self, *args, **kwargs): self.int_data.to(kwargs["device"]), self.scale.to(kwargs["device"]), self.zero_point.to(kwargs["device"]), - self.layout_type, + self._layout, ) def _apply_fn_to_data(self, fn): @@ -613,7 +613,7 @@ def _apply_fn_to_data(self, fn): fn(self.int_data), fn(self.scale), fn(self.zero_point), - self.layout_type, + self._layout, ) @classmethod @@ -633,7 +633,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.t.default: tensor = args[0] new = tensor.__class__( - tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor.layout_type + tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -645,7 +645,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif dim == 1: assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return PlainAQTTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type) + return PlainAQTTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self._layout) else: raise NotImplementedError(f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") @@ -658,8 +658,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data, self.scale, self.zero_point - def get_layout_type(self) -> LayoutType: - return self.layout_type + def get_layout(self) -> Layout: + return self._layout @classmethod def from_plain( @@ -667,12 +667,12 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - layout_type: LayoutType, + _layout: Layout, ): - assert isinstance(layout_type, PlainLayoutType) - return cls(int_data, scale, zero_point, layout_type) + assert isinstance(_layout, PlainLayout) + return cls(int_data, scale, zero_point, _layout) -@register_layout(SemiSparseLayoutType) +@register_layout(SemiSparseLayout) class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): """ TensorImpl storage class for semi_sparse_cusparselt layout for affine quantized tensor @@ -706,13 +706,13 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - layout_type: LayoutType, + _layout: Layout, ): - assert isinstance(layout_type, SemiSparseLayoutType) + assert isinstance(_layout, SemiSparseLayout) int_data_compressed = torch._cslt_compress(int_data) - return cls(int_data_compressed, scale, zero_point, layout_type) + return cls(int_data_compressed, scale, zero_point, _layout) -@register_layout(BlockSparseLayoutType) +@register_layout(BlockSparseLayout) class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): bsr_crow_indices: Optional[torch.Tensor] bsr_col_indices: Optional[torch.Tensor] @@ -731,7 +731,7 @@ def __new__( # noqa: PYI034 bsr_values: Optional[torch.Tensor], scale: Optional[torch.Tensor], zero_point: Optional[torch.Tensor], - layout_type: LayoutType, + _layout: Layout, requires_grad: bool = False, ): if bsr_values is None: @@ -755,7 +755,7 @@ def __init__( # noqa: PYI034 bsr_values: Optional[torch.Tensor], scale: Optional[torch.Tensor], zero_point: Optional[torch.Tensor], - layout_type: LayoutType, + _layout: Layout, requires_grad: bool = False, ): self.bsr_crow_indices = bsr_crow_indices @@ -763,13 +763,13 @@ def __init__( # noqa: PYI034 self.bsr_values = bsr_values self.scale = scale self.zero_point = zero_point - self.layout_type = layout_type + self._layout = _layout def __tensor_flatten__(self): inner_tensors = list( filter(lambda x: getattr(self, x) is not None, self.__slots__) ) - tensor_meta = (self.shape, self.layout_type, self.requires_grad) + tensor_meta = (self.shape, self._layout, self.requires_grad) return inner_tensors, tensor_meta @classmethod @@ -780,7 +780,7 @@ def __tensor_unflatten__( outer_size, outer_stride, ) -> torch.Tensor: - shape, layout_type, requires_grad = tensor_meta + shape, _layout, requires_grad = tensor_meta return cls( shape=shape, bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), @@ -788,13 +788,13 @@ def __tensor_unflatten__( bsr_values=inner_tensors.get("bsr_values", None), scale=inner_tensors.get("scale", None), zero_point=inner_tensors.get("zero_point", None), - layout_type=layout_type, + _layout=_layout, requires_grad=requires_grad, ) @classmethod - def from_plain(cls, int_data, scale, zero_point, layout_type): - bsr_tensor = int_data.to_sparse_bsr(layout_type.blocksize) + def from_plain(cls, int_data, scale, zero_point, _layout): + bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) return cls( shape=int_data.shape, bsr_crow_indices=bsr_tensor.crow_indices(), @@ -802,7 +802,7 @@ def from_plain(cls, int_data, scale, zero_point, layout_type): bsr_values=bsr_tensor.values(), scale=scale, zero_point=zero_point, - layout_type = layout_type, + _layout = _layout, requires_grad=False, ) @@ -818,7 +818,7 @@ def _apply_fn_to_data(self, func): bsr_values=func(self.bsr_values), scale=self.scale, zero_point=self.zero_point, - layout_type=self.layout_type, + _layout=self._layout, requires_grad=self.requires_grad, ) @@ -852,7 +852,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) -@register_layout(MarlinSparseLayoutType) +@register_layout(MarlinSparseLayout) class MarlinSparseAQTTensorImpl(AQTTensorImpl): """ TensorImpl storage class for sparse_marlin_24 layout for affine quantized tensor. @@ -877,7 +877,7 @@ def __new__( scale: torch.Tensor, zero_point: torch.Tensor, meta: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, original_shape: torch.Size, group_size: int, num_bits: int, @@ -898,7 +898,7 @@ def __init__( scale: torch.Tensor, zero_point: torch.Tensor, meta: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, original_shape: torch.Size, group_size: int, num_bits: int, @@ -907,7 +907,7 @@ def __init__( self.scale = scale self.zero_point = zero_point self.meta = meta - self.layout_type = layout_type + self._layout = _layout self.original_shape = original_shape self.group_size = group_size self.num_bits = num_bits @@ -926,7 +926,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point", "meta"], [self.layout_type, self.original_shape, self.group_size, self.num_bits] + return ["int_data", "scale", "zero_point", "meta"], [self._layout, self.original_shape, self.group_size, self.num_bits] @classmethod def __tensor_unflatten__( @@ -936,8 +936,8 @@ def __tensor_unflatten__( scale = tensor_data_dict["scale"] zero_point = tensor_data_dict["zero_point"] meta = tensor_data_dict["meta"] - layout_type, original_shape, group_size, num_bits = tensor_attributes - return cls(int_data, scale, zero_point, meta, layout_type, original_shape, group_size, num_bits) + _layout, original_shape, group_size, num_bits = tensor_attributes + return cls(int_data, scale, zero_point, meta, _layout, original_shape, group_size, num_bits) def get_plain(self): from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import @@ -959,10 +959,10 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import - assert isinstance(layout_type, MarlinSparseLayoutType) + assert isinstance(_layout, MarlinSparseLayout) # Linear layers are (in_features, out_features) but the int_data that is reaching this point # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. @@ -1007,12 +1007,12 @@ def from_plain( return cls( marlin_24_q_w_comp, marlin_24_s, zero_point, - meta, layout_type, q_w_24.shape, + meta, _layout, q_w_24.shape, group_size, num_bits ) - def get_layout_type(self) -> LayoutType: - return self.layout_type + def get_layout(self) -> Layout: + return self._layout def _apply_fn_to_data(self, fn): self.int_data = fn(self.int_data) @@ -1022,7 +1022,7 @@ def _apply_fn_to_data(self, fn): return self -@register_layout(Float8LayoutType) +@register_layout(Float8Layout) class Float8AQTTensorImpl(AQTTensorImpl): """ TensorImpl storage class for float8 tensor impl for affine quantized tensor @@ -1036,7 +1036,7 @@ def __new__( float8_data: torch.Tensor, scale: torch.Tensor, transposed: bool, - layout_type: LayoutType, + _layout: Layout, ): kwargs = {} kwargs["device"] = float8_data.device @@ -1053,12 +1053,12 @@ def __init__( float8_data: torch.Tensor, scale: torch.Tensor, transposed: bool, - layout_type: LayoutType, + _layout: Layout, ): self.float8_data = float8_data self.scale = scale self.transposed = transposed - self.layout_type = layout_type + self._layout = _layout def _apply_fn_to_data(self, fn): """ Applys a fn to all tensor components stored on this class""" @@ -1072,19 +1072,19 @@ def to(self, *args, **kwargs): self.float8_data.to(kwargs["device"]), self.scale.to(kwargs["device"]), self.transposed, - self.layout_type, + self._layout, ) def __tensor_flatten__(self): - return ["float8_data", "scale"], [self.transposed, self.layout_type] + return ["float8_data", "scale"], [self.transposed, self._layout] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] - transposed, layout_type, = tensor_attributes - return cls(float8_data, scale, transposed, layout_type) + transposed, _layout, = tensor_attributes + return cls(float8_data, scale, transposed, _layout) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -1112,7 +1112,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif dim == 1: assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type) + return Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) else: raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") else: @@ -1125,8 +1125,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return self.float8_data, self.scale, None - def get_layout_type(self) -> LayoutType: - return self.layout_type + def get_layout(self) -> Layout: + return self._layout @classmethod def from_plain( @@ -1134,24 +1134,24 @@ def from_plain( data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - layout_type: LayoutType, + _layout: Layout, ): """ Main entrypoint for constructing Float8TensorImpl""" assert _is_float8_type(data.dtype), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" - assert isinstance(layout_type, Float8LayoutType), f"Float8 TensorImpl must be constructed from Float8LayoutType but got {layout_type}" - return cls(data, scale, False, layout_type) + assert isinstance(_layout, Float8Layout), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" + return cls(data, scale, False, _layout) def __repr__(self): float8_data, scale, _ = self.get_plain() - layout_type = self.get_layout_type() + _layout = self.get_layout() return (f"{self.__class__.__name__}(\n" f"float8_data={float8_data},\n" f"scale={scale},\n" f"transposed={self.transposed}, " - f"layout_type={layout_type})") + f"_layout={_layout})") -@register_layout(TensorCoreTiledLayoutType) +@register_layout(TensorCoreTiledLayout) class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): """ TensorImpl storage class for tensor_core_tiled tensor impl for affine quantized tensor, this is for int4 only, @@ -1168,7 +1168,7 @@ def __new__( packed_weight: torch.Tensor, scale_and_zero: torch.Tensor, transposed: bool, - layout_type: LayoutType, + _layout: Layout, ): kwargs = {} kwargs["device"] = packed_weight.device @@ -1185,23 +1185,23 @@ def __init__( packed_weight: torch.Tensor, scale_and_zero: torch.Tensor, transposed: bool, - layout_type: LayoutType, + _layout: Layout, ): self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero self.transposed = False - self.layout_type = layout_type + self._layout = _layout def __tensor_flatten__(self): - return ["packed_weight", "scale_and_zero"], [self.transposed, self.layout_type] + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] - transposed, layout_type, = tensor_attributes - return cls(packed_weight, scale_and_zero, transposed, layout_type) + transposed, _layout, = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, _layout) @classmethod def from_plain( @@ -1209,22 +1209,22 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - layout_type: LayoutType + _layout: Layout ): - assert isinstance(layout_type, TensorCoreTiledLayoutType) + assert isinstance(_layout, TensorCoreTiledLayout) if TORCH_VERSION_AT_LEAST_2_5: int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" else: assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles) + packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, _layout.inner_k_tiles) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) - return cls(packed_weight, scale_and_zero, False, layout_type) + return cls(packed_weight, scale_and_zero, False, _layout) def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -1235,7 +1235,7 @@ def to(self, *args, **kwargs): self.packed_weight.to(device), self.scale_and_zero.to(device), self.transposed, - self.layout_type, + self._layout, ) def _apply_fn_to_data(self, fn): @@ -1300,8 +1300,8 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) return int_data, scale, zero - def get_layout_type(self) -> LayoutType: - return self.layout_type + def get_layout(self) -> Layout: + return self._layout ##################################################### @@ -1348,8 +1348,8 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.is_cuda and input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor.layout_type, PlainLayoutType) and - isinstance(weight_tensor.layout_type, PlainLayoutType) + isinstance(input_tensor._layout, PlainLayout) and + isinstance(weight_tensor._layout, PlainLayout) ) def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): @@ -1390,8 +1390,8 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weig isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.is_cuda and input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor.layout_type, PlainLayoutType) and - isinstance(weight_tensor.layout_type, SemiSparseLayoutType) + isinstance(input_tensor._layout, PlainLayout) and + isinstance(weight_tensor._layout, SemiSparseLayout) ) def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): @@ -1421,8 +1421,8 @@ def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.is_cuda and input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor.layout_type, PlainLayoutType) and - isinstance(weight_tensor.layout_type, BlockSparseLayoutType) + isinstance(input_tensor._layout, PlainLayout) and + isinstance(weight_tensor._layout, BlockSparseLayout) ) @@ -1462,7 +1462,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): weight_tensor.dtype == torch.bfloat16 and len(weight_tensor.shape) == 2 and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and - isinstance(weight_tensor.layout_type, TensorCoreTiledLayoutType) + isinstance(weight_tensor._layout, TensorCoreTiledLayout) ) @@ -1516,7 +1516,7 @@ def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): weight_tensor.block_size[0] == 1 and weight_tensor.block_size[1] == weight_tensor.shape[1] and weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor.layout_type, PlainLayoutType) + isinstance(weight_tensor._layout, PlainLayout) ) def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): @@ -1539,7 +1539,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): return y def _linear_f16_act_floatx_weight_check(input_tensor, weight_tensor, bias): - from torchao.dtypes.floatx import FloatxTensorCoreLayoutType + from torchao.dtypes.floatx import FloatxTensorCoreLayout return ( # input is native float32 tensor not is_traceable_wrapper_subclass(input_tensor) and @@ -1547,18 +1547,18 @@ def _linear_f16_act_floatx_weight_check(input_tensor, weight_tensor, bias): input_tensor.dtype == torch.float16 and # weight is floatx Tensor isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor.layout_type, FloatxTensorCoreLayoutType) and + isinstance(weight_tensor._layout, FloatxTensorCoreLayout) and ( # weight is using fp6 quantization - (weight_tensor.layout_type.ebits == 3 and - weight_tensor.layout_type.mbits == 2) or - (weight_tensor.layout_type.ebits == 2 and - weight_tensor.layout_type.mbits == 3) or + (weight_tensor._layout.ebits == 3 and + weight_tensor._layout.mbits == 2) or + (weight_tensor._layout.ebits == 2 and + weight_tensor._layout.mbits == 3) or # weight is using fp5 quantization - (weight_tensor.layout_type.ebits == 2 and - weight_tensor.layout_type.mbits == 2) or - (weight_tensor.layout_type.ebits == 3 and - weight_tensor.layout_type.mbits == 1) + (weight_tensor._layout.ebits == 2 and + weight_tensor._layout.mbits == 2) or + (weight_tensor._layout.ebits == 3 and + weight_tensor._layout.mbits == 1) ) ) @@ -1577,8 +1577,8 @@ def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 out = quant_llm_linear( - weight.layout_type.ebits, - weight.layout_type.mbits, + weight._layout.ebits, + weight._layout.mbits, act_reshaped, weight.tensor_impl.packed_floatx_data, weight.tensor_impl.scale, @@ -1598,7 +1598,7 @@ def _linear_fp8_act_fp8_weight_check( def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and - isinstance(aqt.layout_type, Float8LayoutType) + isinstance(aqt._layout, Float8Layout) and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) @@ -1620,7 +1620,7 @@ def _linear_fp8_act_fp8_weight_impl( bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" - scaled_mm_config = weight_tensor.layout_type.mm_config + scaled_mm_config = weight_tensor._layout.mm_config out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) # Weight tensor preprocessing @@ -1666,7 +1666,7 @@ def _linear_fp_act_fp8_weight_check( input_tensor.is_floating_point() and # weight is float8 quantized affine quantized tensor isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor.layout_type, Float8LayoutType) + isinstance(weight_tensor._layout, Float8Layout) and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) ) @@ -1685,7 +1685,7 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, input_tensor.dtype == torch.float16 and len(weight_tensor.shape) == 2 and weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor.layout_type, MarlinSparseLayoutType) + isinstance(weight_tensor._layout, MarlinSparseLayout) ) def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): @@ -1753,8 +1753,8 @@ def _(func, types, args, kwargs): try: return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1780,8 +1780,8 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1804,8 +1804,8 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: raise e if isinstance(input_tensor, AffineQuantizedTensor): diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 39461d886..d7559015f 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1 +1 @@ -from .floatx import FloatxTensorCoreLayoutType, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP +from .floatx import FloatxTensorCoreLayout, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP diff --git a/torchao/dtypes/floatx/floatx.py b/torchao/dtypes/floatx/floatx.py index 5a9aab035..f86210637 100644 --- a/torchao/dtypes/floatx/floatx.py +++ b/torchao/dtypes/floatx/floatx.py @@ -6,7 +6,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones from torchao.dtypes.utils import ( - LayoutType, + Layout, ) from torchao.quantization.quant_api import _get_linear_subclass_inserter from dataclasses import dataclass @@ -353,13 +353,13 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> # quantization api integrations @dataclass(frozen=True) -class FloatxTensorCoreLayoutType(LayoutType): +class FloatxTensorCoreLayout(Layout): """Layout type for FloatxTensorCoreAQTTensorImpl """ ebits: int mbits: int -@register_layout(FloatxTensorCoreLayoutType) +@register_layout(FloatxTensorCoreLayout) class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), it has a internal tensor field of "packed_floatx_data", which is packed from the @@ -386,11 +386,11 @@ def __new__( cls, packed_floatx_data: torch.Tensor, scale: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): assert packed_floatx_data.ndim == 2 assert packed_floatx_data.dtype == torch.uint8 - shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8) + shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8) kwargs = {} kwargs["device"] = packed_floatx_data.device kwargs["layout"] = ( @@ -404,25 +404,25 @@ def __init__( self, packed_floatx_data: torch.Tensor, scale: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): self.packed_floatx_data = packed_floatx_data self.scale = scale - self.layout_type = layout_type + self._layout = _layout def __tensor_flatten__(self): - return ["packed_floatx_data", "scale"], [self.layout_type] + return ["packed_floatx_data", "scale"], [self._layout] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): packed_floatx_data, scale = tensor_data_dict["packed_floatx_data"], tensor_data_dict["scale"] - layout_type, = tensor_attributes - return cls(packed_floatx_data, scale, layout_type) + _layout, = tensor_attributes + return cls(packed_floatx_data, scale, _layout) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self.layout_type.ebits + self.layout_type.mbits) + unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits) return unpacked_floatx_data, self.scale @classmethod @@ -431,7 +431,7 @@ def from_plain( unpacked_floatx_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - layout_type: LayoutType, + _layout: Layout, ): """ Format for `unpacked_floatx_data` will be: @@ -440,20 +440,20 @@ def from_plain( For example for fp6_e3_m2, the format will be: `00SEEEMM`, where S is sign bit, E is exponent bit, M is mantissa bit """ - assert isinstance(layout_type, FloatxTensorCoreLayoutType) - packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + layout_type.ebits + layout_type.mbits) - return cls(packed_floatx_data, scale, layout_type) + assert isinstance(_layout, FloatxTensorCoreLayout) + packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits) + return cls(packed_floatx_data, scale, _layout) def __repr__(self): unpacked_floatx_data, scale = self.get_plain() - layout_type = self.get_layout_type() - return f"{self.__class__.__name__}(unpacked_floatx_data={unpacked_floatx_data}, scale={scale}, layout_type={layout_type})" + _layout = self.get_layout() + return f"{self.__class__.__name__}(unpacked_floatx_data={unpacked_floatx_data}, scale={scale}, _layout={_layout})" def _apply_fn_to_data(self, fn): return self.__class__( fn(self.packed_floatx_data), fn(self.scale), - self.layout_type, + self._layout, ) def to(self, *args, **kwargs): @@ -462,7 +462,7 @@ def to(self, *args, **kwargs): return self.__class__( self.packed_floatx_data.to(device), self.scale.to(device), - self.layout_type, + self._layout, ) @classmethod @@ -488,5 +488,5 @@ def __torch_dispatch__(cls, func, types, args, kwargs): __torch_function__ = torch._C._disabled_torch_function_impl - def get_layout_type(self) -> LayoutType: - return self.layout_type + def get_layout(self) -> Layout: + return self._layout diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 5caaa8b1b..c44803f6d 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1 +1 @@ -from .uintx import UintxTensor, UintxLayoutType, UintxAQTTensorImpl, to_uintx, _DTYPE_TO_BIT_WIDTH +from .uintx import UintxTensor, UintxLayout, UintxAQTTensorImpl, to_uintx, _DTYPE_TO_BIT_WIDTH diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx.py index eb63fc619..a48faee8d 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx.py @@ -5,7 +5,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from .bitpacking import pack, unpack from torchao.dtypes.utils import ( - LayoutType, + Layout, ) from torchao.utils import TorchAOBaseTensor from torchao.dtypes.affine_quantized_tensor import PlainAQTTensorImpl, register_layout @@ -187,14 +187,14 @@ def _(func, types, args, kwargs): to_uintx = UintxTensor.from_uint8 @dataclass(frozen=True) -class UintxLayoutType(LayoutType): +class UintxLayout(Layout): dtype: torch.dtype pack_dim: int = -1 def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype, self.pack_dim) -@register_layout(UintxLayoutType) +@register_layout(UintxLayout) class UintxAQTTensorImpl(PlainAQTTensorImpl): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -206,7 +206,7 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): - assert isinstance(layout_type, UintxLayoutType) - return cls(int_data, scale, zero_point, layout_type) + assert isinstance(_layout, UintxLayout) + return cls(int_data, scale, zero_point, _layout) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 4a6b3a0bb..d17231c1e 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -3,7 +3,7 @@ from dataclasses import dataclass """ -Base class for different LayoutType, should not be instantiated directly +Base class for different Layout, should not be instantiated directly used to allow users to pass around configurations for the tensor impl, e.g. inner_k_tiles for int4 tensor core tiled tensor impl @@ -12,7 +12,7 @@ behaviors when running the same operator, e.g. transpose, quantized_linear. """ @dataclass(frozen=True) -class LayoutType: +class Layout: def pre_process(self, input: torch.Tensor) -> torch.Tensor: return input @@ -29,10 +29,10 @@ def extra_repr(self) -> str: return "" """ -Plain LayoutType, the most basic LayoutType, also has no extra metadata, will typically be the default +Plain Layout, the most basic Layout, also has no extra metadata, will typically be the default """ @dataclass(frozen=True) -class PlainLayoutType(LayoutType): +class PlainLayout(Layout): pass def is_device(target_device_str: str, device: Union[str, torch.device]): diff --git a/torchao/prototype/autoround/core.py b/torchao/prototype/autoround/core.py index c602473c5..1b7c86f6d 100644 --- a/torchao/prototype/autoround/core.py +++ b/torchao/prototype/autoround/core.py @@ -7,7 +7,7 @@ import torchao.prototype.autoround.utils as ar_utils import torchao.quantization as ao_quant -from torchao.dtypes import TensorCoreTiledLayoutType, to_affine_quantized_intx_static +from torchao.dtypes import TensorCoreTiledLayout, to_affine_quantized_intx_static from torchao.prototype.autoround.multi_tensor import _multi_tensor_config, MultiTensor from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.utils import find_multiple @@ -183,7 +183,7 @@ def to_uintx_weight(input_float): block_size = (1, observed_linear.group_size) from torchao.dtypes.uintx.uintx import ( _BIT_WIDTH_TO_DTYPE, - UintxLayoutType, + UintxLayout, ) from torchao.quantization.quant_primitives import ZeroPointDomain @@ -192,7 +192,7 @@ def to_uintx_weight(input_float): ), f"Invalid bits: {_auto_round_config.bits}" dtype = _BIT_WIDTH_TO_DTYPE[_auto_round_config.bits] pack_dim = -1 - layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim) + _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) return to_affine_quantized_intx_static( input_float=input_float, scale=scale.to(input_float.dtype), @@ -202,7 +202,7 @@ def to_uintx_weight(input_float): quant_min=quant_min, quant_max=quant_max, zero_point_domain=ZeroPointDomain.INT, - layout_type=layout_type, + _layout=_layout, ) def to_int4_tinygemm_weight(input_float): @@ -256,7 +256,7 @@ def to_int4_tinygemm_weight(input_float): quant_min=quant_min, quant_max=quant_max, zero_point_domain=ZeroPointDomain.FLOAT, - layout_type=TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles), + _layout=TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles), ) # TODO(Yi): better way to select the weight quantization function diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index fc1b04f94..0a26ab98d 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -9,10 +9,10 @@ ) from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType +from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayout from torchao.dtypes import( to_affine_quantized_intx, - TensorCoreTiledLayoutType, + TensorCoreTiledLayout, ) from .core import( AWQObserver, @@ -106,14 +106,14 @@ def weight_quant_func(observed_linear): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - layout_type = TensorCoreTiledLayoutType(inner_k_tiles=8) + _layout = TensorCoreTiledLayout(inner_k_tiles=8) else: target_dtype = torch.uint8 eps = torch.finfo(torch.float32).eps preserve_zero = True zero_point_dtype = torch.int64 zero_point_domain = ZeroPointDomain.INT - layout_type = UintxLayoutType(quant_dtype) + _layout = UintxLayout(quant_dtype) mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) @@ -128,12 +128,10 @@ def weight_quant_func(observed_linear): zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, - layout_type=layout_type, + _layout=_layout, use_hqq=use_hqq ) return to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) return _observed_linear_subclass_inserter(weight_quant_func) - - diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 725c168f9..034d73639 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType +from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayout from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.granularity import Granularity from torchao.quantization.quant_primitives import ( @@ -110,7 +110,7 @@ def calculate_qparams(self): ratio = i * 1 / self.scale_options scales = self.average.pow(ratio).to(self.weight.dtype) scales = scales / (scales.max() * scales.min()).sqrt() - layout = UintxLayoutType(self.target_dtype) + layout = UintxLayout(self.target_dtype) # regardless of weight dtype, we have to store as packed uint8 tensors tensor_dtype = torch.uint8 w = to_affine_quantized_intx( @@ -125,7 +125,7 @@ def calculate_qparams(self): zero_point_dtype = self.zero_point_dtype, preserve_zero = self.preserve_zero, zero_point_domain = self.zero_point_domain, - layout_type = layout + _layout = layout ) loss = 0 for i in range(self.n_validation_examples): diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py index f410a11cd..07d5dea20 100644 --- a/torchao/prototype/hqq/example.py +++ b/torchao/prototype/hqq/example.py @@ -4,9 +4,9 @@ to_affine_quantized_intx, ZeroPointDomain, PlainAQTTensorImpl, - PlainLayoutType, + PlainLayout, TensorCoreTiledAQTTensorImpl, - TensorCoreTiledLayoutType, + TensorCoreTiledLayout, MappingType, ) @@ -34,7 +34,7 @@ preserve_zero = False zero_point_domain = ZeroPointDomain.FLOAT zero_point_dtype = compute_dtype -layout_type = PlainLayoutType() +_layout = PlainLayout() for nbits in list(range(2, 9))[::-1]: print('------------------------------------------------------------------------------') @@ -47,7 +47,7 @@ quant_max=2**nbits - 1, zero_point_domain= zero_point_domain, preserve_zero=preserve_zero, - layout_type=layout_type, + _layout=_layout, ) linear_layer.weight = q_tensor_default @@ -66,7 +66,7 @@ quant_max=2**nbits - 1, zero_point_domain=zero_point_domain, preserve_zero=preserve_zero, - layout_type=layout_type, + _layout=_layout, use_hqq=True, ) @@ -87,7 +87,7 @@ nbits = 4 target_dtype = torch.int32 inner_k_tiles = 8 -layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles) +_layout = TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles) int4_weight_only_patch_fct = int4_weight_only(group_size=group_size, inner_k_tiles=inner_k_tiles) linear_layer_default = torch.nn.Linear(in_features, out_features, bias=False, device=device) @@ -108,7 +108,7 @@ quant_max=2**nbits - 1, zero_point_domain=zero_point_domain, preserve_zero=preserve_zero, - layout_type=layout_type, + _layout=_layout, use_hqq=True, ) linear_layer.weight = q_tensor_hqq diff --git a/torchao/quantization/GPTQ_MT.py b/torchao/quantization/GPTQ_MT.py index 6545650c5..890b707ec 100644 --- a/torchao/quantization/GPTQ_MT.py +++ b/torchao/quantization/GPTQ_MT.py @@ -12,7 +12,7 @@ from torchao.dtypes import ( to_affine_quantized_intx_static, - TensorCoreTiledLayoutType + TensorCoreTiledLayout ) from torchao.quantization.quant_primitives import ( MappingType, @@ -613,7 +613,7 @@ def make_qtensor(q, qparams): quant_min = 0 quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT - layout_type = TensorCoreTiledLayoutType(inner_k_tiles=8) + _layout = TensorCoreTiledLayout(inner_k_tiles=8) # at least the big up to here should be a util quantized_tensor = to_affine_quantized_intx_static( @@ -625,7 +625,7 @@ def make_qtensor(q, qparams): quant_min=quant_min, quant_max=quant_max, zero_point_domain=zero_point_domain, - layout_type=layout_type, + _layout=_layout, ) return quantized_tensor self.make_qtensor = make_qtensor @@ -690,4 +690,3 @@ def _replace_with_custom_fn_if_matches_filter( if new_child is not child: setattr(model, name, new_child) return model - diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 40c587aa3..9a0b793ea 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,7 +9,7 @@ Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) -from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType +from torchao.dtypes import AffineQuantizedTensor, PlainLayout, TensorCoreTiledLayout, Float8Layout from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing from .granularity import ( @@ -311,11 +311,11 @@ def get_per_token_block_size(x): input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - layout_type = PlainLayoutType() + _layout = PlainLayout() input_quant_func = lambda x: to_affine_quantized_intx(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) block_size = get_weight_block_size(weight) - weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) + weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, _layout=_layout) weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) return weight @@ -437,7 +437,7 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): @classmethod def from_float(cls, weight): group_size = cls.group_size - layout_type = TensorCoreTiledLayoutType(inner_k_tiles=8) + _layout = TensorCoreTiledLayout(inner_k_tiles=8) if weight.shape[-1] % group_size != 0: return weight @@ -451,7 +451,7 @@ def from_float(cls, weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, _layout=_layout, use_hqq=use_hqq) class AQInt4G64WeightOnlyQuantizedLinearWeight(AQInt4G32WeightOnlyQuantizedLinearWeight): group_size: int = 64 @@ -494,7 +494,7 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): @classmethod def from_float(cls, weight): block_size = (1, weight.shape[1]) - return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) + return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, _layout=Float8Layout()) class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): """ @@ -520,7 +520,7 @@ def get_per_token_block_size(x): return block_size input_target_dtype = torch.float8_e4m3fn - layout_type = Float8LayoutType(mm_config=Float8MMConfig(use_fast_accum=True)) + _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) input_quant_func = lambda x: _input_activation_quant_func_fp8( x=x, activation_granularity=cls.activation_granularity, @@ -531,7 +531,7 @@ def get_per_token_block_size(x): input_float=weight, block_size=block_size, target_dtype=target_dtype, - layout_type=layout_type, + _layout=_layout, scale_dtype=torch.float32, ) weight = super(AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py index d82473433..a182ea0b2 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py @@ -31,7 +31,7 @@ from torchao._models.llama.tokenizer import get_tokenizer from torchao._models._eval import TransformerEvalWrapper, InputRecorder -from torchao.dtypes import TensorCoreTiledLayoutType +from torchao.dtypes import TensorCoreTiledLayout from torchao._models.llama.generate import ( device_sync, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4f4aa099a..91803fe3f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -23,17 +23,17 @@ from typing import Any, Callable, Union, Dict, Optional, Literal, Tuple import types -from torchao.dtypes.uintx.uintx import UintxLayoutType +from torchao.dtypes.uintx.uintx import UintxLayout from torchao.dtypes import ( to_affine_quantized_intx, to_affine_quantized_floatx, to_affine_quantized_floatx_static, - TensorCoreTiledLayoutType, - PlainLayoutType, + TensorCoreTiledLayout, + PlainLayout, AffineQuantizedTensor, - SemiSparseLayoutType, - Float8LayoutType, - MarlinSparseLayoutType, + SemiSparseLayout, + Float8Layout, + MarlinSparseLayout, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, @@ -514,7 +514,7 @@ def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType. return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type) -def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False): +def int4_weight_only(group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel @@ -530,7 +530,7 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner Args: `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] - `layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)` + `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` `use_hqq`: whether to use hqq or default quantization mode, default is False """ def apply_int4_weight_only_quant(weight): @@ -553,12 +553,12 @@ def apply_int4_weight_only_quant(weight): # Sparse Marlin only supports symmetric quantization. # NOTE: If we start having lots of layouts that require different configurations, # we should consider moving this logic somewhere else. - if isinstance(layout_type, MarlinSparseLayoutType): + if isinstance(layout, MarlinSparseLayout): mapping_type = MappingType.SYMMETRIC preserve_zero = True zero_point_domain = ZeroPointDomain.INT - return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, _layout=layout, use_hqq=use_hqq) return _get_linear_subclass_inserter(apply_int4_weight_only_quant) @@ -586,7 +586,7 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) -def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()): +def int8_dynamic_activation_int8_weight(layout=PlainLayout()): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers @@ -612,7 +612,7 @@ def get_weight_block_size(x): input_quant_func = _int8_symm_per_token_reduced_range_quant block_size = get_weight_block_size(weight) - weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) + weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout) weight = to_linear_activation_quantized(weight, input_quant_func) return weight @@ -624,12 +624,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout_type kwarg in int8_dynamic_activation_int8_weight instead. + warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. - from torchao.dtypes import SemiSparseLayoutType - int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()""") + from torchao.dtypes import SemiSparseLayout + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""") - return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) + return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): @@ -652,7 +652,7 @@ def apply_float8wo_quant(weight): block_size=block_size, target_dtype=weight_dtype, scale_dtype=None, - layout_type=Float8LayoutType(mm_config=None), + _layout=Float8Layout(mm_config=None), ) return _get_linear_subclass_inserter(apply_float8wo_quant) @@ -709,7 +709,7 @@ def _input_activation_quant_func_fp8( block_size=block_size, target_dtype=activation_dtype, scale_dtype=torch.float32, - layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight + _layout=Float8Layout(mm_config=None), # Config is stored on weight ) else: assert isinstance(activation_granularity, PerTensor), "Static quantization only supports PerTensor granularity" @@ -718,7 +718,7 @@ def _input_activation_quant_func_fp8( block_size=block_size, scale=scale, target_dtype=activation_dtype, - layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight + _layout=Float8Layout(mm_config=None), # Config is stored on weight ) return activation @@ -762,7 +762,7 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor): block_size=block_size, target_dtype=weight_dtype, scale_dtype=torch.float32, - layout_type=Float8LayoutType(mm_config=mm_config), + _layout=Float8Layout(mm_config=mm_config), ) input_quant_func = partial( @@ -812,7 +812,7 @@ def apply_float8_static_activation_quant(weight: torch.Tensor): block_size=block_size, target_dtype=weight_dtype, scale_dtype=torch.float32, - layout_type=Float8LayoutType(mm_config=mm_config), + _layout=Float8Layout(mm_config=mm_config), ) input_quant_func = _input_activation_quant_func_fp8 @@ -863,14 +863,14 @@ def apply_uintx_weight_only_quant(weight, dtype): zero_point_dtype = None zero_point_domain = ZeroPointDomain.FLOAT preserve_zero = False - layout_type = PlainLayoutType() + _layout = PlainLayout() else: quant_min, quant_max = None, None eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int32 zero_point_domain = ZeroPointDomain.INT preserve_zero = True - layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim) + _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) return to_affine_quantized_intx( weight, mapping_type, block_size, dtype, @@ -878,7 +878,7 @@ def apply_uintx_weight_only_quant(weight, dtype): eps=eps, zero_point_dtype=zero_point_dtype, zero_point_domain=zero_point_domain, preserve_zero=preserve_zero, - layout_type=layout_type, + _layout=_layout, use_hqq=use_hqq, ) @@ -896,7 +896,7 @@ def fpx_weight_only(ebits: int, mbits: int): """ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: - from torchao.dtypes.floatx import FloatxTensorCoreLayoutType + from torchao.dtypes.floatx import FloatxTensorCoreLayout from torchao.dtypes import to_affine_quantized_fpx assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" @@ -908,8 +908,8 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: "expected in_dim % 64 == 0 and out_dim % 256 == 0") return weight - layout_type = FloatxTensorCoreLayoutType(ebits, mbits) - return to_affine_quantized_fpx(weight, layout_type) + _layout = FloatxTensorCoreLayout(ebits, mbits) + return to_affine_quantized_fpx(weight, _layout) return _get_linear_subclass_inserter(apply_quant_llm) diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index 7748c93be..e644bd16d 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -53,11 +53,11 @@ Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regress ```py from torchao.quantization.quant_api import quantize_, int4_weight_only -from torchao.dtypes import MarlinSparseLayoutType +from torchao.dtypes import MarlinSparseLayout # Your FP16 model model = model.cuda().half() -quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) +quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) ``` Note the existing API results in an extremely high accuracy degredation and is intended to be used in concert with an already sparsified+finetuned checkpoint where possible until we develop @@ -69,17 +69,17 @@ We support composing int8 dynaic quantization with 2:4 sparsity. We fuse one of ```py from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight -from torchao.dtypes import SemiSparseLayoutType +from torchao.dtypes import SemiSparseLayout model = model.cuda() -quantize_(model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())) +quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())) ``` ### 2:4 sparsity ```py from torchao.sparsity.sparse_api import sparsify_, semi_sparse_weight -from torchao.dtypes import SemiSparseLayoutType +from torchao.dtypes import SemiSparseLayout model = model.cuda() sparsify_(model, semi_sparse_weight()) diff --git a/torchao/sparsity/prototype/superblock/utils.py b/torchao/sparsity/prototype/superblock/utils.py index cf865fd36..e0cf4a177 100644 --- a/torchao/sparsity/prototype/superblock/utils.py +++ b/torchao/sparsity/prototype/superblock/utils.py @@ -146,12 +146,12 @@ def accelerate_with_sparsity(model, args): if args.sparsity == "bsr": apply_sparsity(model) if args.quantization: - from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType + from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout quantize_( model, int8_dynamic_activation_int8_weight( - layout_type=BlockSparseLayoutType(blocksize=args.bsr) + _layout=BlockSparseLayout(blocksize=args.bsr) ), superblock_only, ) @@ -160,11 +160,11 @@ def accelerate_with_sparsity(model, args): sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) elif args.sparsity == "semi_structured": if args.quantization: - from torchao.dtypes.affine_quantized_tensor import SemiSparseLayoutType + from torchao.dtypes.affine_quantized_tensor import SemiSparseLayout quantize_( model, - int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), mlp_0_only, ) sparsify_(model, semi_sparse_weight(), mlp_3_only) diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index ae343add9..98bfa3b30 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -48,8 +48,8 @@ def sparsify_( Currently, we support three options for sparsity: - semi-structured (2:4) sparsity with `semi_sparse_weight` - - int8 dynamic quantization + 2:4 sparsity with `layout_type=SemiSparseLayoutType` - - int4 weight-only quantization + 2:4 sparsity with `layout_type=SparseMarlinLayoutType` + - int8 dynamic quantization + 2:4 sparsity with `layout=SemiSparseLayout` + - int4 weight-only quantization + 2:4 sparsity with `layout=SparseMarlinLayout` Args: model (torch.nn.Module): input model @@ -72,8 +72,8 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: m = sparsify_(m, semi_sparse_weight(), filter_fn) # for int8 dynamic quantization + 2:4 sparsity - from torchao.dtypes import SemiSparseLayoutType - m = quantize_(m, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType), filter_fn) + from torchao.dtypes import SemiSparseLayout + m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model, diff --git a/torchao/utils.py b/torchao/utils.py index 36bc1be36..bcfb23348 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -392,49 +392,49 @@ class MyTensor(torch.Tensor): kwarg_types = {k: type(arg) for k, arg in kwargs} raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}") -def _register_layout(cls: Callable, layout_type_class: Callable): +def _register_layout(tensor_class: Callable, layout_class: Callable): """Helper function for layout registrations, this is used to implement register_layout decorator for each tensor subclass, see aqt.py for example usage Args: - cls: Tensor subclass type - layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` + tensor_class: Tensor subclass type + layout_class: the class type of subclass of `Layout`, e.g. `PlainLayout` Returns: a decorator that registers the tensor impl constructor in the table """ - # cls._LAYOUT_CONSTRUCTOR_TABLE is a map from layout_type_class like TensorCoreTiledLayout + # tensor_class._LAYOUT_CONSTRUCTOR_TABLE is a map from layout_class like TensorCoreTiledLayout # to tensor_impl class constructor like TensorCoreTiledAQTTensorImpl.from_plain that can construct a tensor_impl # from plain data like (quantized, unpacked) `data`, `scale`, `zero_point` - if not hasattr(cls, "_LAYOUT_CONSTRUCTOR_TABLE"): - cls._LAYOUT_CONSTRUCTOR_TABLE = {} + if not hasattr(tensor_class, "_LAYOUT_CONSTRUCTOR_TABLE"): + tensor_class._LAYOUT_CONSTRUCTOR_TABLE = {} def decorator(tensor_impl_class): - cls._LAYOUT_CONSTRUCTOR_TABLE[layout_type_class] = tensor_impl_class.from_plain + tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] = tensor_impl_class.from_plain if TORCH_VERSION_AT_LEAST_2_5: # Allow serialization to work for models uses this tensor impl subclass - torch.serialization.add_safe_globals([layout_type_class, tensor_impl_class]) + torch.serialization.add_safe_globals([layout_class, tensor_impl_class]) return tensor_impl_class return decorator -def _get_tensor_impl_constructor(cls: Callable, layout_type_class: Callable) -> Callable: - """Get TensorImpl class constructor (TensorImplClass.from_plain) for `cls` based on `layout_type_class` - `layout_type_class` means the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` +def _get_tensor_impl_constructor(tensor_class: Callable, layout_class: Callable) -> Callable: + """Get TensorImpl class constructor (TensorImplClass.from_plain) for `tensor_class` based on `layout_class` + `layout_class` means the class type of subclass of `Layout`, e.g. `PlainLayout` Args: - cls: Tensor subclass type - layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` + tensor_class: Tensor subclass type + layout_class: the class type of subclass of `Layout`, e.g. `PlainLayout` Returns: - tensor impl subclass constructor for the layout_type_class + tensor impl subclass constructor for the layout_class """ - if not hasattr(cls, "_LAYOUT_CONSTRUCTOR_TABLE"): - raise ValueError(f"no registered tensor_impl class constructor for: {cls}") - if layout_type_class not in cls._LAYOUT_CONSTRUCTOR_TABLE: - raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}") + if not hasattr(tensor_class, "_LAYOUT_CONSTRUCTOR_TABLE"): + raise ValueError(f"no registered tensor_impl class constructor for: {tensor_class}") + if layout_class not in tensor_class._LAYOUT_CONSTRUCTOR_TABLE: + raise ValueError(f"layout_name: {layout_class} is not supported yet for {tensor_class}") - return cls._LAYOUT_CONSTRUCTOR_TABLE[layout_type_class] + return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] class TorchAOBaseTensor(torch.Tensor): @@ -460,15 +460,15 @@ def _(func, types, args, kwargs): `register_layout`: register_layout = MyTensor.register_layout - @register_layout(PlainLayoutType) + @register_layout(PlainLayout) class PlainAQTTensorImpl(...): ... `get_tensor_impl_constructor`: get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # in constructor of MyTensor: - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(data, scale, zero_point, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) """ implements = classmethod(_implements) diff --git a/tutorials/calibration_flow/awq_like.py b/tutorials/calibration_flow/awq_like.py index b71933e3b..003959742 100644 --- a/tutorials/calibration_flow/awq_like.py +++ b/tutorials/calibration_flow/awq_like.py @@ -14,7 +14,7 @@ from torchao.dtypes import ( to_affine_quantized_intx_static, to_affine_quantized_floatx_static, - Float8LayoutType, + Float8Layout, ) from torchao.quantization.utils import compute_error from torchao.quantization import quantize_ @@ -73,7 +73,7 @@ def weight_quant_func(weight): if target_dtype == torch.uint8: return to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) elif target_dtype == torch.float8_e4m3fn: - return to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8LayoutType(mm_config=None)) + return to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8Layout(mm_config=None)) else: raise ValueError(f"Unsupported target dtype {target_dtype}") linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index 31d2be201..b96d7da05 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -9,7 +9,7 @@ from torchao.dtypes import ( to_affine_quantized_intx_static, to_affine_quantized_floatx_static, - Float8LayoutType, + Float8Layout, ) from torchao.quantization.utils import compute_error from torchao.quantization import quantize_ @@ -68,7 +68,7 @@ def weight_quant_func(weight): if target_dtype == torch.uint8: return to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) elif target_dtype == torch.float8_e4m3fn: - return to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8LayoutType(mm_config=None)) + return to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8Layout(mm_config=None)) else: raise ValueError(f"Unsupported target dtype {target_dtype}") linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) @@ -82,7 +82,7 @@ def weight_quant_func(weight): if target_dtype == torch.uint8: input_quant_func = lambda x: to_affine_quantized_intx_static(x, act_scale, act_zero_point, x.shape, target_dtype) elif target_dtype == torch.float8_e4m3fn: - input_quant_func = lambda x: to_affine_quantized_floatx_static(x, act_scale, x.shape, target_dtype, Float8LayoutType(mm_config=None)) + input_quant_func = lambda x: to_affine_quantized_floatx_static(x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None)) else: raise ValueError(f"Unsupported target dtype {target_dtype}") linear.weight = torch.nn.Parameter(to_linear_activation_quantized(linear.weight, input_quant_func), requires_grad=False) @@ -104,7 +104,7 @@ def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module if self.target_dtype == torch.uint8: self.qweight = to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, self.target_dtype) elif self.target_dtype == torch.float8_e4m3fn: - self.qweight = to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8LayoutType(mm_config=None)) + self.qweight = to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8Layout(mm_config=None)) else: raise ValueError(f"Unsupported target dtype {self.target_dtype}") @@ -113,7 +113,7 @@ def forward(self, input: Tensor): if self.target_dtype == torch.uint8: qinput = to_affine_quantized_intx_static(input, self.act_scale, self.act_zero_point, block_size, self.target_dtype) elif self.target_dtype == torch.float8_e4m3fn: - qinput = to_affine_quantized_floatx_static(input, self.act_scale, block_size, self.target_dtype, Float8LayoutType(mm_config=None)) + qinput = to_affine_quantized_floatx_static(input, self.act_scale, block_size, self.target_dtype, Float8Layout(mm_config=None)) else: raise ValueError(f"Unsupported target dtype {self.target_dtype}") return F.linear(qinput, self.qweight, self.bias) diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index c714df2a7..69c0bf956 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -22,8 +22,8 @@ dequantize_affine, ) from torchao.dtypes.utils import ( - LayoutType, - PlainLayoutType, + Layout, + PlainLayout, ) from torchao.utils import ( TorchAOBaseTensor, @@ -43,25 +43,25 @@ class MyDTypeTensorImpl(torch.Tensor): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: return self.int_data, self.scale - def get_layout_type(self) -> LayoutType: - return self.layout_type + def get_layout(self) -> Layout: + return self._layout @classmethod def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): - """Construct a tensor impl from plain tensors and a layout_type, which main contain + """Construct a tensor impl from plain tensors and a _layout, which main contain extra metadata for packing etc. """ pass def __repr__(self): int_data, scale = self.get_plain() - layout_type = self.get_layout_type() - return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, layout_type={layout_type})" + _layout = self.get_layout() + return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, _layout={_layout})" __torch_function__ = torch._C._disabled_torch_function_impl @@ -145,23 +145,23 @@ def __tensor_unflatten__( def from_float( cls, input_float: torch.Tensor, - layout_type: LayoutType = PlainLayoutType(), + _layout: Layout = PlainLayout(), ): mapping_type = MappingType.SYMMETRIC block_size = (1, input_float.shape[-1]) dtype = torch.int16 scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, dtype) int_data = quantize_affine(input_float, block_size, scale, zero_point, dtype) - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(int_data, scale, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(int_data, scale, _layout) return cls(tensor_impl, input_float.shape) """[Optional] We can overwrite layout property of the Tensor to represent different packing formats """ @property - def layout_type(self) -> LayoutType: - return self.tensor_impl.layout_type + def _layout(self) -> Layout: + return self.tensor_impl._layout def dequantize(self, output_dtype=None): """We can define a dequantize method to convert the quantized tensor to a floating point tensor""" @@ -206,20 +206,20 @@ def _apply_fn_to_data(self, fn): """ ###################################################### -# LayoutType and TensorImpl Subclass Registration # +# Layout and TensorImpl Subclass Registration # ###################################################### register_layout = MyDTypeTensor.register_layout get_tensor_impl_constructor = MyDTypeTensor.get_tensor_impl_constructor -@register_layout(PlainLayoutType) +@register_layout(PlainLayout) class PlainMyDTypeTensorImpl(MyDTypeTensorImpl): def __new__( cls, int_data: torch.Tensor, scale: torch.Tensor, transposed: bool, - layout_type: LayoutType, + _layout: Layout, ): kwargs = {} kwargs["device"] = int_data.device @@ -236,43 +236,43 @@ def __init__( int_data: torch.Tensor, scale: torch.Tensor, transposed: bool, - layout_type: LayoutType, + _layout: Layout, ): self.int_data = int_data self.scale = scale self.transposed = transposed - self.layout_type = layout_type + self._layout = _layout def __tensor_flatten__(self): - return ["int_data", "scale"], [self.transposed, self.layout_type] + return ["int_data", "scale"], [self.transposed, self._layout] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"] - transposed, layout_type, = tensor_attributes - return cls(int_data, scale, transposed, layout_type) + transposed, _layout, = tensor_attributes + return cls(int_data, scale, transposed, _layout) @classmethod def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ): - """Construct a tensor impl from plain tensors and a layout_type, which main contain + """Construct a tensor impl from plain tensors and a _layout, which main contain extra metadata for packing etc. """ - assert isinstance(layout_type, PlainLayoutType) - return cls(int_data, scale, False, layout_type) + assert isinstance(_layout, PlainLayout) + return cls(int_data, scale, False, _layout) def _apply_fn_to_data(self, fn): return self.__class__( fn(self.int_data), fn(self.scale), self.transposed, - self.layout_type, + self._layout, ) @classmethod @@ -292,11 +292,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.split.Tensor: int_data_list = func(args[0].int_data, *args[1:], **kwargs) scale_list = func(args[0].scale, *args[1:], **kwargs) - out = [PlainMyDTypeTensorImpl(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)] + out = [PlainMyDTypeTensorImpl(int_data, scale, args[0].transposed, args[0]._layout) for int_data, scale in zip(int_data_list, scale_list)] return out elif func is aten.empty_like.default: int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs) - return PlainMyDTypeTensorImpl(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type) + return PlainMyDTypeTensorImpl(int_data_empty_like, args[0].scale, args[0].transposed, args[0]._layout) elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: @@ -304,11 +304,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) ) elif dim == 1: - return PlainMyDTypeTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type) + return PlainMyDTypeTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self._layout) else: raise NotImplementedError(f"PlainMyDTypeTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") elif func is aten.t.default: - return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeTensorImpl(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) + return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeTensorImpl(args[0].int_data, args[0].scale, not args[0].transposed, args[0]._layout)) # Tensor parallel support END diff --git a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py index 59e72efb6..bc37c9783 100644 --- a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py @@ -14,7 +14,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType -from torchao.dtypes.utils import LayoutType, PlainLayoutType +from torchao.dtypes.utils import Layout, PlainLayout from my_dtype_tensor_subclass import MyDTypeLayout, MyDTypeTensor aten = torch.ops.aten @@ -33,7 +33,7 @@ class MyTrainableDTypeTensor(MyDTypeTensor): def _quantize( cls, input_float: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ) -> MyDTypeLayout: """ Convert from a floating point tensor (fp32/fp16/bf16) to the desired dtype. @@ -43,14 +43,14 @@ def _quantize( dtype = torch.int16 scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) int_data = (input_float / scale).to(torch.int8) - tensor_impl_ctr = cls.get_tensor_impl_constructor(type(layout_type)) - return tensor_impl_ctr(int_data, scale, layout_type) + tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout)) + return tensor_impl_ctr(int_data, scale, _layout) @classmethod def from_float( cls, input_float: torch.Tensor, - layout_type: LayoutType = PlainLayoutType(), + _layout: Layout = PlainLayout(), ) -> "MyTrainableDTypeTensor": """ Main entry point for creating a `MyTrainableDTypeTensor`. @@ -58,7 +58,7 @@ def from_float( This instantiates the tensor subclass in a differentiable constructor to ensure gradients are passed to the tensor subclass properly during training. """ - return _ToMyTrainableDTypeTensor.apply(input_float, layout_type) + return _ToMyTrainableDTypeTensor.apply(input_float, _layout) class _ToMyTrainableDTypeTensor(torch.autograd.Function): """ @@ -69,9 +69,9 @@ class _ToMyTrainableDTypeTensor(torch.autograd.Function): def forward( ctx: torch.autograd.function.FunctionCtx, input_float: torch.Tensor, - layout_type: LayoutType, + _layout: Layout, ) -> "MyTrainableDTypeTensor": - tensor_impl = MyTrainableDTypeTensor._quantize(input_float, layout_type) + tensor_impl = MyTrainableDTypeTensor._quantize(input_float, _layout) return MyTrainableDTypeTensor( tensor_impl, input_float.shape, @@ -143,7 +143,7 @@ def _(func, types, args, kwargs): new_value = torch.add(float0, float1, **kwargs) new_tensor_impl = MyTrainableDTypeTensor._quantize( new_value, - args[0].tensor_impl.get_layout_type(), + args[0].tensor_impl.get_layout(), ) args[0].tensor_impl = new_tensor_impl return return_and_correct_aliasing(func, args, kwargs, args[0])