Skip to content

Commit

Permalink
Rename AQT#2 LayoutType -> Layout (#1049)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Oct 10, 2024
1 parent 10601b3 commit 7038f8b
Show file tree
Hide file tree
Showing 34 changed files with 358 additions and 361 deletions.
4 changes: 2 additions & 2 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import pandas as pd
import torch.nn.functional as F
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2))
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2))
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
Expand Down
4 changes: 2 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
)
from torchao.dtypes import SemiSparseLayoutType
from torchao.dtypes import SemiSparseLayout
from torch.testing._internal import common_utils
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand All @@ -31,7 +31,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
base_functions.append(int4_weight_only(group_size=32))

if do_sparse:
base_functions.append(int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()))
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))

if is_cuda_8_9:
base_functions.append(float8_weight_only())
Expand Down
6 changes: 3 additions & 3 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from torchao.dtypes.floatx import (
FloatxTensorCoreAQTTensorImpl,
FloatxTensorCoreLayoutType,
FloatxTensorCoreLayout,
to_scaled_tc_floatx,
from_scaled_tc_floatx,
)
Expand Down Expand Up @@ -81,8 +81,8 @@ def test_to_copy_device(self, ebits, mbits):
x = torch.randn(256, 64)
scale = choose_qparams_affine_floatx(x, ebits, mbits)
x = quantize_affine_floatx(x, scale, ebits, mbits)
layout_type = FloatxTensorCoreLayoutType(ebits, mbits)
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda()
_layout = FloatxTensorCoreLayout(ebits, mbits)
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, _layout).cuda()
assert floatx_tensor_impl.device.type == "cuda"
floatx_tensor_impl = floatx_tensor_impl.cpu()
assert floatx_tensor_impl.device.type == "cpu"
Expand Down
4 changes: 2 additions & 2 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
to_affine_quantized_intx,
ZeroPointDomain,
PlainAQTTensorImpl,
PlainLayoutType,
PlainLayout,
TensorCoreTiledAQTTensorImpl,
TensorCoreTiledLayoutType,
TensorCoreTiledLayout,
MappingType,
)

Expand Down
6 changes: 3 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchao.quantization.dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from torchao.dtypes import TensorCoreTiledLayoutType
from torchao.dtypes import TensorCoreTiledLayout
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
Expand Down Expand Up @@ -876,7 +876,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)}
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}

def api(mod):
kwargs_copy = kwargs.copy()
Expand All @@ -888,7 +888,7 @@ def api(mod):
unwrap_tensor_subclass(mod)
else:
kwargs_copy["inner_k_tiles"] = inner_k_tiles
del kwargs_copy["layout_type"]
del kwargs_copy["layout"]
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)

self._test_lin_weight_subclass_api_impl(
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.dtypes import (
TensorCoreTiledLayoutType,
TensorCoreTiledLayout,
)
from torchao.quantization.prototype.qat.api import (
ComposableQATQuantizer,
Expand Down
6 changes: 3 additions & 3 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.dtypes import MarlinSparseLayoutType
from torchao.dtypes import MarlinSparseLayout
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.quantization.quant_api import int4_weight_only, quantize_
from torchao.sparsity.marlin import (
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_quant_sparse_marlin_layout_eager(self):
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
Expand All @@ -67,7 +67,7 @@ def test_quant_sparse_marlin_layout_compile(self):
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
sparse_result = self.model(self.input)

Expand Down
10 changes: 5 additions & 5 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import nn
from torch.testing._internal import common_utils
from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType
from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout
from torchao.quantization.quant_api import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_quant_semi_sparse(self, compile):

quantize_(
model,
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()),
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
)
if compile:
model = torch.compile(model)
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_sparse_marlin(self, compile):
dense_result = model_copy(input.bfloat16()).half()

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if compile:
model = torch.compile(model)
sparse_result = model(input)
Expand Down Expand Up @@ -185,12 +185,12 @@ def test_sparse(self, compile):
quantize_(model_copy, int8_dynamic_activation_int8_weight())
reference = model_copy(input)

from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout

quantize_(
model,
int8_dynamic_activation_int8_weight(
layout_type=BlockSparseLayoutType(blocksize=64)
layout=BlockSparseLayout(blocksize=64)
),
)
if compile:
Expand Down
4 changes: 2 additions & 2 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def run_evaluation(
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "int4wo" in quantization and "gptq" in quantization:
# avoid circular imports
from torchao._models._eval import InputRecorder
Expand Down
4 changes: 2 additions & 2 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def main(
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model, int4_weight_only(group_size=groupsize))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if quantization.startswith("awq"):
Expand Down
8 changes: 4 additions & 4 deletions torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight
from torchao.dtypes import SemiSparseLayoutType, MarlinSparseLayoutType
from torchao.dtypes import SemiSparseLayout, MarlinSparseLayout
from torchao.utils import unwrap_tensor_subclass
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand Down Expand Up @@ -315,7 +315,7 @@ def mlp_only(mod, name):
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()),
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
mlp_lin1_only)
sparsify_(predictor.model.image_encoder,
semi_sparse_weight(),
Expand All @@ -326,11 +326,11 @@ def mlp_only(mod, name):
# apply sparsify first to set qparams
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)
from torchao.dtypes import MarlinSparseLayoutType
from torchao.dtypes import MarlinSparseLayout
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(predictor.model.image_encoder, int4_weight_only(layout_type=MarlinSparseLayoutType()), mlp_lin1_only)
quantize_(predictor.model.image_encoder, int4_weight_only(layout=MarlinSparseLayout()), mlp_lin1_only)
sparsify_(predictor.model.image_encoder,
semi_sparse_weight(),
mlp_lin2_only)
Expand Down
24 changes: 12 additions & 12 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
to_affine_quantized_fpx,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
LayoutType,
PlainLayoutType,
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
Float8LayoutType,
Layout,
PlainLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
Float8Layout,
Float8AQTTensorImpl,
MarlinSparseLayoutType,
MarlinSparseLayout,
)

__all__ = [
Expand All @@ -28,11 +28,11 @@
"to_affine_quantized_fpx",
"to_affine_quantized_floatx",
"to_affine_quantized_floatx_static",
"LayoutType",
"PlainLayoutType",
"SemiSparseLayoutType",
"TensorCoreTiledLayoutType",
"Float8LayoutType",
"Layout",
"PlainLayout",
"SemiSparseLayout",
"TensorCoreTiledLayout",
"Float8Layout",
"Float8AQTTensorImpl",
"MarlinSparseLayoutType",
"MarlinSparseLayout",
]
Loading

0 comments on commit 7038f8b

Please sign in to comment.