Skip to content

Commit

Permalink
Create header for packed weight ops
Browse files Browse the repository at this point in the history
Differential Revision: D63498956

Pull Request resolved: #1072
  • Loading branch information
metascroy authored Oct 15, 2024
1 parent afc0a02 commit ff65188
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once
#include <stdint.h>
#include <stddef.h>
#include <torchao/experimental/ops/packed_weights_header.h>

namespace torchao::ops::linear_8bit_act_xbit_weight {

Expand Down Expand Up @@ -59,6 +60,8 @@ struct UKernelConfig {
kernel_fn_type kernel_fn{nullptr};
int mr{0};
int nr{0};

torchao::ops::PackedWeightsHeader packed_weights_header;
};

// Pack weight functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#endif // defined(__aarch64__) || defined(__ARM_NEON)

#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h>
#include <torchao/experimental/ops/packed_weights_header.h>
#include <optional>
#include <vector>

Expand All @@ -35,31 +37,63 @@ using RuntimeContext = torch::executor::KernelRuntimeContext;

namespace {

// This selects a UkernelConfig based on the packed weight header
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig
get_ukernel_config() {
get_ukernel_config(torchao::ops::PackedWeightsHeader header) {
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config;

switch (header.format) {
#if defined(__aarch64__) || defined(__ARM_NEON)
namespace ukernel = torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
config.mr = 1;
config.nr = 8;
config.activation_data_size_fn =
&ukernel::activation_data_size<has_weight_zeros>;
config.preferred_activation_data_alignment = 16; // size of neon register
config.prepare_activation_data_fn =
&ukernel::prepare_activation_data<has_weight_zeros>;
config.weight_data_size_fn =
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
config.preferred_weight_data_alignment = 16; // size of neon register
config.prepare_weight_data_fn =
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
config.kernel_fn =
&ukernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>;
case torchao::ops::PackedWeightsFormat::
linear_8bit_act_xbit_weight_universal:
namespace ukernel
= torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;

// Check packing params match the kernel
CHECK_MSG(
header ==
torchao::ops::linear_8bit_act_xbit_weight::
get_packed_weights_header_universal(
weight_nbit,
has_weight_zeros,
has_bias,
/*nr=*/8,
/*kr=*/16),
"Packing params do not match what kernel supports");

config.packed_weights_header = header;
config.mr = 1;
config.nr = 8;
config.activation_data_size_fn =
&ukernel::activation_data_size<has_weight_zeros>;
config.preferred_activation_data_alignment = 16; // size of neon register
config.prepare_activation_data_fn =
&ukernel::prepare_activation_data<has_weight_zeros>;
config.weight_data_size_fn =
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
config.preferred_weight_data_alignment = 16; // size of neon register
config.prepare_weight_data_fn =
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
config.kernel_fn =
&ukernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>;
return config;
break;
default:
CHECK_MSG(false, "Unsupported packed weights format");
#endif // defined(__aarch64__) || defined(__ARM_NEON)
}
}

return config;
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig
get_ukernel_config() {
auto header = torchao::ops::linear_8bit_act_xbit_weight::
get_packed_weights_header_universal(
weight_nbit, has_weight_zeros, has_bias, /*nr=*/8, /*kr=*/16);
return get_ukernel_config<weight_nbit, has_weight_zeros, has_bias, has_clamp>(
header);
}

#ifdef USE_ATEN
Expand Down Expand Up @@ -114,13 +148,17 @@ Tensor pack_weights_cpu(
auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params(
ukernel_config, n, /*target_panels_per_thread=*/1);

auto packed_weight_data_size =
auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() +
get_packed_weight_data_size(ukernel_config, n, k, group_size);
Tensor packed_weights = torch::empty({static_cast<int64_t>(packed_weight_data_size)}, torch::kInt8);
Tensor packed_weights = torch::empty(
{static_cast<int64_t>(packed_weight_data_size)}, torch::kInt8);
ukernel_config.packed_weights_header.write(
packed_weights.mutable_data_ptr<int8_t>());
pack_weight_data_operator(
ukernel_config,
pack_weight_tiling_params,
packed_weights.mutable_data_ptr<int8_t>(),
packed_weights.mutable_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
n,
k,
group_size,
Expand Down Expand Up @@ -180,9 +218,10 @@ Tensor pack_weights_meta(
false /*has_bias*/,
false /*has_clamp*/>();

auto packed_weight_data_size =
auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() +
get_packed_weight_data_size(ukernel_config, n, k, group_size);
return torch::empty({static_cast<int64_t>(packed_weight_data_size)}).to("meta");
return torch::empty({static_cast<int64_t>(packed_weight_data_size)})
.to("meta");
}
#endif // USE_ATEN

Expand Down Expand Up @@ -260,11 +299,23 @@ Tensor linear_out_cpu(

using namespace torchao::ops::linear_8bit_act_xbit_weight;

CHECK_MSG(packed_weights.dim() == 1, "packed_weights must be 1D");
#ifdef USE_ATEN
CHECK_MSG(
packed_weights.dtype() == torch::kInt8, "packed_weights must be int8");
#endif // USE_ATEN
CHECK_MSG(
packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(),
"packed_weights is not big enough to read the header.");
auto header = torchao::ops::PackedWeightsHeader::read(
packed_weights.const_data_ptr());

auto ukernel_config = get_ukernel_config<
weight_nbit,
has_weight_zeros /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>();
false /*has_clamp*/>(header);

auto linear_tiling_params = get_default_linear_tiling_params(
ukernel_config,
m,
Expand Down Expand Up @@ -292,7 +343,8 @@ Tensor linear_out_cpu(
n,
k,
group_size,
packed_weights.const_data_ptr<int8_t>(),
packed_weights.const_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
activations.const_data_ptr<float>(),
/*bias=*/nullptr,
// Clamp parameters are ignored because config is created from
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <torchao/experimental/ops/macro.h>
#include <torchao/experimental/ops/packed_weights_header.h>

namespace torchao::ops::linear_8bit_act_xbit_weight {

torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
int nr,
int kr,
int version = 1) {
TORCHAO_CHECK(
version >= 0 && version < 256, "version must be between 0 and 255");
TORCHAO_CHECK(
weight_nbit >= 1 && weight_nbit < 256,
"weight_nbit must be between 1 and 255");
return torchao::ops::PackedWeightsHeader(
torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal,
{((static_cast<unsigned short>(version) << 8) |
static_cast<unsigned short>(weight_nbit)),
((static_cast<unsigned short>(has_weight_zeros) << 8) |
static_cast<unsigned short>(has_bias)),
static_cast<unsigned short>(nr),
static_cast<unsigned short>(kr),
0,
0,
0,
0});
}

} // namespace torchao::ops::linear_8bit_act_xbit_weight
66 changes: 66 additions & 0 deletions torchao/experimental/ops/packed_weights_header.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <array>

#include <cassert>
namespace torchao::ops {

enum PackedWeightsFormat : unsigned short {
unknown = 0,
linear_8bit_act_xbit_weight_universal = 1
};

class PackedWeightsHeader {
public:
using params_type = std::array<unsigned short, 7>;
PackedWeightsFormat format;

// 14 bytes of format specific params
params_type params;

PackedWeightsHeader(
PackedWeightsFormat format = PackedWeightsFormat::unknown,
params_type params = {0, 0, 0, 0, 0, 0, 0})
: format{format}, params{params} {}

inline static constexpr int size() {
static_assert(sizeof(format) + sizeof(params) == 16);
return 16;
}

inline void write(void* packed_weights) const {
auto header = (unsigned short*)(packed_weights);
header[0] = (unsigned short)format;
for (int i = 0; i < params.size(); i++) {
header[i + 1] = params[i];
}
}

static PackedWeightsHeader read(const void* packed_weights) {
auto header = (unsigned short*)(packed_weights);
params_type params;
for (int i = 0; i < params.size(); i++) {
params[i] = header[i + 1];
}
return PackedWeightsHeader((PackedWeightsFormat)header[0], params);
}

bool operator==(const PackedWeightsHeader& other) const {
if (format != other.format) {
return false;
}
for (int i = 0; i < params.size(); i++) {
if (params[i] != other.params[i]) {
return false;
}
}
return true;
}
};

} // namespace torchao::ops

0 comments on commit ff65188

Please sign in to comment.