Skip to content

Commit

Permalink
XeGPU integration tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Chang, Liangliang authored and silee2 committed Nov 15, 2023
1 parent 8d4ecbd commit a082e62
Show file tree
Hide file tree
Showing 9 changed files with 368 additions and 17 deletions.
18 changes: 5 additions & 13 deletions lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,6 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op,
}
}

/// @brief
/// convert the tensor descriptor to [2xi64] which is of the format
/// -> [base pointer: i64, offsetX: i32, offsetY: i32] for 2D tensor desc
/// -> [base pointer: i64, unused] for 1D and scattered tensor desc
class CreateNdDescToVCPattern : public OpConversionPattern<CreateNdDescOp> {
public:
using OpConversionPattern<CreateNdDescOp>::OpConversionPattern;
Expand Down Expand Up @@ -498,7 +494,7 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern<OpType> {
}
};

std::optional<xegpu::CreateNdDescOp> findDescOp(mlir::Value val) {
xegpu::CreateNdDescOp findDescOp(mlir::Value val) {
if (auto op = val.getDefiningOp()) {
if (auto descOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
return descOp;
Expand All @@ -510,9 +506,9 @@ std::optional<xegpu::CreateNdDescOp> findDescOp(mlir::Value val) {
auto forOp = cast<scf::ForOp>(ownerOp);
auto init = forOp.getInits()[arg.getArgNumber() - 1];
return findDescOp(init);
}
// Add more support
return std::nullopt;
} else {
assert(0 && "add more support");
}
}

template <typename OpType>
Expand Down Expand Up @@ -606,12 +602,8 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern<OpType> {
}
auto msg = createIntConstant(i32Type, rawSendMsg);
// payload
// payload is v8i32 = [base:i64, surfaceWidth:i32, surfaceHeight:i32,
// surefacePitch:i32, offsetX:i32, offsetY:i32, blockInfo:i32]
// the base/surfaceInfo/blockInfo are staticly from the tensor desc
// while the offsetX/Y are dynamicly udpated
auto insertPoint = rewriter.saveInsertionPoint();
CreateNdDescOp createDescOp = *findDescOp(op.template getTensorDesc());
CreateNdDescOp createDescOp = findDescOp(op.template getTensorDesc());
rewriter.setInsertionPointAfter(createDescOp);
auto v8i32 = VectorType::get(8, i32Type);
auto v4i64 = VectorType::get(4, i64Type);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ bool isLegalSCFOp(mlir::Operation *op) {

if (llvm::isa<mlir::scf::YieldOp>(op)) {
auto yieldOp = llvm::cast<mlir::scf::YieldOp>(op);
for (const auto &arg : yieldOp.getResults()) {
for (auto arg : yieldOp.getResults()) {
auto type = arg.getType();
result &= !type.isa<imex::xetile::TileType>();
if (type.isa<mlir::VectorType>())
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ struct SgUpdateTileOffsetOpPattern
auto tiles = adaptor.getTile();

llvm::SmallVector<mlir::Value> xegpuOps;
for (const auto &tile : tiles) {
for (auto tile : tiles) {
auto xegpuTile = rewriter.create<xegpu::UpdateNDOffsetOp>(
op.getLoc(), tile.getType(), tile, mlir::ValueRange{offsetX, offsetY},
imex::xegpu::Mode::VC);
Expand Down
111 changes: 111 additions & 0 deletions test/Integration/Dialect/XeGPU/gemm_1024x1016x1016_f16_f16_f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
module @gemm attributes {gpu.container_module} {
memref.global "private" @__constant_1024x1016xf16 : memref<1024x1016xf16> = dense<1.0>
memref.global "private" @__constant_1016x1016xf16_ : memref<1016x1016xf16> = dense<1.0>
memref.global "private" @__constant_1024x1016xf32 : memref<1024x1016xf32> = dense<0.0>
func.func @test(%arg0: memref<1024x1016xf16>, %arg1: memref<1016x1016xf16>) -> memref<1024x1016xf32> attributes {llvm.emit_c_interface} {
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%memref = gpu.alloc host_shared () : memref<1024x1016xf16>
memref.copy %arg0, %memref : memref<1024x1016xf16> to memref<1024x1016xf16>
%memref_0 = gpu.alloc host_shared () : memref<1016x1016xf16>
memref.copy %arg1, %memref_0 : memref<1016x1016xf16> to memref<1016x1016xf16>
%memref_1 = gpu.alloc host_shared () : memref<1024x1016xf32>
gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1016xf16>, %memref_0 : memref<1016x1016xf16>, %memref_1 : memref<1024x1016xf32>)
gpu.dealloc %memref : memref<1024x1016xf16>
gpu.dealloc %memref_0 : memref<1016x1016xf16>
return %memref_1 : memref<1024x1016xf32>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%arg0: memref<1024x1016xf16>, %arg1: memref<1016x1016xf16>, %arg2: memref<1024x1016xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 128, 64, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c8 = arith.constant 8 : index
%c1024 = arith.constant 1024 : index
%c1016 = arith.constant 1016 : index
%0 = gpu.block_id x
%1 = gpu.block_id y
%2 = arith.muli %0, %c8 : index
%3 = arith.muli %1, %c16 : index
%4 = xegpu.create_nd_tdesc %arg2[%2, %3] {mode = vc} : memref<1024x1016xf32> -> !xegpu.tensor_desc<8x16xf32>
%5 = xegpu.load_nd %4 {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
// each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block
%6 = scf.for %arg3 = %c0 to %c1016 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) {
%7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] {mode = vc} : memref<1024x1016xf16> -> !xegpu.tensor_desc<8x16xf16>
%8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] {mode = vc} : memref<1016x1016xf16> -> !xegpu.tensor_desc<16x16xf16>
%9 = xegpu.load_nd %7 {mode = vc, vnni_axis = 1}: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16>
%10 = xegpu.load_nd %8 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
%11 = xegpu.dpas %9, %10, %arg4 : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32>
scf.yield %11 : vector<8x16xf32>
}
xegpu.store_nd %6, %4 {mode = vc}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
func.func @main() attributes {llvm.emit_c_interface} {
%0 = memref.get_global @__constant_1024x1016xf16 : memref<1024x1016xf16>
%1 = memref.get_global @__constant_1016x1016xf16_ : memref<1016x1016xf16>
%ref = memref.get_global @__constant_1024x1016xf32 : memref<1024x1016xf32>
%init = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c128 = arith.constant 128 : index
%c1024 = arith.constant 1024 : index
%c1016 = arith.constant 1016 : index
// fill the top-left block 128x128
// A matrix: row-major, start from 0.0, increase 0.01 per element
// B matrix: A matrix + 1.0
scf.for %arg0 = %c0 to %c128 step %c1 {
scf.for %arg1 = %c0 to %c128 step %c1 {
%int0 = arith.index_cast %arg0 : index to i16
%int1 = arith.index_cast %arg1 : index to i16
%c128_i16 = arith.constant 128 : i16
%idx0 = arith.muli %int0, %c128_i16 : i16
%idx1 = arith.addi %int1, %idx0 : i16
%fp = arith.uitofp %idx1 : i16 to f16
%cst100 = arith.constant 100.0 : f16
%val0 = arith.divf %fp, %cst100 : f16
%cst1 = arith.constant 1.0 : f16
%val1 = arith.addf %val0, %cst1 : f16
memref.store %val0, %0[%arg0, %arg1] : memref<1024x1016xf16>
memref.store %val1, %1[%arg0, %arg1] : memref<1016x1016xf16>
}
}
// caculate the result C matrix
scf.for %arg0 = %c0 to %c1024 step %c1 {
scf.for %arg1 = %c0 to %c1016 step %c1 {
%acc = memref.load %ref[%arg0, %arg1] : memref<1024x1016xf32>
%res = scf.for %arg2 = %c0 to %c1016 step %c1 iter_args(%arg3 = %acc) -> f32 {
%a = memref.load %0[%arg0, %arg2] : memref<1024x1016xf16>
%b = memref.load %1[%arg2, %arg1] : memref<1016x1016xf16>
%c = arith.mulf %a, %b : f16
%cc = arith.extf %c : f16 to f32
%ccc = arith.addf %cc, %arg3 : f32
scf.yield %ccc : f32
}
memref.store %res, %ref[%arg0, %arg1] : memref<1024x1016xf32>
}
}

%2 = call @test(%0, %1) : (memref<1024x1016xf16>, memref<1016x1016xf16>) -> memref<1024x1016xf32>
%cast = memref.cast %2 : memref<1024x1016xf32> to memref<*xf32>
// call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
%cast_ref = memref.cast %ref : memref<1024x1016xf32> to memref<*xf32>
// call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> ()
// CHECK: [ALLCLOSE: TRUE]
call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> ()
return
}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
}
9 changes: 7 additions & 2 deletions test/Integration/Dialect/XeGPU/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
local_excludes = ['gemm_1024x1024xf16.mlir']

local_excludes = [
'gemm_1024x1024xf16.mlir',
'gemm_1024x1016x1016_f16_f16_f32.mlir',
'load2d_dpas_store2d.mlir',
'load2d-padding-f32.mlir',
'load2d-padding.mlir'
]
config.excludes.update(local_excludes)
66 changes: 66 additions & 0 deletions test/Integration/Dialect/XeGPU/load2d-padding-f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
module @gemm attributes {gpu.container_module} {
memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<1.0>
func.func @test(%arg0: memref<8x16xf32>,%arg1:index)attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%memref_0 = gpu.alloc host_shared () : memref<8x16xf32>
memref.copy %arg0, %memref_0 : memref<8x16xf32> to memref<8x16xf32>
%memref_1 = gpu.alloc host_shared () : memref<8x16xf32>
gpu.launch_func @test_kernel::@test_padding_f32 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<8x16xf32>, %memref_1 : memref<8x16xf32>, %arg1:index)
%cast1 = memref.cast %memref_1 : memref<8x16xf32> to memref<*xf32>
call @printMemrefF32(%cast1) : (memref<*xf32>) -> ()
return
}

gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {

gpu.func @test_padding_f32(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>, %arg3:index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = xegpu.create_nd_tdesc %arg0[%arg3, %arg3] {mode = vc}
: memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc}
: memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
%3 = xegpu.load_nd %0 {mode = vc}: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
xegpu.store_nd %3,%1 {mode = vc}: vector<8x16xf32>,!xegpu.tensor_desc<8x16xf32>
gpu.return
}

}
func.func @main() attributes {llvm.emit_c_interface} {
%0 = memref.get_global @__constant_8x16xf32 : memref<8x16xf32>
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
call @test(%0, %c1) : (memref<8x16xf32>, index)-> ()
call @test(%0, %c2) : (memref<8x16xf32>, index)-> ()
return
}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
}


// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
// CHECK-SAME: rank = 2 offset = 0 sizes = [8, 16] strides = [16, 1] data =
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
// CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
// CHECK-SAME: rank = 2 offset = 0 sizes = [8, 16] strides = [16, 1] data =
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
// CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
// CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
Loading

0 comments on commit a082e62

Please sign in to comment.