diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h index 5357c39f0..b2563b26e 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h @@ -17,19 +17,18 @@ torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( int nr, int kr, int version = 1) { - TORCHAO_CHECK( - version >= 0 && version < 256, "version must be between 0 and 255"); - TORCHAO_CHECK( - weight_nbit >= 1 && weight_nbit < 256, - "weight_nbit must be between 1 and 255"); return torchao::ops::PackedWeightsHeader( torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal, - {((static_cast(version) << 8) | - static_cast(weight_nbit)), - ((static_cast(has_weight_zeros) << 8) | - static_cast(has_bias)), - static_cast(nr), - static_cast(kr), + {version, + weight_nbit, + has_weight_zeros, + has_bias, + nr, + kr, + 0, + 0, + 0, + 0, 0, 0, 0, diff --git a/torchao/experimental/ops/packed_weights_header.h b/torchao/experimental/ops/packed_weights_header.h index f77e9ff1f..8a09e6200 100644 --- a/torchao/experimental/ops/packed_weights_header.h +++ b/torchao/experimental/ops/packed_weights_header.h @@ -7,17 +7,20 @@ #pragma once #include +#include #include + namespace torchao::ops { -enum PackedWeightsFormat : unsigned short { +enum class PackedWeightsFormat : uint32_t { unknown = 0, linear_8bit_act_xbit_weight_universal = 1 }; class PackedWeightsHeader { public: - using params_type = std::array; + using params_type = std::array; + const static int magic = 6712; PackedWeightsFormat format; // 14 bytes of format specific params @@ -25,29 +28,32 @@ class PackedWeightsHeader { PackedWeightsHeader( PackedWeightsFormat format = PackedWeightsFormat::unknown, - params_type params = {0, 0, 0, 0, 0, 0, 0}) + params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) : format{format}, params{params} {} inline static constexpr int size() { - static_assert(sizeof(format) + sizeof(params) == 16); - return 16; + static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64); + return 64; } inline void write(void* packed_weights) const { - auto header = (unsigned short*)(packed_weights); - header[0] = (unsigned short)format; + auto header = reinterpret_cast(packed_weights); + header[0] = magic; + header[1] = static_cast(format); for (int i = 0; i < params.size(); i++) { - header[i + 1] = params[i]; + header[i + 2] = params[i]; } } static PackedWeightsHeader read(const void* packed_weights) { - auto header = (unsigned short*)(packed_weights); + auto header = reinterpret_cast(packed_weights); + assert(header[0] == PackedWeightsHeader::magic); params_type params; for (int i = 0; i < params.size(); i++) { - params[i] = header[i + 1]; + params[i] = header[i + 2]; } - return PackedWeightsHeader((PackedWeightsFormat)header[0], params); + return PackedWeightsHeader( + static_cast(header[1]), params); } bool operator==(const PackedWeightsHeader& other) const {