Skip to content

Commit

Permalink
Kleidi 4b blockwise gemv prototype
Browse files Browse the repository at this point in the history
Differential Revision: D64194844

Pull Request resolved: #997
  • Loading branch information
digantdesai authored Oct 11, 2024
1 parent 5277507 commit db72dd1
Show file tree
Hide file tree
Showing 11 changed files with 842 additions and 10 deletions.
9 changes: 7 additions & 2 deletions torchao/experimental/build_torchao_ops.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
#!/bin/bash
#!/bin/bash -eu
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

if [[ $# -ne 1 ]]; then
echo "Usage: $0 <aten|executorch>";
exit 1;
fi
TARGET="${1}"
export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
export CMAKE_OUT=/tmp/cmake-out/torchao
cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
-DTORCHAO_OP_TARGET="$1" \
-DTORCHAO_OP_TARGET="${TARGET}" \
-DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \
-DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \
-S . \
Expand Down
23 changes: 23 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

include(FetchContent)

# KleidiAI is an open-source library that provides optimized
# performance-critical routines, also known as micro-kernels, for artificial
# intelligence (AI) workloads tailored for Arm® CPUs.
FetchContent_Declare(kleidiai
GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git
GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this

FetchContent_MakeAvailable(kleidiai)

# Disabled by default. Force enable if we are on a suitable system.
# TODO: Introduce ISA specific flags for i8mm.
CMAKE_DEPENDENT_OPTION(BUILD_KLEIDI "Download, build, and link against Arm KleidiAI library"
OFF "CMAKE_SYSTEM_PROCESSOR STREQUAL \"arm64\"" ON)

if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
add_library(
torchao_kernels_aarch64
Expand All @@ -12,6 +28,13 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
)
if (BUILD_KLEIDI)
# Temporarily exposing this to the parent scope until we wire
# this up properly from the top level
set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE)
message(STATUS "Building with Kleidi")
target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai)
endif()
endif()

install(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
#!/bin/bash
#!/bin/bash -eu
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK}
set -eu

if [[ $# -ne 1 ]]; then
echo "Usage: $0 <quantization|bitpacking|linear>";
exit 1;
fi

BENCHMARK_TYPE="${1}"
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)

export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../..
export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks

# Build
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \
-B ${CMAKE_OUT}

cmake --build ${CMAKE_OUT}

# Run
case "$1" in
case "${BENCHMARK_TYPE}" in
quantization) ${CMAKE_OUT}/benchmark_quantization; ;;
bitpacking) ${CMAKE_OUT}/benchmark_bitpacking; ;;
linear) ${CMAKE_OUT}/benchmark_linear; ;;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>

namespace torchao::kernels::cpu::aarch64::kleidi {
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {

namespace neon_dotprod_1x4x32 {
const Ukernel get_ukernel() {
return Ukernel{
.get_m_step =
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_n_step =
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_mr =
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_nr =
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_kr =
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_sr =
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_lhs_packed_offset =
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_rhs_packed_offset =
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_dst_offset =
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.get_dst_size =
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
.run_matmul =
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod};
}

size_t activation_data_size(int m, int k, int group_size) {
(void)group_size; // unused
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
get_ukernel(), m, k);
}

void prepare_activation_data(
void* activation_data,
int m,
int k,
int group_size,
const float* activations) {
(void)group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(), activation_data, m, k, activations);
}

size_t weight_data_size(int n, int k, int group_size) {
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
get_ukernel(), n, k, group_size);
}

void prepare_weight_data(
void* weight_data,
int n,
int k,
int group_size,
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros) {
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
get_ukernel(),
weight_data,
n,
k,
group_size,
weight_qvals,
weight_scales,
weight_zeros);
}

void kernel(
float32_t* output,
int output_m_stride,
int m,
int n,
int k,
int group_size,
const void* weight_data,
const void* activation_data,
const float* bias,
float clamp_min,
float clamp_max) {
(void)bias; // TODO(T203756650) - unused - needs API fixing
assert(output_m_stride == n);
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
}

auto ukernel = get_ukernel();
ukernel.run_matmul(
m,
n,
k,
group_size,
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
}

size_t get_preferred_alignement() {
return 16;
}
} // namespace neon_dotprod_1x4x32
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
} // namespace torchao::kernels::cpu::aarch64::kleidi
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>

namespace torchao::kernels::cpu::aarch64::kleidi {
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
namespace neon_dotprod_1x8x32 {
const Ukernel get_ukernel() {
return Ukernel{
.get_m_step =
kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_n_step =
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_mr =
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_nr =
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_kr =
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_sr =
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_lhs_packed_offset =
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_rhs_packed_offset =
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_dst_offset =
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.get_dst_size =
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
.run_matmul =
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod};
}

size_t activation_data_size(int m, int k, int group_size) {
(void) group_size; // unused
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k);
}

void prepare_activation_data(
void* activation_data,
int m,
int k,
int group_size,
const float* activations) {
(void) group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(),
activation_data,
m,
k,
activations);
}

size_t weight_data_size(int n, int k, int group_size) {
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size);
}

void prepare_weight_data(
void* weight_data,
int n,
int k,
int group_size,
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros) {
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
get_ukernel(),
weight_data,
n,
k,
group_size,
weight_qvals,
weight_scales,
weight_zeros);
}

void kernel(
float32_t* output,
int output_m_stride,
int m,
int n,
int k,
int group_size,
const void* weight_data,
const void* activation_data,
const float* bias,
float clamp_min,
float clamp_max) {
(void) bias; // TODO(T203756650) - unused - needs API fixing
assert(output_m_stride == n);
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
}

auto ukernel = get_ukernel();
ukernel.run_matmul(
m,
n,
k,
group_size,
activation_data,
weight_data,
output,
/*dst_stride_row=*/ n * sizeof(float),
/*dst_stride_col=*/ sizeof(float),
clamp_min,
clamp_max);
}

size_t get_preferred_alignement() {
return 16;
}
} // namespace neon_dotprod_1x4x32
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
} // namespace torchao::kernels::cpu::aarch64::kleidi
Loading

0 comments on commit db72dd1

Please sign in to comment.