Skip to content

Commit

Permalink
Header bug fix
Browse files Browse the repository at this point in the history
Differential Revision: D64370707

Pull Request resolved: #1079
  • Loading branch information
metascroy authored Oct 15, 2024
1 parent 48bc81c commit 6ea36c5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@ torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
int nr,
int kr,
int version = 1) {
TORCHAO_CHECK(
version >= 0 && version < 256, "version must be between 0 and 255");
TORCHAO_CHECK(
weight_nbit >= 1 && weight_nbit < 256,
"weight_nbit must be between 1 and 255");
return torchao::ops::PackedWeightsHeader(
torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal,
{((static_cast<unsigned short>(version) << 8) |
static_cast<unsigned short>(weight_nbit)),
((static_cast<unsigned short>(has_weight_zeros) << 8) |
static_cast<unsigned short>(has_bias)),
static_cast<unsigned short>(nr),
static_cast<unsigned short>(kr),
{version,
weight_nbit,
has_weight_zeros,
has_bias,
nr,
kr,
0,
0,
0,
0,
0,
0,
0,
Expand Down
28 changes: 17 additions & 11 deletions torchao/experimental/ops/packed_weights_header.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,53 @@
#pragma once
#include <array>

#include <stdint.h>
#include <cassert>

namespace torchao::ops {

enum PackedWeightsFormat : unsigned short {
enum class PackedWeightsFormat : uint32_t {
unknown = 0,
linear_8bit_act_xbit_weight_universal = 1
};

class PackedWeightsHeader {
public:
using params_type = std::array<unsigned short, 7>;
using params_type = std::array<int, 14>;
const static int magic = 6712;
PackedWeightsFormat format;

// 14 bytes of format specific params
params_type params;

PackedWeightsHeader(
PackedWeightsFormat format = PackedWeightsFormat::unknown,
params_type params = {0, 0, 0, 0, 0, 0, 0})
params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
: format{format}, params{params} {}

inline static constexpr int size() {
static_assert(sizeof(format) + sizeof(params) == 16);
return 16;
static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64);
return 64;
}

inline void write(void* packed_weights) const {
auto header = (unsigned short*)(packed_weights);
header[0] = (unsigned short)format;
auto header = reinterpret_cast<int*>(packed_weights);
header[0] = magic;
header[1] = static_cast<int>(format);
for (int i = 0; i < params.size(); i++) {
header[i + 1] = params[i];
header[i + 2] = params[i];
}
}

static PackedWeightsHeader read(const void* packed_weights) {
auto header = (unsigned short*)(packed_weights);
auto header = reinterpret_cast<const int*>(packed_weights);
assert(header[0] == PackedWeightsHeader::magic);
params_type params;
for (int i = 0; i < params.size(); i++) {
params[i] = header[i + 1];
params[i] = header[i + 2];
}
return PackedWeightsHeader((PackedWeightsFormat)header[0], params);
return PackedWeightsHeader(
static_cast<PackedWeightsFormat>(header[1]), params);
}

bool operator==(const PackedWeightsHeader& other) const {
Expand Down

0 comments on commit 6ea36c5

Please sign in to comment.