Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fewer instructions when unpacking uint6s. #1109

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,20 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values(
uint8x8_t packed1 = vld1_u8(packed + 8);
uint8x8_t packed2 = vld1_u8(packed + 16);

// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
// ((packed[1] & 0b1100'0000u) >> 4) |
// ((packed[2] & 0b1100'0000u) >> 2);
const uint8x8_t high = vdup_n_u8(0b1100'0000u);
uint8x8_t unpacked3;
unpacked3 = vorr_u8(
vshr_n_u8(vand_u8(packed0, high), 6),
vshr_n_u8(vand_u8(packed1, high), 4));
unpacked3 = vorr_u8(unpacked3, vshr_n_u8(vand_u8(packed2, high), 2));
// We want to extract bits 123456 and place them in unpacked3.
// Packed structure is:
//
// packed0: 56 | abcdef
// packed1: 34 | ghijkl
// packed2: 12 | mnopqr
//
// unpacked3 = 1234 ghij
unpacked3 = vsri_n_u8(packed2, packed1, 2);
// unpacked3 = 1234 56ab
unpacked3 = vsri_n_u8(unpacked3, packed0, 4);
// unpacked3 = 0012 3456
unpacked3 = vshr_n_u8(unpacked3, 2);

// unpacked[i] = packed[i] & 0b11'1111u;
const uint8x8_t mask = vdup_n_u8(0b11'1111u);
Expand Down Expand Up @@ -183,14 +188,19 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
unpacked1 = vld1q_u8(packed + 16);
unpacked2 = vld1q_u8(packed + 32);

// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
// ((packed[1] & 0b1100'0000u) >> 4) |
// ((packed[2] & 0b1100'0000u) >> 2);
const uint8x16_t high = vdupq_n_u8(0b1100'0000u);
unpacked3 = vorrq_u8(
vshrq_n_u8(vandq_u8(unpacked0, high), 6),
vshrq_n_u8(vandq_u8(unpacked1, high), 4));
unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(vandq_u8(unpacked2, high), 2));
// We want to extract bits 123456 and place them in unpacked3.
// Packed structure is:
//
// packed0: 56 | abcdef
// packed1: 34 | ghijkl
// packed2: 12 | mnopqr
//
// unpacked3 = 1234 ghij
unpacked3 = vsriq_n_u8(unpacked2, unpacked1, 2);
// unpacked3 = 1234 56ab
unpacked3 = vsriq_n_u8(unpacked3, unpacked0, 4);
// unpacked3 = 0012 3456
unpacked3 = vshrq_n_u8(unpacked3, 2);

// unpacked[i] = packed[i] & 0b11'1111u;
const uint8x16_t mask = vdupq_n_u8(0b11'1111u);
Expand Down
Loading