From ff65188b748e7514c43557b22b9a29085ed5dec0 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 14 Oct 2024 18:10:49 -0700 Subject: [PATCH] Create header for packed weight ops Differential Revision: D63498956 Pull Request resolved: https://github.com/pytorch/ao/pull/1072 --- .../linear_8bit_act_xbit_weight.h | 3 + .../op_linear_8bit_act_xbit_weight-impl.h | 102 +++++++++++++----- .../packed_weights_header.h | 39 +++++++ .../experimental/ops/packed_weights_header.h | 66 ++++++++++++ 4 files changed, 185 insertions(+), 25 deletions(-) create mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h create mode 100644 torchao/experimental/ops/packed_weights_header.h diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index 6ec098314..568da46bf 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -7,6 +7,7 @@ #pragma once #include #include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { @@ -59,6 +60,8 @@ struct UKernelConfig { kernel_fn_type kernel_fn{nullptr}; int mr{0}; int nr{0}; + + torchao::ops::PackedWeightsHeader packed_weights_header; }; // Pack weight functions diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index ba732d526..46de55446 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -11,6 +11,8 @@ #endif // defined(__aarch64__) || defined(__ARM_NEON) #include +#include +#include #include #include @@ -35,31 +37,63 @@ using RuntimeContext = torch::executor::KernelRuntimeContext; namespace { +// This selects a UkernelConfig based on the packed weight header template inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig -get_ukernel_config() { +get_ukernel_config(torchao::ops::PackedWeightsHeader header) { torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config; + switch (header.format) { #if defined(__aarch64__) || defined(__ARM_NEON) - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; + case torchao::ops::PackedWeightsFormat:: + linear_8bit_act_xbit_weight_universal: + namespace ukernel + = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + + // Check packing params match the kernel + CHECK_MSG( + header == + torchao::ops::linear_8bit_act_xbit_weight:: + get_packed_weights_header_universal( + weight_nbit, + has_weight_zeros, + has_bias, + /*nr=*/8, + /*kr=*/16), + "Packing params do not match what kernel supports"); + + config.packed_weights_header = header; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.preferred_activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.preferred_weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + return config; + break; + default: + CHECK_MSG(false, "Unsupported packed weights format"); #endif // defined(__aarch64__) || defined(__ARM_NEON) + } +} - return config; +template +inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig +get_ukernel_config() { + auto header = torchao::ops::linear_8bit_act_xbit_weight:: + get_packed_weights_header_universal( + weight_nbit, has_weight_zeros, has_bias, /*nr=*/8, /*kr=*/16); + return get_ukernel_config( + header); } #ifdef USE_ATEN @@ -114,13 +148,17 @@ Tensor pack_weights_cpu( auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( ukernel_config, n, /*target_panels_per_thread=*/1); - auto packed_weight_data_size = + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + get_packed_weight_data_size(ukernel_config, n, k, group_size); - Tensor packed_weights = torch::empty({static_cast(packed_weight_data_size)}, torch::kInt8); + Tensor packed_weights = torch::empty( + {static_cast(packed_weight_data_size)}, torch::kInt8); + ukernel_config.packed_weights_header.write( + packed_weights.mutable_data_ptr()); pack_weight_data_operator( ukernel_config, pack_weight_tiling_params, - packed_weights.mutable_data_ptr(), + packed_weights.mutable_data_ptr() + + torchao::ops::PackedWeightsHeader::size(), n, k, group_size, @@ -180,9 +218,10 @@ Tensor pack_weights_meta( false /*has_bias*/, false /*has_clamp*/>(); - auto packed_weight_data_size = + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + get_packed_weight_data_size(ukernel_config, n, k, group_size); - return torch::empty({static_cast(packed_weight_data_size)}).to("meta"); + return torch::empty({static_cast(packed_weight_data_size)}) + .to("meta"); } #endif // USE_ATEN @@ -260,11 +299,23 @@ Tensor linear_out_cpu( using namespace torchao::ops::linear_8bit_act_xbit_weight; + CHECK_MSG(packed_weights.dim() == 1, "packed_weights must be 1D"); +#ifdef USE_ATEN + CHECK_MSG( + packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); +#endif // USE_ATEN + CHECK_MSG( + packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), + "packed_weights is not big enough to read the header."); + auto header = torchao::ops::PackedWeightsHeader::read( + packed_weights.const_data_ptr()); + auto ukernel_config = get_ukernel_config< weight_nbit, has_weight_zeros /*has_weight_zeros*/, false /*has_bias*/, - false /*has_clamp*/>(); + false /*has_clamp*/>(header); + auto linear_tiling_params = get_default_linear_tiling_params( ukernel_config, m, @@ -292,7 +343,8 @@ Tensor linear_out_cpu( n, k, group_size, - packed_weights.const_data_ptr(), + packed_weights.const_data_ptr() + + torchao::ops::PackedWeightsHeader::size(), activations.const_data_ptr(), /*bias=*/nullptr, // Clamp parameters are ignored because config is created from diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h new file mode 100644 index 000000000..5357c39f0 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h @@ -0,0 +1,39 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +namespace torchao::ops::linear_8bit_act_xbit_weight { + +torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int version = 1) { + TORCHAO_CHECK( + version >= 0 && version < 256, "version must be between 0 and 255"); + TORCHAO_CHECK( + weight_nbit >= 1 && weight_nbit < 256, + "weight_nbit must be between 1 and 255"); + return torchao::ops::PackedWeightsHeader( + torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal, + {((static_cast(version) << 8) | + static_cast(weight_nbit)), + ((static_cast(has_weight_zeros) << 8) | + static_cast(has_bias)), + static_cast(nr), + static_cast(kr), + 0, + 0, + 0, + 0}); +} + +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/packed_weights_header.h b/torchao/experimental/ops/packed_weights_header.h new file mode 100644 index 000000000..f77e9ff1f --- /dev/null +++ b/torchao/experimental/ops/packed_weights_header.h @@ -0,0 +1,66 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include + +#include +namespace torchao::ops { + +enum PackedWeightsFormat : unsigned short { + unknown = 0, + linear_8bit_act_xbit_weight_universal = 1 +}; + +class PackedWeightsHeader { + public: + using params_type = std::array; + PackedWeightsFormat format; + + // 14 bytes of format specific params + params_type params; + + PackedWeightsHeader( + PackedWeightsFormat format = PackedWeightsFormat::unknown, + params_type params = {0, 0, 0, 0, 0, 0, 0}) + : format{format}, params{params} {} + + inline static constexpr int size() { + static_assert(sizeof(format) + sizeof(params) == 16); + return 16; + } + + inline void write(void* packed_weights) const { + auto header = (unsigned short*)(packed_weights); + header[0] = (unsigned short)format; + for (int i = 0; i < params.size(); i++) { + header[i + 1] = params[i]; + } + } + + static PackedWeightsHeader read(const void* packed_weights) { + auto header = (unsigned short*)(packed_weights); + params_type params; + for (int i = 0; i < params.size(); i++) { + params[i] = header[i + 1]; + } + return PackedWeightsHeader((PackedWeightsFormat)header[0], params); + } + + bool operator==(const PackedWeightsHeader& other) const { + if (format != other.format) { + return false; + } + for (int i = 0; i < params.size(); i++) { + if (params[i] != other.params[i]) { + return false; + } + } + return true; + } +}; + +} // namespace torchao::ops