From 67c12ddab209bd462c9e50d98cdcf14bf2945856 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:24:49 -0700 Subject: [PATCH] Move common ET/Aten op stuff to ops/library.h Differential Revision: D63983380 Pull Request resolved: https://github.com/pytorch/ao/pull/1116 --- torchao/experimental/ops/library.h | 34 +++++++ .../Linear8BitActXBitWeightOperator.h | 2 +- .../linear_8bit_act_xbit_weight.cpp | 2 +- .../op_linear_8bit_act_xbit_weight-impl.h | 89 ++++++++----------- .../packed_weights_header.h | 4 +- torchao/experimental/ops/macro.h | 13 --- torchao/experimental/quant_api.py | 6 +- ...t_linear_8bit_act_xbit_weight_quantizer.py | 6 +- 8 files changed, 81 insertions(+), 75 deletions(-) create mode 100644 torchao/experimental/ops/library.h delete mode 100644 torchao/experimental/ops/macro.h diff --git a/torchao/experimental/ops/library.h b/torchao/experimental/ops/library.h new file mode 100644 index 000000000..182441f9c --- /dev/null +++ b/torchao/experimental/ops/library.h @@ -0,0 +1,34 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#if defined(USE_ATEN) && !defined(USE_EXECUTORCH) +#pragma message("USE_ATEN") +#include +#include +#include +using Tensor = at::Tensor; +#define TORCHAO_CHECK(cond, msg) TORCH_CHECK(cond, msg) + +#elif defined(USE_EXECUTORCH) && !defined(USE_ATEN) +#pragma message("USE_EXECUTORCH") +#include +#include +using Tensor = torch::executor::Tensor; +using RuntimeContext = torch::executor::KernelRuntimeContext; +#define TORCHAO_CHECK(cond, msg) ET_CHECK_MSG(cond, msg) + +#elif !defined(USE_EXECUTORCH) && !defined(USE_ATEN) +#pragma message("Neither USE_ATEN or USE_EXECUTORCH defined") +#include + +#define TORCHAO_CHECK(cond, message) \ + if (!(cond)) { \ + throw std::runtime_error(message); \ + } + +#else +#error "Cannot define both USE_ATEN or USE_EXECUTORCH" +#endif diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h index 7d4e28f44..2250a6070 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h @@ -6,7 +6,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index e2fcbaa2f..32793b8f2 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include #include diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 46de55446..d635afa36 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -10,31 +10,13 @@ #include #endif // defined(__aarch64__) || defined(__ARM_NEON) +#include #include #include #include #include #include -#if defined(USE_ATEN) && !defined(USE_EXECUTORCH) -#pragma message("USE_ATEN") -#include -#include -#include -using Tensor = at::Tensor; -#define CHECK_MSG(cond, msg) TORCH_CHECK(cond, msg) - -#elif defined(USE_EXECUTORCH) && !defined(USE_ATEN) -#pragma message("USE_EXECUTORCH") -#include -#include -using Tensor = torch::executor::Tensor; -using RuntimeContext = torch::executor::KernelRuntimeContext; -#define CHECK_MSG(cond, msg) ET_CHECK_MSG(cond, msg) -#else -#error "Must define either USE_ATEN or USE_EXECUTORCH" -#endif - namespace { // This selects a UkernelConfig based on the packed weight header @@ -52,7 +34,7 @@ get_ukernel_config(torchao::ops::PackedWeightsHeader header) { channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; // Check packing params match the kernel - CHECK_MSG( + TORCHAO_CHECK( header == torchao::ops::linear_8bit_act_xbit_weight:: get_packed_weights_header_universal( @@ -80,9 +62,9 @@ get_ukernel_config(torchao::ops::PackedWeightsHeader header) { &ukernel::kernel; return config; break; - default: - CHECK_MSG(false, "Unsupported packed weights format"); #endif // defined(__aarch64__) || defined(__ARM_NEON) + default: + TORCHAO_CHECK(false, "Unsupported packed weights format"); } } @@ -103,8 +85,9 @@ Tensor pack_weights_cpu( const Tensor& weight_scales, const std::optional& weight_zeros, int64_t group_size) { - CHECK_MSG(weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8"); - CHECK_MSG(weight_qvals.dim() == 2, "weight_qvals must be 2D"); + TORCHAO_CHECK( + weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8"); + TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D"); // In PyTorch, weights are nxk in row-major format (with activations being // right-multiplied). @@ -114,25 +97,25 @@ Tensor pack_weights_cpu( int n = weight_qvals.size(0); int k = weight_qvals.size(1); - CHECK_MSG( + TORCHAO_CHECK( weight_scales.dtype() == torch::kFloat32, "weight_scales must be float32"); - CHECK_MSG(weight_scales.dim() == 1, "weight_scales must be 1D"); - CHECK_MSG(group_size >= 1, "group_size must be >= 1"); - CHECK_MSG( + TORCHAO_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D"); + TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); + TORCHAO_CHECK( weight_scales.size(0) == ((n * k) / group_size), "expected 1 scale per group"); - CHECK_MSG( + TORCHAO_CHECK( has_weight_zeros == weight_zeros.has_value(), "has_weight_zeros must match weight_zeros.has_value()"); const int8_t* weight_zeros_ptr = nullptr; if constexpr (has_weight_zeros) { - CHECK_MSG( + TORCHAO_CHECK( weight_zeros.value().dtype() == torch::kInt8, "weight_zeros must be int8"); - CHECK_MSG(weight_zeros.value().dim() == 1, "weight_zeros must be 1D"); - CHECK_MSG( + TORCHAO_CHECK(weight_zeros.value().dim() == 1, "weight_zeros must be 1D"); + TORCHAO_CHECK( weight_zeros.value().size(0) == ((n * k) / group_size), "expected 1 zero per group"); weight_zeros_ptr = weight_zeros.value().const_data_ptr(); @@ -206,7 +189,7 @@ Tensor pack_weights_meta( const Tensor& weight_scales, const std::optional& weight_zeros, int64_t group_size) { - CHECK_MSG(group_size >= 1, "group_size must be >= 1"); + TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); int n = weight_qvals.size(0); int k = weight_qvals.size(1); @@ -269,22 +252,23 @@ Tensor linear_out_cpu( int n = n_tensor.size(1); int k = k_tensor.size(1); int group_size = group_size_tensor.size(1); - CHECK_MSG(n >= 1, "n must be >= 1"); - CHECK_MSG(k >= 1, "k must be >= 1"); - CHECK_MSG(group_size >= 1, "group_size must be >= 1"); + TORCHAO_CHECK(n >= 1, "n must be >= 1"); + TORCHAO_CHECK(k >= 1, "k must be >= 1"); + TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); #ifdef USE_ATEN - CHECK_MSG( + TORCHAO_CHECK( activations.dtype() == torch::kFloat32, "activations must be float32"); #endif // USE_ATEN - CHECK_MSG(activations.dim() == 2, "activations must be 2D"); + TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D"); int m = activations.size(0); int k_ = activations.size(1); - CHECK_MSG(k == k_, "activation shape is incompatible with packed weights."); + TORCHAO_CHECK( + k == k_, "activation shape is incompatible with packed weights."); #ifdef USE_ATEN - CHECK_MSG(out.dtype() == torch::kFloat32, "out must be float32"); + TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32"); #endif // USE_ATEN #ifdef USE_ATEN @@ -292,23 +276,23 @@ Tensor linear_out_cpu( #endif // USE_ATEN #ifdef USE_EXECUTORCH - CHECK_MSG(out.dim() == 2, "out must be 2D"); - CHECK_MSG(out.size(0) == m, "out shape is incorrect"); - CHECK_MSG(out.size(1) == n, "out shape is incorrect"); + TORCHAO_CHECK(out.dim() == 2, "out must be 2D"); + TORCHAO_CHECK(out.size(0) == m, "out shape is incorrect"); + TORCHAO_CHECK(out.size(1) == n, "out shape is incorrect"); #endif // USE_EXECUTORCH using namespace torchao::ops::linear_8bit_act_xbit_weight; - CHECK_MSG(packed_weights.dim() == 1, "packed_weights must be 1D"); + TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); #ifdef USE_ATEN - CHECK_MSG( + TORCHAO_CHECK( packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); #endif // USE_ATEN - CHECK_MSG( + TORCHAO_CHECK( packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), "packed_weights is not big enough to read the header."); - auto header = torchao::ops::PackedWeightsHeader::read( - packed_weights.const_data_ptr()); + auto header = + torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); auto ukernel_config = get_ukernel_config< weight_nbit, @@ -392,13 +376,14 @@ Tensor linear_meta( const Tensor& k_tensor) { int n = n_tensor.size(1); int k = k_tensor.size(1); - CHECK_MSG(n >= 1, "n must be >= 1"); - CHECK_MSG(k >= 1, "k must be >= 1"); + TORCHAO_CHECK(n >= 1, "n must be >= 1"); + TORCHAO_CHECK(k >= 1, "k must be >= 1"); - CHECK_MSG(activations.dim() == 2, "activations must be 2D"); + TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D"); int m = activations.size(0); int k_ = activations.size(1); - CHECK_MSG(k == k_, "activation shape is incompatible with packed weights."); + TORCHAO_CHECK( + k == k_, "activation shape is incompatible with packed weights."); return torch::empty({m, n}).to("meta"); } #endif // USE_ATEN diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h index b2563b26e..d86a42946 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h @@ -5,12 +5,12 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include +#include #include namespace torchao::ops::linear_8bit_act_xbit_weight { -torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( +inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( int weight_nbit, bool has_weight_zeros, bool has_bias, diff --git a/torchao/experimental/ops/macro.h b/torchao/experimental/ops/macro.h deleted file mode 100644 index a291d5d34..000000000 --- a/torchao/experimental/ops/macro.h +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include - -#define TORCHAO_CHECK(cond, message) \ - if (!(cond)) { \ - throw std::runtime_error(message); \ - } diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 7666f8b78..e22c97e05 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -262,7 +262,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): ) -class Int8DynActIntxWeightQuantizer: +class Int8DynActIntxWeightLinearQuantizer: def __init__( self, device, @@ -274,14 +274,14 @@ def __init__( ): if device != "cpu": raise NotImplementedError( - "Only device=cpu is currently supported in Int8DynActLowbitWeightQuantizer" + "Only device=cpu is currently supported in Int8DynActIntxWeightLinearQuantizer" ) else: self.device = device if precision != torch.float32: raise NotImplementedError( - "Only precision=torch.float32 is currently supported in Int8DynActLowbitWeightQuantizer" + "Only precision=torch.float32 is currently supported in Int8DynActIntxWeightLinearQuantizer" ) else: self.precision = precision diff --git a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py index 631f18200..87d81c3d9 100644 --- a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py +++ b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py @@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from quant_api import ( _Int8DynActIntxWeightQuantizedLinearFallback, - Int8DynActIntxWeightQuantizer, + Int8DynActIntxWeightLinearQuantizer, ) @@ -74,7 +74,7 @@ def test_accuracy(self): for has_weight_zeros in [True, False]: print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") quantized_model = copy.deepcopy(model) - quantizer = Int8DynActIntxWeightQuantizer( + quantizer = Int8DynActIntxWeightLinearQuantizer( device="cpu", precision=torch.float32, bitwidth=nbit, @@ -122,7 +122,7 @@ def test_export_compile_aoti(self): activations = torch.randn(m, k0, dtype=torch.float32) print("Quantizing model") - quantizer = Int8DynActIntxWeightQuantizer( + quantizer = Int8DynActIntxWeightLinearQuantizer( device="cpu", precision=torch.float32, bitwidth=nbit,