Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/sycl' into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
sarnex committed Dec 26, 2024
2 parents 437fe59 + ee6969f commit ae21aaf
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 38 deletions.
7 changes: 5 additions & 2 deletions compiler-rt/lib/builtins/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -850,9 +850,12 @@ else ()
if (CAN_TARGET_${arch})
cmake_push_check_state()
# TODO: we should probably make most of the checks in builtin-config depend on the target flags.
message(STATUS "Performing additional configure checks with target flags: ${TARGET_${arch}_CFLAGS}")
set(BUILTIN_CFLAGS_${arch} ${BUILTIN_CFLAGS})
list(APPEND CMAKE_REQUIRED_FLAGS ${TARGET_${arch}_CFLAGS} ${BUILTIN_CFLAGS_${arch}})
# CMAKE_REQUIRED_FLAGS must be a space separated string but unlike TARGET_${arch}_CFLAGS,
# BUILTIN_CFLAGS_${arch} is a CMake list, so we have to join it to create a valid command line.
list(JOIN BUILTIN_CFLAGS " " CMAKE_REQUIRED_FLAGS)
set(CMAKE_REQUIRED_FLAGS "${TARGET_${arch}_CFLAGS} ${BUILTIN_CFLAGS_${arch}}")
message(STATUS "Performing additional configure checks with target flags: ${CMAKE_REQUIRED_FLAGS}")
# For ARM archs, exclude any VFP builtins if VFP is not supported
if (${arch} MATCHES "^(arm|armhf|armv7|armv7s|armv7k|armv7m|armv7em|armv8m.main|armv8.1m.main)$")
string(REPLACE ";" " " _TARGET_${arch}_CFLAGS "${TARGET_${arch}_CFLAGS}")
Expand Down
132 changes: 113 additions & 19 deletions llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,74 @@ namespace {
static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";

Type *getInnermostType(Type *Ty) {
while (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
Ty = ArrayTy->getElementType();
return Ty;
}

Type *replaceInnermostType(Type *Ty, Type *NewInnermostTy) {
if (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
return ArrayType::get(
replaceInnermostType(ArrayTy->getElementType(), NewInnermostTy),
ArrayTy->getNumElements());
return NewInnermostTy;
}

// This function is a copy of stripPointerCastsAndOffsets from Value.cpp,
// simplified and modified to strip non-zero GEP indices as well and also
// find nearest GEP instruction.
Value *stripPointerCastsAndOffsets(Value *V, bool StopOnGEP = false) {
if (!V->getType()->isPointerTy())
return V;

// Even though we don't look through PHI nodes, we could be called on an
// instruction in an unreachable block, which may be on a cycle.
SmallPtrSet<Value *, 4> Visited;

Visited.insert(V);
do {
if (auto *GEP = dyn_cast<GEPOperator>(V)) {
if (StopOnGEP && isa<GetElementPtrInst>(GEP))
return V;
V = GEP->getPointerOperand();
} else if (auto *BC = dyn_cast<BitCastOperator>(V)) {
Value *NewV = BC->getOperand(0);
if (!NewV->getType()->isPointerTy())
return V;
V = NewV;
} else if (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V)) {
V = ASC->getOperand(0);
} else {
if (auto *Call = dyn_cast<CallBase>(V)) {
if (Value *RV = Call->getReturnedArgOperand()) {
V = RV;
// Strip the call instruction, since callee returns its RV
// argument as return value. So, we need to continue stripping.
continue;
}
}
return V;
}
assert(V->getType()->isPointerTy() && "Unexpected operand type!");
} while (Visited.insert(V).second);

return V;
}

TargetExtType *extractMatrixType(StructType *WrapperMatrixTy) {
if (!WrapperMatrixTy)
return nullptr;
TargetExtType *MatrixTy =
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));

if (!MatrixTy)
return nullptr;
if (MatrixTy->getName() != MATRIX_TYPE)
return nullptr;
return MatrixTy;
}

