diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index 2cff22326..340f295e5 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -9,13 +9,13 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include #include #include #include #include #include #include +#include #include namespace torchao { @@ -142,7 +142,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( break; case 6: torchao::bitpacking::internal::vec_pack_32_uint6_values( - packed, shifted0, shifted1); + packed, shifted0, shifted1); break; default: assert(false); @@ -153,7 +153,7 @@ template TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( int8x16_t& unpacked0, int8x16_t& unpacked1, - uint8_t* packed) { + const uint8_t* packed) { static_assert(nbit < 8); static_assert(nbit >= 1); @@ -217,7 +217,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( break; case 6: torchao::bitpacking::internal::vec_unpack_32_uint6_values( - shifted0, shifted1, packed); + shifted0, shifted1, packed); break; default: assert(false); @@ -288,7 +288,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( int8x16_t& unpacked1, int8x16_t& unpacked2, int8x16_t& unpacked3, - uint8_t* packed) { + const uint8_t* packed) { static_assert(nbit < 8); static_assert(nbit >= 1); @@ -443,7 +443,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( int8x16_t& unpacked5, int8x16_t& unpacked6, int8x16_t& unpacked7, - uint8_t* packed) { + const uint8_t* packed) { static_assert(nbit < 8); static_assert(nbit >= 1); diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h index 78d0f76e8..de999a53d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h @@ -8,7 +8,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint1. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h index d036c6ebc..630bc2279 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint4. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h index 9a42bdb00..a808ee3a2 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint3. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h index 3b1352d91..fba626ea5 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint4. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h index 0eceb56b7..06a8d63c2 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint5. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h index 53dc9ec2c..87712f7bc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint5. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h b/torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h new file mode 100644 index 000000000..eb8ee3849 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h @@ -0,0 +1,327 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::embedding { + +namespace internal { + +TORCHAO_ALWAYS_INLINE inline void vec_dequantize_i32_fp32( + float32x4_t& out, + const int32x4_t& qvals, + const float32x4_t& scales, + const float32x4_t& zeros) { + out = vcvtq_f32_s32(qvals); + out = vsubq_f32(out, zeros); + out = vmulq_f32(out, scales); +} + +TORCHAO_ALWAYS_INLINE inline void vec_dequantize_and_store_16_values( + float* out, + const int8x16_t& input, + float scale, + float zero) { + float32x4_t dequant; + float32x4_t scales = vdupq_n_f32(scale); + float32x4_t zeros = vdupq_n_f32(zero); + + int16x8_t low8 = vmovl_s8(vget_low_s8(input)); + int32x4_t qvals = vmovl_s16(vget_low_s16(low8)); + vec_dequantize_i32_fp32(dequant, qvals, scales, zeros); + vst1q_f32(out, dequant); + + qvals = vmovl_s16(vget_high_s16(low8)); + vec_dequantize_i32_fp32(dequant, qvals, scales, zeros); + vst1q_f32(out + 4, dequant); + + int16x8_t high8 = vmovl_s8(vget_high_s8(input)); + qvals = vmovl_s16(vget_low_s16(high8)); + vec_dequantize_i32_fp32(dequant, qvals, scales, zeros); + vst1q_f32(out + 8, dequant); + + qvals = vmovl_s16(vget_high_s16(high8)); + vec_dequantize_i32_fp32(dequant, qvals, scales, zeros); + vst1q_f32(out + 12, dequant); +} + +} // namespace internal + +template +inline void pack_embedding_weight_qvals_( + // Output + void* packed_qvals, + // Inputs + int embedding_dim, + const int8_t* qvals) { + assert(embedding_dim % 32 == 0); + + constexpr int bytes_per_packed_128_values = (128 * weight_nbit) / 8; + constexpr int bytes_per_packed_64_values = (64 * weight_nbit) / 8; + constexpr int bytes_per_packed_32_values = (32 * weight_nbit) / 8; + auto packed_qvals_byte_ptr = reinterpret_cast(packed_qvals); + + int8x16_t qvals0; + int8x16_t qvals1; + int8x16_t qvals2; + int8x16_t qvals3; + int8x16_t qvals4; + int8x16_t qvals5; + int8x16_t qvals6; + int8x16_t qvals7; + + int packed_offset = 0; + int i = 0; + for (; i + 128 - 1 < embedding_dim; i += 128) { + qvals0 = vld1q_s8(qvals + i); + qvals1 = vld1q_s8(qvals + i + 16); + qvals2 = vld1q_s8(qvals + i + 32); + qvals3 = vld1q_s8(qvals + i + 48); + qvals4 = vld1q_s8(qvals + i + 64); + qvals5 = vld1q_s8(qvals + i + 80); + qvals6 = vld1q_s8(qvals + i + 96); + qvals7 = vld1q_s8(qvals + i + 112); + torchao::bitpacking::vec_pack_128_lowbit_values( + packed_qvals_byte_ptr + packed_offset, + qvals0, + qvals1, + qvals2, + qvals3, + qvals4, + qvals5, + qvals6, + qvals7); + packed_offset += bytes_per_packed_128_values; + } + + if (i + 64 - 1 < embedding_dim) { + qvals0 = vld1q_s8(qvals + i); + qvals1 = vld1q_s8(qvals + i + 16); + qvals2 = vld1q_s8(qvals + i + 32); + qvals3 = vld1q_s8(qvals + i + 48); + torchao::bitpacking::vec_pack_64_lowbit_values( + packed_qvals_byte_ptr + packed_offset, qvals0, qvals1, qvals2, qvals3); + packed_offset += bytes_per_packed_64_values; + i += 64; + } + + if (i + 32 - 1 < embedding_dim) { + qvals0 = vld1q_s8(qvals + i); + qvals1 = vld1q_s8(qvals + i + 16); + torchao::bitpacking::vec_pack_32_lowbit_values( + packed_qvals_byte_ptr + packed_offset, qvals0, qvals1); + packed_offset += bytes_per_packed_32_values; + i += 32; + } + + assert(i == embedding_dim); // because 32 | embedding_dim +} + +template +inline void embedding_( + // Output + float* out, + // Inputs + int embedding_dim, + int group_size, + const void* packed_weight_qvals, + const float* weight_scales, + // If weight_zeros is nullptr, they are assumed zeros + const int8_t* weight_zeros) { + assert(embedding_dim % 32 == 0); + + constexpr int bytes_per_packed_128_values = (128 * weight_nbit) / 8; + constexpr int bytes_per_packed_64_values = (64 * weight_nbit) / 8; + constexpr int bytes_per_packed_32_values = (32 * weight_nbit) / 8; + auto packed_weight_qvals_byte_ptr = + reinterpret_cast(packed_weight_qvals); + + int8x16_t qvals0; + int8x16_t qvals1; + int8x16_t qvals2; + int8x16_t qvals3; + int8x16_t qvals4; + int8x16_t qvals5; + int8x16_t qvals6; + int8x16_t qvals7; + + int packed_offset = 0; + int i = 0; + for (; i + 128 - 1 < embedding_dim; i += 128) { + torchao::bitpacking::vec_unpack_128_lowbit_values( + qvals0, + qvals1, + qvals2, + qvals3, + qvals4, + qvals5, + qvals6, + qvals7, + packed_weight_qvals_byte_ptr + packed_offset); + packed_offset += bytes_per_packed_128_values; + + // Dequantize and store first 32 values + int group_idx = i / group_size; + float scale = weight_scales[group_idx]; + float zero = 0.0; + if (weight_zeros != nullptr) { + zero = weight_zeros[group_idx]; + } + internal::vec_dequantize_and_store_16_values(out + i, qvals0, scale, zero); + internal::vec_dequantize_and_store_16_values( + out + i + 16, qvals1, scale, zero); + + // Dequantize and store second 32 values + group_idx = (i + 32) / group_size; + scale = weight_scales[group_idx]; + if (weight_zeros != nullptr) { + zero = weight_zeros[group_idx]; + } + internal::vec_dequantize_and_store_16_values( + out + i + 32, qvals2, scale, zero); + internal::vec_dequantize_and_store_16_values( + out + i + 48, qvals3, scale, zero); + + // Dequantize and store third 32 values + group_idx = (i + 64) / group_size; + scale = weight_scales[group_idx]; + if (weight_zeros != nullptr) { + zero = weight_zeros[group_idx]; + } + internal::vec_dequantize_and_store_16_values( + out + i + 64, qvals4, scale, zero); + internal::vec_dequantize_and_store_16_values( + out + i + 80, qvals5, scale, zero); + + // Dequantize and store fourth 32 values + group_idx = (i + 96) / group_size; + scale = weight_scales[group_idx]; + if (weight_zeros != nullptr) { + zero = weight_zeros[group_idx]; + } + internal::vec_dequantize_and_store_16_values( + out + i + 96, qvals6, scale, zero); + internal::vec_dequantize_and_store_16_values( + out + i + 112, qvals7, scale, zero); + } + + if (i + 64 - 1 < embedding_dim) { + torchao::bitpacking::vec_unpack_64_lowbit_values( + qvals0, + qvals1, + qvals2, + qvals3, + packed_weight_qvals_byte_ptr + packed_offset); + packed_offset += bytes_per_packed_64_values; + + // Dequantize and store first 32 values + int group_idx = i / group_size; + float scale = weight_scales[group_idx]; + float zero = 0.0; + if (weight_zeros != nullptr) { + zero = weight_zeros[group_idx]; + } + internal::vec_dequantize_and_store_16_values(out + i, qvals0, scale, zero); + internal::vec_dequantize_and_store_16_values( + out + i + 16, qvals1, scale, zero); + + // Dequantize and store second 32 values + group_idx = (i + 32) / group_size; + scale = weight_scales[group_idx]; + if (weight_zeros != nullptr) { + zero = weight_zeros[group_idx]; + } + internal::vec_dequantize_and_store_16_values( + out + i + 32, qvals2, scale, zero); + internal::vec_dequantize_and_store_16_values( + out + i + 48, qvals3, scale, zero); + + i += 64; + } + + if (i + 32 - 1 < embedding_dim) { + torchao::bitpacking::vec_unpack_32_lowbit_values( + qvals0, qvals1, packed_weight_qvals_byte_ptr + packed_offset); + packed_offset += bytes_per_packed_32_values; + + int group_idx = i / group_size; + float scale = weight_scales[group_idx]; + float zero = 0.0; + if (weight_zeros != nullptr) { + zero = weight_zeros[group_idx]; + } + internal::vec_dequantize_and_store_16_values(out + i, qvals0, scale, zero); + internal::vec_dequantize_and_store_16_values( + out + i + 16, qvals1, scale, zero); + + i += 32; + } + + assert(i == embedding_dim); // because 32 | embedding_dim +} + +template +inline void embedding( + // Output + float* out, + // Inputs + int embedding_dim, + int group_size, + const void* packed_weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros, + int index) { + assert(group_size % 32 == 0); + assert(embedding_dim % group_size == 0); + + auto packed_weight_qvals_byte_ptr = + reinterpret_cast(packed_weight_qvals); + + int groups_per_embedding = embedding_dim / group_size; + int packed_bytes_per_embedding = embedding_dim * weight_nbit / 8; + + packed_weight_qvals_byte_ptr += (index * packed_bytes_per_embedding); + weight_scales += index * groups_per_embedding; + if (weight_zeros != nullptr) { + weight_zeros += index * groups_per_embedding; + } + embedding_( + out, + embedding_dim, + group_size, + packed_weight_qvals_byte_ptr, + weight_scales, + weight_zeros); +} + +template +inline void pack_embedding_weight_qvals( + // Output + void* packed_qvals, + // Inputs + int embedding_dim, + const int8_t* qvals, + int index) { + assert(embedding_dim % 8 == 0); + int packed_bytes_per_embedding = embedding_dim * weight_nbit / 8; + auto packed_qvals_byte_ptr = reinterpret_cast(packed_qvals); + + pack_embedding_weight_qvals_( + packed_qvals_byte_ptr + index * packed_bytes_per_embedding, + embedding_dim, + qvals + index * embedding_dim); +} + +} // namespace torchao::kernels::cpu::aarch64::embedding + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h b/torchao/experimental/kernels/cpu/aarch64/macro.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h rename to torchao/experimental/kernels/cpu/aarch64/macro.h diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index c9799eadd..c0dedb0f2 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -84,9 +84,18 @@ target_link_libraries( dep ) +add_executable(test_embedding test_embedding.cpp) +target_link_libraries( + test_embedding + PRIVATE + GTest::gtest_main + dep +) + include(GoogleTest) gtest_discover_tests(test_quantization) gtest_discover_tests(test_reduction) gtest_discover_tests(test_bitpacking) gtest_discover_tests(test_linear) gtest_discover_tests(test_valpacking) +gtest_discover_tests(test_embedding) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 27c584ce2..3bffb2ddc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -19,3 +19,4 @@ ${CMAKE_OUT}/test_reduction ${CMAKE_OUT}/test_bitpacking ${CMAKE_OUT}/test_linear ${CMAKE_OUT}/test_valpacking +${CMAKE_OUT}/test_embedding diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp new file mode 100644 index 000000000..a6d6ac8a8 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp @@ -0,0 +1,155 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include + +float kTol = 0.0001; + +template +void test_embedding( + int num_embeddings, + int embedding_dim, + int group_size, + bool has_weight_zeros) { + auto test_case = torchao::lowbit_embedding_test_case::generate( + num_embeddings, embedding_dim, group_size, has_weight_zeros); + + auto packed = std::vector( + num_embeddings * embedding_dim * weight_nbit / 8, 0); + auto output = std::vector(num_embeddings * embedding_dim, 0.0); + + for (int i = 0; i < num_embeddings; i++) { + torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals< + weight_nbit>( + packed.data(), embedding_dim, test_case.weight_qvals.data(), i); + } + + int8_t* weight_zeros = nullptr; + if (has_weight_zeros) { + weight_zeros = test_case.weight_zeros.data(); + } + + for (int i = 0; i < num_embeddings; i++) { + torchao::kernels::cpu::aarch64::embedding::embedding( + output.data() + i * embedding_dim, + embedding_dim, + group_size, + packed.data(), + test_case.weight_scales.data(), + weight_zeros, + i); + } + + for (int i = 0; i < num_embeddings * embedding_dim; i++) { + EXPECT_NEAR(output[i], test_case.expected_outputs[i], kTol); + } +} + +TEST(test_embedding, NBit1) { + constexpr int num_embeddings = 5; + constexpr int group_size = 128 * 3 + 64 + 32; + constexpr int embedding_dim = group_size * 7; + + test_embedding<1>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/true); + test_embedding<1>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/false); +} + +TEST(test_embedding, NBit2) { + constexpr int num_embeddings = 5; + constexpr int group_size = 128 * 3 + 64 + 32; + constexpr int embedding_dim = group_size * 7; + + test_embedding<2>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/true); + test_embedding<2>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/false); +} + +TEST(test_embedding, NBit3) { + constexpr int num_embeddings = 5; + constexpr int group_size = 128 * 3 + 64 + 32; + constexpr int embedding_dim = group_size * 7; + + test_embedding<3>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/true); + test_embedding<3>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/false); +} + +TEST(test_embedding, NBit4) { + constexpr int num_embeddings = 5; + constexpr int group_size = 128 * 3 + 64 + 32; + constexpr int embedding_dim = group_size * 7; + + test_embedding<4>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/true); + test_embedding<4>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/false); + + // More detailed testing for 4-bit case + + test_embedding<4>( + num_embeddings, + /*embedding_dim=*/256, + /*group_size=*/32, + /*has_weight_zeros=*/true); + test_embedding<4>( + num_embeddings, + /*embedding_dim=*/256, + /*group_size=*/32, + /*has_weight_zeros=*/false); + test_embedding<4>( + num_embeddings, + /*embedding_dim=*/256, + /*group_size=*/64, + /*has_weight_zeros=*/true); + test_embedding<4>( + num_embeddings, + /*embedding_dim=*/256, + /*group_size=*/64, + /*has_weight_zeros=*/false); + test_embedding<4>( + num_embeddings, + /*embedding_dim=*/256, + /*group_size=*/128, + /*has_weight_zeros=*/true); + test_embedding<4>( + num_embeddings, + /*embedding_dim=*/256, + /*group_size=*/128, + /*has_weight_zeros=*/false); +} + +TEST(test_embedding, NBit5) { + constexpr int num_embeddings = 5; + constexpr int group_size = 128 * 3 + 64 + 32; + constexpr int embedding_dim = group_size * 7; + + test_embedding<5>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/true); + test_embedding<5>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/false); +} + +TEST(test_embedding, NBit6) { + constexpr int num_embeddings = 5; + constexpr int group_size = 128 * 3 + 64 + 32; + constexpr int embedding_dim = group_size * 7; + + test_embedding<6>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/true); + test_embedding<6>( + num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/false); +} + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index e9f36e14a..1bf9cf85b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -43,9 +43,29 @@ inline std::vector get_random_lowbit_vector(int size, int nbit) { return res; } -// TODO move these to a common utils -inline uint16_t -get_bf16_from_float(float f) { +inline std::vector get_random_signed_lowbit_vector(int size, int nbit) { + assert(nbit >= 1); + assert(nbit <= 8); + + int min = 0; + int max = (1 << nbit) - 1; + int offset = (1 << (nbit - 1)); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); + + std::vector res(size); + std::vector tmp(size); + std::generate(tmp.begin(), tmp.end(), std::ref(dist)); + for (int i = 0; i < size; i++) { + res[i] = tmp[i] - offset; + } + return res; +} + +// TODO move these to a common utils +inline uint16_t get_bf16_from_float(float f) { uint16_t bf16; #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ memcpy(&bf16, &f, sizeof(uint16_t)); @@ -57,8 +77,7 @@ get_bf16_from_float(float f) { return bf16; } -inline float -get_float_from_bf16(uint16_t bf16) { +inline float get_float_from_bf16(uint16_t bf16) { float f; const uint32_t i32 = (bf16 << 16); memcpy(&f, &i32, sizeof(uint32_t)); @@ -158,7 +177,7 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { bool has_weight_zeros, bool has_bias, bool has_clamp, - bool round_weight_scales_to_bf16=false) { + bool round_weight_scales_to_bf16 = false) { // activations is m x k (stored in row-major) // weights is k x n (stored in column-major) @@ -303,6 +322,86 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { } }; +template +struct lowbit_embedding_test_case { + int num_embeddings; + int embedding_dim; + int group_size; + std::vector weight_qvals; + std::vector weight_scales; + std::vector weight_zeros; + std::vector expected_outputs; + + lowbit_embedding_test_case( + int num_embeddings, + int embedding_dim, + int group_size, + std::vector weight_qvals, + std::vector weight_scales, + std::vector weight_zeros, + std::vector expected_outputs) + : num_embeddings{num_embeddings}, + embedding_dim{embedding_dim}, + group_size{group_size}, + weight_qvals{weight_qvals}, + weight_scales{weight_scales}, + weight_zeros{weight_zeros}, + expected_outputs{expected_outputs} { + assert(embedding_dim % group_size == 0); + assert(weight_qvals.size() == num_embeddings * embedding_dim); + assert( + weight_scales.size() == num_embeddings * (embedding_dim / group_size)); + assert( + weight_zeros.size() == num_embeddings * (embedding_dim / group_size)); + assert(expected_outputs.size() == num_embeddings * embedding_dim); + } + + static lowbit_embedding_test_case generate( + int num_embeddings, + int embedding_dim, + int group_size, + bool has_weight_zeros) { + int groups_per_embedding = embedding_dim / group_size; + + auto weight_qvals = get_random_signed_lowbit_vector( + num_embeddings * embedding_dim, weight_nbit); + auto weight_scales = + get_random_vector(num_embeddings * groups_per_embedding, 0.1, 1.0); + + std::vector weight_zeros; + if (has_weight_zeros) { + weight_zeros = get_random_signed_lowbit_vector( + num_embeddings * groups_per_embedding, weight_nbit); + } else { + weight_zeros = + std::vector(num_embeddings * groups_per_embedding, 0); + } + + auto expected_outputs = std::vector(num_embeddings * embedding_dim); + for (int embedding_idx = 0; embedding_idx < num_embeddings; + embedding_idx++) { + for (int j = 0; j < embedding_dim; j++) { + auto qval = weight_qvals[embedding_idx * embedding_dim + j]; + auto scale = weight_scales + [embedding_idx * groups_per_embedding + j / group_size]; + auto zero = + weight_zeros[embedding_idx * groups_per_embedding + j / group_size]; + expected_outputs[embedding_idx * embedding_dim + j] = + scale * (qval - zero); + } + } + + return lowbit_embedding_test_case( + num_embeddings, + embedding_dim, + group_size, + weight_qvals, + weight_scales, + weight_zeros, + expected_outputs); + } +}; + } // namespace torchao #endif // defined(__aarch64__) || defined(__ARM_NEON)