Skip to content

Commit

Permalink
support blis for gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Sep 3, 2023
1 parent bf3e91b commit 0539ca8
Show file tree
Hide file tree
Showing 6 changed files with 431 additions and 266 deletions.
30 changes: 22 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ OPTION(USE_KSYMM "K Symmetry" OFF)
OPTION(USE_COMPLEX "Complex Number" OFF)
OPTION(USE_SG "General Spin Symmetry" OFF)
OPTION(USE_SINGLE_PREC "Single Precision" OFF)
OPTION(USE_BLIS "BLIS" OFF)
OPTION(ARCH_ARM64 "MacOS arch arm64" OFF)
OPTION(APPLE_ACC_SINGLE_PREC "Fix Apple Accelerate single prec" ON)
OPTION(BUILD_LIB "Build python block2.so" OFF)
Expand Down Expand Up @@ -462,6 +463,14 @@ ELSE()
SET(MKL_OMP_VALUE 1)
ENDIF()

IF (${USE_BLIS})
FIND_PATH(BLIS_INCLUDE_DIR NAMES blis/blis.h HINTS $ENV{MKLROOT}/include /usr/local/include
$ENV{BLIS_PREFIX}/include ${PYTHON_INCLUDE_PATH})
FIND_LIBRARY(BLIS_LIBS NAMES libblis.so PATHS $ENV{MKLROOT}/lib /usr/local/lib /usr/lib64
$ENV{BLIS_PREFIX}/lib ${PYTHON_INCLUDE_PATH})
SET(BLIS_FLAG "-D_HAS_BLIS")
ENDIF()