// This function finds all calls to __spirv_AccessChain function and transforms
// its users and operands to make LLVM IR more SPIR-V friendly.
bool transformAccessChain(Function *F) {
Expand Down Expand Up @@ -60,33 +128,59 @@ bool transformAccessChain(Function *F) {
// from sycl::joint_matrix class object if it's used in __spirv_AccessChain
// function call. It's necessary because otherwise OpAccessChain indices
// would be wrong.
Instruction *Ptr =
dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
Instruction *Ptr = dyn_cast<Instruction>(
stripPointerCastsAndOffsets(CI->getArgOperand(0)));
if (!Ptr || !isa<AllocaInst>(Ptr))
continue;
StructType *WrapperMatrixTy =
dyn_cast<StructType>(cast<AllocaInst>(Ptr)->getAllocatedType());
if (!WrapperMatrixTy)
continue;
TargetExtType *MatrixTy =
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));
if (!MatrixTy)

Type *AllocaTy = cast<AllocaInst>(Ptr)->getAllocatedType();
// It may happen that sycl::joint_matrix class object is wrapped into
// nested arrays. We need to find the innermost type to extract
if (StructType *WrapperMatrixTy =
dyn_cast<StructType>(getInnermostType(AllocaTy))) {
TargetExtType *MatrixTy = extractMatrixType(WrapperMatrixTy);
if (!MatrixTy)
continue;

AllocaInst *Alloca = nullptr;
{
IRBuilder Builder(CI);
IRBuilderBase::InsertPointGuard IG(Builder);
Builder.SetInsertPointPastAllocas(CI->getFunction());
Alloca = Builder.CreateAlloca(replaceInnermostType(AllocaTy, MatrixTy));
Alloca->takeName(Ptr);
}
Ptr->replaceAllUsesWith(Alloca);
Ptr->dropAllReferences();
Ptr->eraseFromParent();
ModuleChanged = true;
}

// In case spirv.CooperativeMatrixKHR is used in arrays, we also need to
// insert GEP to get pointer to target exention type and use it instead of
// pointer to sycl::joint_matrix class object when it is passed to
// __spirv_AccessChain
// First we check if the argument came from a GEP instruction
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(
stripPointerCastsAndOffsets(CI->getArgOperand(0), /*StopOnGEP=*/true));
if (!GEP)
continue;
StringRef Name = MatrixTy->getName();
if (Name != MATRIX_TYPE)

// Check if GEP return type is a pointer to sycl::joint_matrix class object
StructType *WrapperMatrixTy =
dyn_cast<StructType>(GEP->getResultElementType());
if (!extractMatrixType(WrapperMatrixTy))
continue;

AllocaInst *Alloca = nullptr;
// Insert GEP right before the __spirv_AccessChain call
{
IRBuilder Builder(CI);
IRBuilderBase::InsertPointGuard IG(Builder);
Builder.SetInsertPointPastAllocas(CI->getFunction());
Alloca = Builder.CreateAlloca(MatrixTy);
Value *NewGEP =
Builder.CreateInBoundsGEP(WrapperMatrixTy, CI->getArgOperand(0),
{Builder.getInt64(0), Builder.getInt32(0)});
CI->setArgOperand(0, NewGEP);
ModuleChanged = true;
}
Ptr->replaceAllUsesWith(Alloca);
Ptr->dropAllReferences();
Ptr->eraseFromParent();
ModuleChanged = true;
}
return ModuleChanged;
}
Expand Down
65 changes: 55 additions & 10 deletions llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,69 @@

; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s

; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4)
; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0)

; ModuleID = 'test.bc'
source_filename = "test.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
target triple = "spir64-unknown-unknown"

%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
%"struct.sycl::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
%"struct.sycl::_V1::long" = type { i64 }

