Skip to content

Commit

Permalink
Add quantized embedding kernels to torchao (#1018)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1018

This diff adds lowbit embedding kernels to torchao.  These reuse the same bitpacking code as the linear kernels.

Reviewed By: digantdesai

Differential Revision: D63839255
  • Loading branch information
metascroy authored and facebook-github-bot committed Oct 17, 2024
1 parent 7849875 commit 7af0756
Show file tree
Hide file tree
Showing 13 changed files with 609 additions and 18 deletions.
12 changes: 6 additions & 6 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
#if defined(__aarch64__) || defined(__ARM_NEON)

#include <arm_neon.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
#include <cassert>

namespace torchao {
Expand Down Expand Up @@ -142,7 +142,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
break;
case 6:
torchao::bitpacking::internal::vec_pack_32_uint6_values(
packed, shifted0, shifted1);
packed, shifted0, shifted1);
break;
default:
assert(false);
Expand All @@ -153,7 +153,7 @@ template <int nbit>
TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
int8x16_t& unpacked0,
int8x16_t& unpacked1,
uint8_t* packed) {
const uint8_t* packed) {
static_assert(nbit < 8);
static_assert(nbit >= 1);

Expand Down Expand Up @@ -217,7 +217,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
break;
case 6:
torchao::bitpacking::internal::vec_unpack_32_uint6_values(
shifted0, shifted1, packed);
shifted0, shifted1, packed);
break;
default:
assert(false);
Expand Down Expand Up @@ -288,7 +288,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
int8x16_t& unpacked1,
int8x16_t& unpacked2,
int8x16_t& unpacked3,
uint8_t* packed) {
const uint8_t* packed) {
static_assert(nbit < 8);
static_assert(nbit >= 1);

Expand Down Expand Up @@ -443,7 +443,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
int8x16_t& unpacked5,
int8x16_t& unpacked6,
int8x16_t& unpacked7,
uint8_t* packed) {
const uint8_t* packed) {
static_assert(nbit < 8);
static_assert(nbit >= 1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#if defined(__aarch64__) || defined(__ARM_NEON)
#include <arm_neon.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>

// This file contains bitpacking and unpacking methods for uint1.
// These are not inteded to be used outside of bitpacking directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#if defined(__aarch64__) || defined(__ARM_NEON)

#include <arm_neon.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>

// This file contains bitpacking and unpacking methods for uint4.
// These are not inteded to be used outside of bitpacking directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#if defined(__aarch64__) || defined(__ARM_NEON)

#include <arm_neon.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>

// This file contains bitpacking and unpacking methods for uint3.
// These are not inteded to be used outside of bitpacking directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#if defined(__aarch64__) || defined(__ARM_NEON)

#include <arm_neon.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>

// This file contains bitpacking and unpacking methods for uint4.
// These are not inteded to be used outside of bitpacking directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#if defined(__aarch64__) || defined(__ARM_NEON)

#include <arm_neon.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>

// This file contains bitpacking and unpacking methods for uint5.
// These are not inteded to be used outside of bitpacking directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#if defined(__aarch64__) || defined(__ARM_NEON)

#include <arm_neon.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>

// This file contains bitpacking and unpacking methods for uint5.
// These are not inteded to be used outside of bitpacking directory.
Expand Down
Loading

0 comments on commit 7af0756

Please sign in to comment.