IF (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
SET(OPT_FLAG ${OMP_FLAG} -bigobj -MP)
ELSEIF("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
Expand Down Expand Up @@ -541,8 +550,11 @@ ELSE()
IF (CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
SET(BLA_VENDOR Apple)
ENDIF()
# Check LAPACK and BLAS
FIND_PACKAGE(BLAS REQUIRED)

IF (NOT ${USE_BLIS})
FIND_PACKAGE(BLAS REQUIRED)
ENDIF()

FIND_PACKAGE(LAPACK REQUIRED)

SET(MKL_INCLUDE_DIR "")
Expand Down Expand Up @@ -643,7 +655,7 @@ ELSE()
ENDIF()

TARGET_LINK_LIBRARIES(${PROJECT_NAME} PRIVATE ${OMP_LIB_NAME} ${PTHREAD})
TARGET_LINK_LIBRARIES(${PROJECT_NAME} PRIVATE ${PTHREAD} ${LAPACK_LIBRARIES} ${BLAS_LIBRARIES} ${MKL_LIBS} ${MPI_LIBS} ${TBB_LIBS})
TARGET_LINK_LIBRARIES(${PROJECT_NAME} PRIVATE ${PTHREAD} ${LAPACK_LIBRARIES} ${BLAS_LIBRARIES} ${MKL_LIBS} ${MPI_LIBS} ${TBB_LIBS} ${BLIS_LIBS})
SET_TARGET_PROPERTIES(${PROJECT_NAME} PROPERTIES LINK_FLAGS "${ARCH_LINK_FLAGS} ${MPI_LINK_FLAGS}")

MESSAGE(STATUS "SRCS = ${SRCS}")
Expand All @@ -669,14 +681,16 @@ MESSAGE(STATUS "MPI_FLAG = ${MPI_FLAG}")
MESSAGE(STATUS "OMP_LIB = ${OMP_LIB_NAME}")
MESSAGE(STATUS "MKL_OMP_LIB_NAME = ${MKL_OMP_LIB_NAME}")
MESSAGE(STATUS "TBB_LIBS = ${TBB_LIBS}")
MESSAGE(STATUS "BLIS_LIBS = ${BLIS_LIBS}")
MESSAGE(STATUS "BLIS_FLAG = ${BLIS_FLAG}")

TARGET_INCLUDE_DIRECTORIES(${PROJECT_NAME} PUBLIC ${PYTHON_INCLUDE_DIRS} ${PYBIND_INCLUDE_DIRS}
${MKL_INCLUDE_DIR} ${MPI_INCLUDE_DIR} ${TBB_INCLUDE_DIR})
${MKL_INCLUDE_DIR} ${MPI_INCLUDE_DIR} ${TBB_INCLUDE_DIR} ${BLIS_INCLUDE_DIR})
TARGET_COMPILE_OPTIONS(${PROJECT_NAME} BEFORE PRIVATE ${OPT_FLAG} ${ARCH_FLAG})
TARGET_COMPILE_OPTIONS(${PROJECT_NAME} BEFORE PUBLIC ${MKL_FLAG} ${MPI_FLAG}
${TMPL_FLAG} ${BOND_FLAG} ${SCI_FLAG} ${CORE_FLAG} ${DMRG_FLAG} ${BIG_SITE_FLAG}
${SP_DMRG_FLAG} ${IC_FLAG} ${KSYMM_FLAG} ${SG_FLAG} ${COMPLEX_FLAG} ${SINGLE_PREC_FLAG}
${TBB_FLAG} ${SU2SZ_FLAG} ${SANY_FLAG})
${TBB_FLAG} ${SU2SZ_FLAG} ${SANY_FLAG} ${BLIS_FLAG})

IF (${BUILD_TEST})
ENABLE_TESTING()
Expand All @@ -691,11 +705,11 @@ IF (${BUILD_TEST})
ENDIF()

ADD_EXECUTABLE(${PROJECT_NAME}_tests ${TSRCS} ${SRCS})
TARGET_INCLUDE_DIRECTORIES(${PROJECT_NAME}_tests PUBLIC src ${MKL_INCLUDE_DIR} ${MPI_INCLUDE_DIR} ${TBB_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${PROJECT_NAME}_tests ${GTEST_BOTH_LIBRARIES} ${PTHREAD} ${MPI_LIBS} ${TBB_LIBS})
TARGET_INCLUDE_DIRECTORIES(${PROJECT_NAME}_tests PUBLIC src ${MKL_INCLUDE_DIR} ${MPI_INCLUDE_DIR} ${TBB_INCLUDE_DIR} ${BLIS_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${PROJECT_NAME}_tests ${GTEST_BOTH_LIBRARIES} ${PTHREAD} ${MPI_LIBS} ${TBB_LIBS} ${BLIS_LIBS})
TARGET_COMPILE_OPTIONS(${PROJECT_NAME}_tests BEFORE PUBLIC ${ARCH_FLAG} ${OPT_FLAG} ${MKL_FLAG} ${MPI_FLAG}
${TMPL_FLAG} ${BOND_FLAG} ${SCI_FLAG} ${CORE_FLAG} ${DMRG_FLAG} ${BIG_SITE_FLAG}
${SP_DMRG_FLAG} ${IC_FLAG} ${KSYMM_FLAG} ${SG_FLAG} ${COMPLEX_FLAG} ${SINGLE_PREC_FLAG} ${TBB_FLAG}
${SP_DMRG_FLAG} ${IC_FLAG} ${KSYMM_FLAG} ${SG_FLAG} ${COMPLEX_FLAG} ${SINGLE_PREC_FLAG} ${TBB_FLAG} ${BLIS_FLAG}
${SU2SZ_FLAG} ${SANY_FLAG})
SET_TARGET_PROPERTIES(${PROJECT_NAME}_tests PROPERTIES LINK_FLAGS "${ARCH_LINK_FLAGS} ${MPI_LINK_FLAGS}")

Expand Down
60 changes: 45 additions & 15 deletions src/core/batch_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,15 @@ inline void threaded_xgemm_batch(
const char *tra =
TransA_Array[ig] == CblasNoTrans
? "n"
: (TransA_Array[ig] == CblasConjTrans ? "c" : "t");
: (TransA_Array[ig] == CblasConjTrans
? "c"
: (TransA_Array[ig] == CblasTrans ? "t" : "x"));
const char *trb =
TransB_Array[ig] == CblasNoTrans
? "n"
: (TransB_Array[ig] == CblasConjTrans ? "c" : "t");
: (TransB_Array[ig] == CblasConjTrans
? "c"
: (TransB_Array[ig] == CblasTrans ? "t" : "x"));
const MKL_INT m = M_Array[ig], n = N_Array[ig], k = K_Array[ig];
const FL alpha = alpha_Array[ig], beta = beta_Array[ig];
const MKL_INT lda = lda_Array[ig], ldb = ldb_Array[ig],
Expand All @@ -210,12 +214,18 @@ single_xgemm(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE *TransA_Array,
const MKL_INT *ldc_Array, const MKL_INT *group_size,
FL scale = 1.0) {
const MKL_INT ig = 0, i = 0;
const char *tra = TransA_Array[ig] == CblasNoTrans
? "n"
: (TransA_Array[ig] == CblasConjTrans ? "c" : "t");
const char *trb = TransB_Array[ig] == CblasNoTrans
? "n"
: (TransB_Array[ig] == CblasConjTrans ? "c" : "t");
const char *tra =
TransA_Array[ig] == CblasNoTrans
? "n"
: (TransA_Array[ig] == CblasConjTrans
? "c"
: (TransA_Array[ig] == CblasTrans ? "t" : "x"));
const char *trb =
TransB_Array[ig] == CblasNoTrans
? "n"
: (TransB_Array[ig] == CblasConjTrans
? "c"
: (TransB_Array[ig] == CblasTrans ? "t" : "x"));
const MKL_INT m = M_Array[ig], n = N_Array[ig], k = K_Array[ig];
const FL alpha = alpha_Array[ig] * scale, beta = beta_Array[ig];
const MKL_INT lda = lda_Array[ig], ldb = ldb_Array[ig], ldc = ldc_Array[ig];
Expand Down Expand Up @@ -279,11 +289,16 @@ template <typename FL> struct BatchGEMM {
void xgemm_group(uint8_t conja, uint8_t conjb, MKL_INT m, MKL_INT n,
MKL_INT k, FL alpha, MKL_INT lda, MKL_INT ldb, FL beta,
MKL_INT ldc, MKL_INT gc) {
ta.push_back(conja == 3 ? CblasConjTrans
: (conja ? CblasTrans : CblasNoTrans));
tb.push_back(conjb == 3 ? CblasConjTrans
: (conjb ? CblasTrans : CblasNoTrans));
assert(lda >= (conja ? m : k) && ldb >= (conjb ? k : n) && ldc >= n);
ta.push_back(conja == 3
? CblasConjTrans
: (conja == 2 ? (CBLAS_TRANSPOSE)(CblasConjTrans + 1)
: (conja ? CblasTrans : CblasNoTrans)));
tb.push_back(conjb == 3
? CblasConjTrans
: (conjb == 2 ? (CBLAS_TRANSPOSE)(CblasConjTrans + 1)
: (conjb ? CblasTrans : CblasNoTrans)));
assert(lda >= ((conja & 1) ? m : k) && ldb >= ((conjb & 1) ? k : n) &&
ldc >= n);
this->m.push_back(m), this->n.push_back(n), this->k.push_back(k);
this->alpha.push_back(alpha), this->beta.push_back(beta);
this->lda.push_back(lda), this->ldb.push_back(ldb),
Expand Down Expand Up @@ -313,25 +328,32 @@ template <typename FL> struct BatchGEMM {
// [c] = [a] x [b]
void multiply(const GMatrix<FL> &a, uint8_t conja, const GMatrix<FL> &b,
uint8_t conjb, const GMatrix<FL> &c, FL scale, FL cfactor) {
#ifndef _HAS_BLIS
assert(conja != 2 && conjb != 2);
this->xgemm(conja, conjb, c.m, conjb ? b.m : b.n, conjb ? b.n : b.m,
scale, a.data, a.n, b.data, b.n, cfactor, c.data, c.n);
#endif
this->xgemm(conja, conjb, c.m, (conjb & 1) ? b.m : b.n,
(conjb & 1) ? b.n : b.m, scale, a.data, a.n, b.data, b.n,
cfactor, c.data, c.n);
}
// Execute DGEMM operation groups from index ii to ii + nn
void perform(MKL_INT ii = 0, MKL_INT kk = 0, MKL_INT nn = 0) {
if (nn != 0 || gp.size() != 0) {
#ifndef _HAS_BLIS
if (threading->type & ThreadingTypes::Quanta)
#endif
threaded_xgemm_batch<FL>(
layout, &ta[ii], &tb[ii], &m[ii], &n[ii], &k[ii],
&alpha[ii], &a[kk], &lda[ii], &b[kk], &ldb[ii], &beta[ii],
&c[kk], &ldc[ii], nn == 0 ? (MKL_INT)gp.size() : nn,
&gp[ii]);
#ifndef _HAS_BLIS
else
cblas_xgemm_batch<FL>(
layout, &ta[ii], &tb[ii], &m[ii], &n[ii], &k[ii],
&alpha[ii], &a[kk], &lda[ii], &b[kk], &ldb[ii], &beta[ii],
&c[kk], &ldc[ii], nn == 0 ? (MKL_INT)gp.size() : nn,
&gp[ii]);
#endif
}
}
inline void perform_single(MKL_INT ii, const FL *a, const FL *b, FL *c,
Expand Down Expand Up @@ -563,8 +585,11 @@ struct AdvancedGEMM<FL, typename enable_if<is_complex<FL>::value>::type> {
const GMatrix<FL> &a, uint8_t conja,
const GMatrix<FL> &b, uint8_t conjb,
const GMatrix<FL> &c, FL scale, FL cfactor) {
#ifndef _HAS_BLIS
if (conja != 2 && conjb != 2)
#endif
batch->multiply(a, conja, b, conjb, c, scale, cfactor);
#ifndef _HAS_BLIS
else if (conja == 2 && conjb != 2) {
batch->xgemm_group(3, conjb, 1, (conjb & 1) ? b.m : b.n,
(conjb & 1) ? b.n : b.m, scale, 1, b.n, cfactor,
Expand All @@ -580,6 +605,7 @@ struct AdvancedGEMM<FL, typename enable_if<is_complex<FL>::value>::type> {
} else {
assert(false);
}
#endif
}
// [c] = [a] * (scalar b) or [c] = (scalar a) * [b] or [c] = [a] \otimes [b]
static void tensor_product(const shared_ptr<BatchGEMM<FL>> &batch,
Expand Down Expand Up @@ -807,7 +833,11 @@ struct AdvancedGEMM<FL, typename enable_if<is_complex<FL>::value>::type> {
}
static void post_three_rotate(const shared_ptr<BatchGEMM<FL>> &batch,
uint8_t x) {
#ifndef _HAS_BLIS
batch->acidxs.push_back(x);
#else
batch->acidxs.push_back(0);
#endif
}
};

Expand Down
Loading

0 comments on commit 0539ca8

Please sign in to comment.