define weak_odr dso_local spir_kernel void @test(i64 %ind) {
; CHECK-LABEL: define weak_odr dso_local spir_kernel void @test(
; CHECK-SAME: i64 [[IND:%.*]]) {

; non-matrix alloca not touched
; CHECK: [[NOT_MATR:%.*]] = alloca [2 x [4 x %"struct.sycl::_V1::long"]]
; both matrix-related allocas updated to use target extension types
; CHECK-NEXT: [[MATR:%.*]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
; CHECK-NEXT: [[MATR_ARR:%.*]] = alloca [2 x [4 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]

; CHECK-NEXT: [[ASCAST:%.*]] = addrspacecast ptr [[MATR]] to ptr addrspace(4)
; no gep inserted, since not needed
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST]], i64 noundef 0)

; CHECK: [[GEP:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr [[MATR_ARR]], i64 0, i64 [[IND]], i64 [[IND]]
; CHECK-NEXT: [[ASCAST_1:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
; CHECK-NEXT: [[ASCAST_2:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
; gep is inserted for each of the accesschain calls to extract target extension type
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_1]], i64 0, i32 0
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP2]], i64 noundef 0)
; CHECK: [[TMP5:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_2]], i64 0, i32 0
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP5]], i64 noundef 0)

; negative test - not touching non-matrix code
; CHECK: [[GEP_1:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr [[NOT_MATR]], i64 0, i64 [[IND]], i64 [[IND]]
; CHECK-NEXT: [[ASCAST_3:%.*]] = addrspacecast ptr [[GEP_1]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST_3]], i64 noundef 0)

define weak_odr dso_local spir_kernel void @test() {
entry:
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
%1 = addrspacecast ptr %0 to ptr addrspace(4)
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
; allocas
%matr = alloca %"struct.sycl::joint_matrix", align 8
%matr.arr = alloca [2 x [4 x %"struct.sycl::joint_matrix"]], align 8
%not.matr = alloca [2 x [4 x %"struct.sycl::_V1::long"]], align 8

; simple case
%ascast = addrspacecast ptr %matr to ptr addrspace(4)
%0 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast, i64 noundef 0)
%1 = load i8, ptr addrspace(4) %0

; gep with non-zero inidices and multiple access chains per 1 alloca
%gep = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr %matr.arr, i64 0, i64 %ind, i64 %ind
%ascast.1 = addrspacecast ptr %gep to ptr addrspace(4)
%ascast.2 = addrspacecast ptr %gep to ptr addrspace(4)
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.1, i64 noundef 0)
%3 = load i8, ptr addrspace(4) %2
%4 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.2, i64 noundef 0)
%5 = load i8, ptr addrspace(4) %4

; negative test - not touching non-matrix code
%gep.1 = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr %not.matr, i64 0, i64 %ind, i64 %ind
%ascast.3 = addrspacecast ptr %gep.1 to ptr addrspace(4)
%6 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.3, i64 noundef 0)
%7 = load i8, ptr addrspace(4) %6

ret void
}

declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef)
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef, i64 noundef)
14 changes: 7 additions & 7 deletions sycl/cmake/modules/UnifiedRuntimeTag.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# commit ea0f3a1f5f15f9af7bf40bd13669afeb9ada569c
# Merge: bb64b3e9f6d3 4a89e1c69a65
# Author: Martin Grant <martin.morrisongrant@codeplay.com>
# Date: Thu Dec 19 11:26:01 2024 +0000
# Merge pull request #2277 from igchor/cooperative_fix
# [Spec] fix urKernelSuggestMaxCooperativeGroupCountExp
set(UNIFIED_RUNTIME_TAG ea0f3a1f5f15f9af7bf40bd13669afeb9ada569c)
# commit 232e62f5221d565ec40d051d3c640b836ca91244
# Merge: 76a96238 59b37e3f
# Author: aarongreig <aaron.greig@codeplay.com>
# Date: Mon Dec 23 18:26:58 2024 +0000
# Merge pull request #2498 from Bensuo/fabio/fix_l0_old_loader_no_translate
# Update usage of zeCommandListImmediateAppendCommandListsExp to use dlsym
set(UNIFIED_RUNTIME_TAG 232e62f5221d565ec40d051d3c640b836ca91244)

0 comments on commit ae21aaf

Please sign in to comment.