From bc851466da1ac690977159509b0a74a7390081bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Mon, 9 Sep 2024 21:32:06 +0200 Subject: [PATCH] W4A8 based on CUTLASS --- docs/source/api_ref_quantization.rst | 1 + setup.py | 7 + test/dtypes/test_affine_quantized.py | 2 + test/quantization/test_quant_api.py | 1 + test/test_s8s4_linear_cutlass.py | 51 ++ .../s8s4_linear_cutlass.cu | 487 ++++++++++++++++++ torchao/csrc/s8s4_linear_cutlass.cpp | 8 + torchao/dtypes/affine_quantized_tensor.py | 61 +++ torchao/ops.py | 33 ++ torchao/quantization/quant_api.py | 47 ++ 10 files changed, 698 insertions(+) create mode 100644 test/test_s8s4_linear_cutlass.py create mode 100644 torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu create mode 100644 torchao/csrc/s8s4_linear_cutlass.cpp diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index f6c842c88..ee1100843 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -18,6 +18,7 @@ torchao.quantization Int4WeightOnlyQuantizer quantize_ int8_dynamic_activation_int4_weight + int8_dynamic_activation_int4_weight_cutlass int8_dynamic_activation_int8_weight int4_weight_only int8_weight_only diff --git a/setup.py b/setup.py index 229e18eec..400655fdf 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,12 @@ def get_extensions(): extension = CUDAExtension if use_cuda else CppExtension if not IS_WINDOWS: + import cutlass_library + cutlass_library_dir = os.path.dirname(cutlass_library.__file__) + cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include") + # FIXME: remove this once CUTLASS package updated to include int4/int8 MM + cutlass_include_dir = "/data/quansight/scratch/cutlass/include" + extra_link_args = [] extra_compile_args = { "cxx": [ @@ -74,6 +80,7 @@ def get_extensions(): "nvcc": [ "-O3" if not debug_mode else "-O0", "-t=0", + "-I" + cutlass_include_dir, ] } diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 2265be31e..148a7cbf5 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -6,6 +6,7 @@ int4_weight_only, int8_weight_only, int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int4_weight_cutlass, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int8_semi_sparse_weight, float8_weight_only, @@ -25,6 +26,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): base_functions = [ int8_weight_only(), int8_dynamic_activation_int4_weight(), + int8_dynamic_activation_int4_weight_cutlass(), int8_dynamic_activation_int8_weight(), ] if do_int4: diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 310f51f89..74113b334 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -39,6 +39,7 @@ Quantizer, TwoStepQuantizer, int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int4_weight_cutlass, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py new file mode 100644 index 000000000..4eeae25da --- /dev/null +++ b/test/test_s8s4_linear_cutlass.py @@ -0,0 +1,51 @@ +# FIXME: move this test to the appropriate test file!!! + +import copy + +from torchao.quantization import quantize_ +from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight_cutlass + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +import pytest + + +class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 256) + self.linear2 = torch.nn.Linear(256, 128, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.functional.relu(x) + x = self.linear2(x) + return x + + +class TestS8S4LinearCUTLASS(TestCase): + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_s8s4_linear_cutlass_(self): + # FIXME: remove this! + torch.manual_seed(0) + + dtype = torch.float16 # torch.bfloat16 + + input = torch.rand((64, 128)).to(dtype).cuda() + model = ToyModel().to(dtype).cuda() + + output_ref = model(input) + + modelq = copy.deepcopy(model) + quantize_(modelq, int8_dynamic_activation_int4_weight_cutlass()) + output = modelq(input) + + assert torch.allclose(output, output_ref, rtol=1e-1, atol=0) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu new file mode 100644 index 000000000..e60a69a7e --- /dev/null +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -0,0 +1,487 @@ +#include + +#include +#include +#include + +#if defined(_MSC_VER) || (CUDA_VERSION < 11080) +#else +#include +#include +#include +#include +#include + +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + __func__, " : Got CUTLASS error: ", \ + cutlassGetStatusString(status)); \ + } +#endif + +namespace torchao { + +#if defined(_MSC_VER) || (CUDA_VERSION < 11080) +#else +template< + typename ElementA, + typename ElementAScale, + typename ElementB, + typename ElementBScale, + typename ElementC, + typename ElementAccumulator, + typename ElementEpilogue, + typename ElementOutput, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + bool use_tensor_c> +void s8s4_linear_kernel_cutlass( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + const int m = tensor_a.size(0); + const int n = tensor_b.size(0); + const int k = tensor_a.size(1); + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentAScale = + 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentBScale = + 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentOutput = + 128 / cutlass::sizeof_bits::value; + + // FIXME: re-check this!!! + // Check for current CUTLASS limitations w.r.t. alignments. + TORCH_CHECK(k % AlignmentA == 0, + __func__, " : Number of columns of tensor A must be divisible ", + "by ", AlignmentA); + TORCH_CHECK(k % AlignmentB == 0, + __func__, " : Number of columns of tensor B must be divisible ", + "by ", AlignmentB); + TORCH_CHECK(n % AlignmentC == 0, + __func__, " : Number of columns of tensor C must be divisible ", + "by ", AlignmentC); + + using SmArch = cutlass::arch::Sm80; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + constexpr auto NumStages = 4; + + constexpr auto NumEVTEpilogueStages = 1; + + using TensorAScaleTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementAScale, + AlignmentAScale, + NumEVTEpilogueStages>; + using TensorBScaleTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementBScale, + AlignmentBScale, + NumEVTEpilogueStages>; + using TensorCTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementC, + AlignmentC, + NumEVTEpilogueStages>; + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementOutput, + AlignmentOutput, + NumEVTEpilogueStages>; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using TensorAScale = + cutlass::epilogue::threadblock::VisitorColBroadcast< + TensorAScaleTileThreadMap, + ElementAScale, + cute::Stride>; + using TensorAScaleArguments = typename TensorAScale::Arguments; + + using TensorBScale = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorBScaleTileThreadMap, + ElementBScale, + cute::Stride>; + using TensorBScaleArguments = typename TensorBScale::Arguments; + + using TensorCScalar = + cutlass::epilogue::threadblock::VisitorScalarBroadcast; + using TensorCTensor = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorCTileThreadMap, + ElementC, + cute::Stride>; + using TensorC = + std::conditional_t; + using TensorCArguments = typename TensorC::Arguments; + + using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyAScale, + Accum, + TensorAScale>; + + using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyBScale, + EVTApplyAScale, + TensorBScale>; + + using ApplySum = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT< + ApplySum, + EVTApplyBScale, + TensorC>; + + using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, + cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride // StrideMNL + >; + + using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< + Output, + EVTApplySum>; + + using EVTKernel = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, LayoutC, AlignmentC, + ElementAccumulator, + ElementEpilogue, + cutlass::arch::OpClassTensorOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EVTOutput, + ThreadblockSwizzle, + NumStages, + cutlass::arch::OpMultiplyAddMixedInputUpcast, + NumEVTEpilogueStages + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalBase; + + cutlass::gemm::GemmCoord problem_size(m, n, k); + constexpr auto SplitKFactor = 1; + + TensorAScaleArguments tensor_a_scale_arguments{ + (ElementAScale*)tensor_a_scale.data_ptr(), + ElementAScale(1), + {cute::_1{}, cute::_0{}, problem_size.m()} + }; + TensorBScaleArguments tensor_b_scale_arguments{ + (ElementBScale*)tensor_b_scale.data_ptr(), + ElementBScale(1), + {cute::_0{}, cute::_1{}, problem_size.n()} + }; + TensorCArguments tensor_c_arguments{ + [&]() -> TensorCArguments { + if constexpr (use_tensor_c) { + return {(ElementC*)tensor_c.data_ptr(), + ElementC(0), + {cute::_0{}, cute::_1{}, problem_size.n()}}; + } else { + return {ElementC(0)}; + } + }() + }; + typename Output::Arguments output_arguments{ + (ElementOutput*)tensor_d.data_ptr(), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename EVTOutput::Arguments callback_arguments{ + { + { + { + {}, // Accum + tensor_a_scale_arguments, // TensorAScale + {} // ApplyAScale + }, // EVTApplyAScale + tensor_b_scale_arguments, // TensorBScale + {}, // ApplyBScale + }, // EVTApplyBScale + tensor_c_arguments, // TensorC + {} // ApplySum + }, // EVTApplySum + output_arguments // Output + }; // EVTOutput + constexpr auto AvailSms = -1; + + typename Gemm::Arguments arguments( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + SplitKFactor, + callback_arguments, // arguments of EVT callbacks + (ElementA*)tensor_a.data_ptr(), + (ElementB*)tensor_b.data_ptr(), + nullptr, // ptr C (unused) + nullptr, // ptr D (unused) + problem_size.mk().product(), // batch stride A + problem_size.nk().product(), // batch stride B + 0, // batch stride C (unused) + 0, // batch stride D (unused) + problem_size.k(), // stride A + problem_size.k(), // stride B + 0, // stride C (unused) + 0, // stride D (unused) + AvailSms); + + Gemm gemm_op; + + cutlass::Status status; + + // Verify that GEMM operation with given arguments can be performed + // by CUTLASS. + status = gemm_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status); + + // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. + const auto workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = tensor_a.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize CUTLASS mixed datatypes GEMM object. + status = gemm_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + + // Perform mixed datatypes GEMM operation. + status = gemm_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); +} +#endif + +// Perform linear operation, using corresponding CUTLASS mixed +// data-types GEMM kernel, to given arguments: +// result = (input * input_scale) @ (weight * weight_scale).T + bias +// Notes: The "input_scale" tensor is expected to be a vector, of size +// equal to number of rows of "input" tensor. The "weight_scale" +// tensor is expected to be a vector, of size equal to number of rows +// of "weight" tensor. The "bias" tensor is expected to be a vector, +// of size equal to number of rows of "weight" tensor. +at::Tensor +s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, + const at::Tensor& weight, const at::Tensor& weight_scale, + const at::Tensor& bias) { +#if defined(_MSC_VER) || (CUDA_VERSION < 11080) + AT_ERROR(__func__, " : CUTLASS not supported"); + return at::Tensor{}; +#else + // For now, only CC 8.x devices are supported. + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major == 8; + TORCH_CHECK(is_sm8x, + __func__, " : Supported only on GPUs with compute capability " + "8.x"); + + // Validate datatypes of arguments. + TORCH_CHECK(input.dtype() == at::kChar, + __func__, " : The input datatype ", input.dtype(), + " not supported"); + TORCH_CHECK(input_scale.dtype() == at::kHalf || + input_scale.dtype() == at::kBFloat16, + __func__, " : The input scale datatype ", input_scale.dtype(), + " not supported"); + TORCH_CHECK(weight.dtype() == at::kChar, " : The weight datatype ", + weight.dtype(), " not supported"); + TORCH_CHECK(weight_scale.dtype() == input_scale.dtype(), + __func__, " : Expected weight scale datatype ", + input_scale.dtype(), ", got ", weight_scale.dtype()); + if (bias.numel() > 0) { + TORCH_CHECK(bias.dtype() == input_scale.dtype(), + __func__, " : Expected bias datatype ", input_scale.dtype(), + ", got ", bias.dtype()); + } + + // Validate layouts of arguments. + TORCH_CHECK(input.dim() >= 2, + __func__, " : Expected input argument to be 2D or " + "higher-dimensional tensor, got ", input.dim(), " dims"); + TORCH_CHECK(input.layout() == at::Layout::Strided, + __func__, " : Expected input argument to be strided, got layout ", + input.layout()); + TORCH_CHECK(input_scale.dim() == input.dim() - 1, + __func__, " : Expected input scale argument to be ", + input.dim() - 1, "D tensor, got ", input_scale.dim(), " dims"); + TORCH_CHECK(input_scale.layout() == at::Layout::Strided, + __func__, " : Expected input scale argument to be strided, got " + "layout ", input_scale.layout()); + TORCH_CHECK(weight.dim() == 2, + __func__, " : Expected weight argument to be 2D tensor, got ", + weight.dim(), " dims"); + TORCH_CHECK(weight.layout() == at::Layout::Strided, + __func__, + " : Expected weight argument to be strided, got layout ", + weight.layout()); + TORCH_CHECK(weight_scale.dim() == 1 || weight_scale.dim() == 2, + __func__, " : Expected weight scale argument to be 1D tensor, ", + "got ", weight_scale.dim(), " dims"); + TORCH_CHECK(weight_scale.layout() == at::Layout::Strided, + __func__, " : Expected weight scale argument to be strided, got " + "layout ", weight_scale.layout()); + if (bias.numel() > 0) { + TORCH_CHECK(bias.dim() == 1, + __func__, " : Expected bias argument to be 1D tensor, got ", + bias.dim(), " dims"); + TORCH_CHECK(bias.layout() == at::Layout::Strided, + __func__, " : Expected bias argument to be strided, got ", + "layout ", bias.layout()); + } + + // Squash the input tensor to 2D tensor. + const auto input_sizes = input.sizes().vec(); + const auto input_2d = input.reshape({-1, input_sizes.back()}); + const auto input_scale_sizes = input_scale.sizes().vec(); + const auto input_scale_1d = input_scale.reshape({-1}); + + // Validate sizes of arguments. + TORCH_CHECK(input_2d.size(1) == 2 * weight.size(1), + __func__, " : Expected input argument to have ", + 2 * weight.size(1), " columns, but got ", input_2d.size(1)); + for (auto i = 0; i < input_scale_sizes.size(); ++i) + TORCH_CHECK(input_scale_sizes[i] == input_sizes[i], + __func__, " : Expected input scale argument size at position ", + i, " to be ", input_sizes[i], ", but got ", + input_scale_sizes[i]); + TORCH_CHECK(weight_scale.numel() == weight.size(0), + __func__, " : Expected weight scale argument to have ", + weight.size(0), " elements, got ", weight_scale.numel(), + " elements"); + if (bias.numel() > 0) { + TORCH_CHECK(bias.numel() == weight.size(0), + __func__, " : Expected bias argument to have ", weight.size(0), + " elements, got ", bias.numel(), " elements"); + } + + // Validate strides of arguments. + const auto input_2d_strides = input_2d.strides(); + TORCH_CHECK(input_2d_strides[0] >= 1 && input_2d_strides[1] == 1, + __func__, " : Expected input argument in row-major layout"); + const auto input_scale_1d_strides = input_scale_1d.strides(); + TORCH_CHECK(input_scale_1d_strides[0] == 1, + __func__, " : Expected input scale argument to be contiguous"); + const auto weight_strides = weight.strides(); + TORCH_CHECK(weight_strides[0] >= 1 && weight_strides[1] == 1, + __func__, " : Expected weight argument in row-major layout"); + const auto weight_scale_strides = weight_scale.strides(); + TORCH_CHECK(weight_scale_strides[0] == 1, + __func__, " : Expected weight scale argument to be contiguous"); + if (bias.numel() > 0) { + const auto bias_strides = bias.strides(); + TORCH_CHECK(bias_strides[0] == 1, + __func__, " : Expected bias argument to be contiguous"); + } + + // Introduce alias names for arguments, according to the CUTLASS + // naming conventions. + const auto& tensor_a = input_2d; + const auto& tensor_a_scale = input_scale_1d; + const auto& tensor_b = weight; + const auto& tensor_b_scale = weight_scale; + const auto& tensor_c = bias; + + // Create output tensor. + at::Tensor tensor_d = + tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); + + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + AT_DISPATCH_SWITCH( + input_scale.scalar_type(), + "s8s4_linear_cutlass", + AT_DISPATCH_CASE( + at::ScalarType::Half, + [&]() { + using ElementAScale = cutlass::half_t; + using ElementBScale = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementEpilogue = float; + using ElementOutput = cutlass::half_t; + if (bias.numel() > 0) { + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, true>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, false>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + }) + // FIXME: this won't build unless CUTLASS patch applied + AT_DISPATCH_CASE( + at::ScalarType::BFloat16, + [&]() { + using ElementAScale = cutlass::bfloat16_t; + using ElementBScale = cutlass::bfloat16_t; + using ElementC = cutlass::bfloat16_t; + using ElementEpilogue = float; + using ElementOutput = cutlass::bfloat16_t; + if (bias.numel() > 0) { + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, true>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + s8s4_linear_kernel_cutlass< + ElementA, ElementAScale, ElementB, ElementBScale, ElementC, + ElementAccumulator, ElementEpilogue, ElementOutput, + ThreadblockShape, WarpShape, InstructionShape, false>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + })); + + auto tensor_d_sizes = input_sizes; + tensor_d_sizes.back() = weight.size(0); + return tensor_d.reshape(tensor_d_sizes); +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass); +} + +} // namespace torchao diff --git a/torchao/csrc/s8s4_linear_cutlass.cpp b/torchao/csrc/s8s4_linear_cutlass.cpp new file mode 100644 index 000000000..cc82ff5bf --- /dev/null +++ b/torchao/csrc/s8s4_linear_cutlass.cpp @@ -0,0 +1,8 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"); +} diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index c2c8e3c0b..062c614c0 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1332,6 +1332,14 @@ def _aqt_is_tensor_core_tile_uint4(aqt): aqt.quant_max == 15 ) +def _aqt_is_tensor_core_tile_int4(aqt): + """Check if an AffineQuantizedTensor is int4 quantized Tensor""" + # TODO: use torch.int4 + return ( + aqt.layout_tensor.dtype == torch.int8 and + aqt.quant_min == -8 and + aqt.quant_max == 7 + ) implements = AffineQuantizedTensor.implements @@ -1346,6 +1354,7 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): isinstance(input_tensor, AffineQuantizedTensor) and _aqt_is_int8_reduced_range(input_tensor) and isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_int8(weight_tensor) and weight_tensor.is_cuda and input_tensor.dtype == weight_tensor.dtype and isinstance(input_tensor.layout_type, PlainLayoutType) and @@ -1721,6 +1730,57 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b return out +def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + # FIXME: refine these checks!!! + return ( + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8_reduced_range(input_tensor) and + input_tensor.dtype in (torch.float16, torch.bfloat16) and + len(input_tensor.shape) >= 2 and + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_tensor_core_tile_int4(weight_tensor) and + weight_tensor.dtype == input_tensor.dtype and + len(weight_tensor.shape) == 2 and + (bias is None or bias.dtype == input_tensor.dtype) and + (bias is None or len(bias.shape) == 1) + ) + +def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import s8s4_linear_cutlass + + assert isinstance(input_tensor, AffineQuantizedTensor) + assert isinstance(weight_tensor, AffineQuantizedTensor) + + input = input_tensor.layout_tensor.int_data + input_scale = input_tensor.layout_tensor.scale + + weight = weight_tensor.layout_tensor.int_data + weight_scale = weight_tensor.layout_tensor.scale + + out = s8s4_linear_cutlass( + input, input_scale, weight, weight_scale, bias + ) + + # FIXME: remove this! + """ + input_shape = input.shape + input_2d = input.view(-1, input.shape[-1]) + input_scale_1d = input_scale.view(-1) + m, k = input_2d.shape + n, _ = weight.shape + weight_orig = torch.stack(((weight << 4) >> 4, weight >> 4), dim=2).view(n, k) + # This is the calculation (well, not completely exact - as the matrix + # multiplication is actually over integers there) that + # s8s4_linear_cutlass() is performing. + out = (input_2d.to(input_scale.dtype) @ weight_orig.to(input_scale.dtype).T) * input_scale_1d.view(m, 1).expand(m, n) * weight_scale.expand(m, n) + if bias is not None: + out += bias + out = out.view(*input_shape[:-1], n) + """ + + return out + + def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), @@ -1732,6 +1792,7 @@ def _register_aqt_quantized_linear_dispatches(): (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl), (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), + (_linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/ops.py b/torchao/ops.py index 79c02dfd8..44b31b32a 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -275,3 +275,36 @@ def _( torch._check(workspace.numel() >= min_workspace_size, lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}") return torch.empty((x.size(0), s.size(1)), dtype=x.dtype, device=x.device) + + +def s8s4_linear_cutlass( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + # FIXME: write docs!!! + """ + """ + + return torch.ops.torchao.s8s4_linear_cutlass.default( + input, input_scale, weight, weight_scale, bias + ) + + +@register_custom_op(f"torchao::s8s4_linear_cutlass") +def _( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + # FIXME: implement all checks from s8s4_linear_cutlass() here!!! + + return torch.empty( + (*input.shape[:-1], weight.size(0)), + dtype=input_scale.dtype, + device=input.device + ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6c4142506..7ee92ce09 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -85,6 +85,7 @@ "_get_subclass_inserter", "quantize_", "int8_dynamic_activation_int4_weight", + "int8_dynamic_activation_int4_weight_cutlass", "int8_dynamic_activation_int8_weight", "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", @@ -430,6 +431,7 @@ def quantize_( # also customizable with arguments # currently options are # int8_dynamic_activation_int4_weight (for executorch) + # int8_dynamic_activation_int4_weight_cutlass (for correspodning W4A8 CUTLASS-based kernel) # int8_dynamic_activation_int8_weight (optimized with int8 mm op and torch.compile) # int4_weight_only (optimized with int4 tinygemm kernel and torch.compile) # int8_weight_only (optimized with int8 mm op and torch.compile @@ -511,6 +513,41 @@ def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType. return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type) +def apply_int8_dynamic_activation_int4_weight_quant_cutlass(weight): + """This is defined here instead of local function to support serialization + """ + # weight settings + mapping_type = MappingType.SYMMETRIC + block_size = (1, weight.shape[-1]) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + quant_min = -8 + quant_max = 7 + + # input settings + input_quant_func = _int8_symm_per_token_reduced_range_quant_cutlass + + weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) + weight = to_linear_activation_quantized(weight, input_quant_func) + + # FIXME: this should be done by to_affine_quantized_intx, also the + # dtype of quantized tensor maybe should be set to quint4x2, and + # then corresponding changes made in + # _linear_int8_act_int4_weight_cutlass_check and for the check in + # the CUTLASS kernel!!! + weight.original_weight_tensor.layout_tensor.int_data = ( + (weight.original_weight_tensor.layout_tensor.int_data[:, 1::2] & 0xF) << 4 + ) | (weight.original_weight_tensor.layout_tensor.int_data[:, 0::2] & 0xF) + + return weight + +def int8_dynamic_activation_int4_weight_cutlass(): + """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear + This is used to produce a model for CUTLASS-based W4A8 kernel + """ + return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant_cutlass) + + def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using @@ -580,8 +617,18 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: eps = 1e-5 quant_min = -127 quant_max = 127 + return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) +def _int8_symm_per_token_reduced_range_quant_cutlass(x: torch.Tensor) -> torch.Tensor: + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = 1e-5 + quant_min = -127 + quant_max = 127 + + return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float16 if x.dtype == torch.float16 else None) + def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()): """