Skip to content

Commit

Permalink
[xegpu][spirv] Add xegpu.simt to spirv JoinMatrixINTEL lowering & and…
Browse files Browse the repository at this point in the history
… E2E XeGPU.SIMT GEMM test case

Supported op:

xegpu.create_nd_descriptor
xegpu.update_nd_offset
xegpu.load_nd
xegpu.store_nd
xegpu.dpas
Add an end-to-end GEMM test case for XeGPU.SIMT

GEMM parameters in the test case:
Matrix A = 1024x1024xf16
Matrix B = 1024x1024xf16
Matrix C = 1024x1024xf32
  • Loading branch information
mshahneo authored and silee2 committed Nov 21, 2023
1 parent c7f4448 commit 5f21589
Show file tree
Hide file tree
Showing 20 changed files with 782 additions and 46 deletions.
6 changes: 5 additions & 1 deletion include/imex/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ memref, arith and math.
let constructor = "imex::createConvertGPUXToSPIRVPass()";
let dependentDialects = ["::mlir::spirv::SPIRVDialect"];
let options = [
Option<"enableVCIntrinsic", "enable-vc-intrinsic","bool", "true",
Option<"enableJointMatrix", "enable-joint-matrix","bool", "false",
"Enable XeGPU SIMT mode Ops lowered to JointMatrix based Ops">,
Option<"enableGenISAIntrinsic", "enable-genisa-intrinsic","bool", "false",
"Enable XeGPU SIMT mode Ops lowered to JointMatrix based Ops">,
Option<"enableVCIntrinsic", "enable-vc-intrinsic","bool", "false",
"Enable XeGPU Ops lowered to intel vc Intrinsics">
];
}
Expand Down
3 changes: 3 additions & 0 deletions include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ void populateXeGPUToVCIntrinsicsPatterns(
// XeGPU to genISA Intrinsics pattern
void populateXeGPUToGenISAPatterns(mlir::SPIRVTypeConverter &typeConverter,
mlir::RewritePatternSet &patterns);
// XeGPU to JointMatrix pattern
void populateXeGPUToJointMatrixPatterns(mlir::SPIRVTypeConverter &typeConverter,
mlir::RewritePatternSet &patterns);
} // namespace imex

#endif // IMEX_CONVERSION_XEGPUTOSPIRV_H
109 changes: 80 additions & 29 deletions lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,35 +156,79 @@ void GPUXToSPIRVPass::runOnOperation() {
eraseOp->erase();
}
target->addIllegalDialect<imex::xegpu::XeGPUDialect>();
typeConverter.addConversion([&](xegpu::NbarrierType type) -> ::mlir::Type {
auto i32Type = ::mlir::IntegerType::get(context, 32);
return mlir::VectorType::get(8, i32Type);
});
typeConverter.addConversion(
[&](xegpu::TensorDescType type) -> ::mlir::Type {
auto i32Type = ::mlir::IntegerType::get(context, 32);
return ::mlir::VectorType::get(8, i32Type);
});
typeConverter.addConversion([&](::mlir::VectorType type) -> ::mlir::Type {
unsigned rank = type.getRank();
auto elemType = type.getElementType();
if (rank < 1)
return type;
else {
// load2d/store2d is vnni format with 3 dims
if (rank == 3 && elemType.getIntOrFloatBitWidth() < 32) {
elemType = ::mlir::IntegerType::get(context, 32);
rank--;
// Only one of the following options should be enabled.
if ((this->enableVCIntrinsic && this->enableGenISAIntrinsic) ||
(this->enableVCIntrinsic && this->enableJointMatrix) ||
(this->enableGenISAIntrinsic && this->enableJointMatrix))
return signalPassFailure();
if (this->enableJointMatrix) {
// Tensor descriptor conversion pattern for SIMT JointMatrix
typeConverter.addConversion(
[&](xegpu::TensorDescType type) -> ::mlir::spirv::StructType {
llvm::SmallVector<::mlir::Type, 4> memberTypes;
auto i64Type = ::mlir::IntegerType::get(context, 64);
// Default storage class is spirv::StorageClass::CrossWorkgroup
auto spirvStorageClass =
::mlir::spirv::StorageClass::CrossWorkgroup;
if (type.getMemoryScope() == xegpu::MemoryScope::SLM)
spirvStorageClass = ::mlir::spirv::StorageClass::Workgroup;
auto baseAddressType = ::mlir::spirv::PointerType::get(
type.getElementType(), spirvStorageClass);
memberTypes.push_back(baseAddressType);
memberTypes.push_back(i64Type);

for (int i = 0; i < type.getRank(); i++) {
memberTypes.push_back(i64Type);
}
return ::mlir::spirv::StructType::get(memberTypes);
});
typeConverter.addConversion([&](::mlir::VectorType type) -> ::mlir::Type {
unsigned rank = type.getRank();
auto elemType = type.getElementType();
if (rank < 1)
return type;
else {
unsigned sum = 1;
for (unsigned i = 0; i < rank; i++) {
sum *= type.getShape()[i];
}
if (llvm::isa<mlir::IndexType>(elemType))
elemType = ::mlir::IntegerType::get(context, 64);
return ::mlir::VectorType::get(sum, elemType);
}
unsigned sum = 1;
for (unsigned i = 0; i < rank; i++) {
sum *= type.getShape()[i];
});
} else {
typeConverter.addConversion(
[&](xegpu::TensorDescType type) -> ::mlir::Type {
auto i32Type = ::mlir::IntegerType::get(context, 32);
return ::mlir::VectorType::get(8, i32Type);
});
typeConverter.addConversion([&](::mlir::VectorType type) -> ::mlir::Type {
unsigned rank = type.getRank();
auto elemType = type.getElementType();
if (rank < 1)
return type;
else {
// load2d/store2d is vnni format with 3 dims
if (rank == 3 && elemType.getIntOrFloatBitWidth() < 32) {
elemType = ::mlir::IntegerType::get(context, 32);
rank--;
}
unsigned sum = 1;
for (unsigned i = 0; i < rank; i++) {
sum *= type.getShape()[i];
}
if (llvm::isa<mlir::IndexType>(elemType))
elemType = ::mlir::IntegerType::get(context, 64);
return ::mlir::VectorType::get(sum, elemType);
}
if (llvm::isa<mlir::IndexType>(elemType))
elemType = ::mlir::IntegerType::get(context, 64);
return ::mlir::VectorType::get(sum, elemType);
}
});
});
typeConverter.addConversion(
[&](xegpu::NbarrierType type) -> ::mlir::Type {
auto i32Type = ::mlir::IntegerType::get(context, 32);
return mlir::VectorType::get(8, i32Type);
});
}

//------- Upstream Conversion------------
mlir::populateGPUToSPIRVPatterns(typeConverter, patterns);
Expand All @@ -200,9 +244,16 @@ void GPUXToSPIRVPass::runOnOperation() {
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
if (this->enableVCIntrinsic)
imex::populateXeGPUToVCIntrinsicsPatterns(typeConverter, patterns);
else
else if (this->enableJointMatrix)
imex::populateXeGPUToJointMatrixPatterns(typeConverter, patterns);
else if (this->enableGenISAIntrinsic)
imex::populateXeGPUToGenISAPatterns(typeConverter, patterns);

else
module.emitOpError(
"'-imex-convert-gpu-to-spirv' pass must be run with one of the "
"following options to be 'true': "
"'enable-vc-intrinsic', 'enable-joint-matrix', "
"'enable-genisa-intrinsic'");
if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
return signalPassFailure();
}
Expand Down
Loading

0 comments on commit 5f21589

Please sign in to comment.