From a082e6229a173ba1e45cd2957fb2be343d876153 Mon Sep 17 00:00:00 2001 From: "Chang, Liangliang" Date: Wed, 15 Nov 2023 14:55:34 -0800 Subject: [PATCH] XeGPU integration tests. --- lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp | 18 +-- .../XeTileToXeGPU/SCFOpConversion.cpp | 2 +- .../XeTileToXeGPU/XeTileOpConversion.cpp | 2 +- .../gemm_1024x1016x1016_f16_f16_f32.mlir | 111 ++++++++++++++++++ test/Integration/Dialect/XeGPU/lit.local.cfg | 9 +- .../Dialect/XeGPU/load2d-padding-f32.mlir | 66 +++++++++++ .../Dialect/XeGPU/load2d-padding.mlir | 73 ++++++++++++ .../Dialect/XeGPU/load2d_dpas_store2d.mlir | 97 +++++++++++++++ .../Dialect/XeGPU/xegpu-to-llvm.pp | 7 ++ 9 files changed, 368 insertions(+), 17 deletions(-) create mode 100644 test/Integration/Dialect/XeGPU/gemm_1024x1016x1016_f16_f16_f32.mlir create mode 100644 test/Integration/Dialect/XeGPU/load2d-padding-f32.mlir create mode 100644 test/Integration/Dialect/XeGPU/load2d-padding.mlir create mode 100644 test/Integration/Dialect/XeGPU/load2d_dpas_store2d.mlir diff --git a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp index 4a5d5246e..1d8937796 100644 --- a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp +++ b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp @@ -219,10 +219,6 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op, } } -/// @brief -/// convert the tensor descriptor to [2xi64] which is of the format -/// -> [base pointer: i64, offsetX: i32, offsetY: i32] for 2D tensor desc -/// -> [base pointer: i64, unused] for 1D and scattered tensor desc class CreateNdDescToVCPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -498,7 +494,7 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern { } }; -std::optional findDescOp(mlir::Value val) { +xegpu::CreateNdDescOp findDescOp(mlir::Value val) { if (auto op = val.getDefiningOp()) { if (auto descOp = dyn_cast(op)) { return descOp; @@ -510,9 +506,9 @@ std::optional findDescOp(mlir::Value val) { auto forOp = cast(ownerOp); auto init = forOp.getInits()[arg.getArgNumber() - 1]; return findDescOp(init); - } - // Add more support - return std::nullopt; + } else { + assert(0 && "add more support"); + } } template @@ -606,12 +602,8 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern { } auto msg = createIntConstant(i32Type, rawSendMsg); // payload - // payload is v8i32 = [base:i64, surfaceWidth:i32, surfaceHeight:i32, - // surefacePitch:i32, offsetX:i32, offsetY:i32, blockInfo:i32] - // the base/surfaceInfo/blockInfo are staticly from the tensor desc - // while the offsetX/Y are dynamicly udpated auto insertPoint = rewriter.saveInsertionPoint(); - CreateNdDescOp createDescOp = *findDescOp(op.template getTensorDesc()); + CreateNdDescOp createDescOp = findDescOp(op.template getTensorDesc()); rewriter.setInsertionPointAfter(createDescOp); auto v8i32 = VectorType::get(8, i32Type); auto v4i64 = VectorType::get(4, i64Type); diff --git a/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp index 24350b6b5..8b83ca066 100644 --- a/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp @@ -100,7 +100,7 @@ bool isLegalSCFOp(mlir::Operation *op) { if (llvm::isa(op)) { auto yieldOp = llvm::cast(op); - for (const auto &arg : yieldOp.getResults()) { + for (auto arg : yieldOp.getResults()) { auto type = arg.getType(); result &= !type.isa(); if (type.isa()) diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index 2fe695760..2e5fc75bf 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -318,7 +318,7 @@ struct SgUpdateTileOffsetOpPattern auto tiles = adaptor.getTile(); llvm::SmallVector xegpuOps; - for (const auto &tile : tiles) { + for (auto tile : tiles) { auto xegpuTile = rewriter.create( op.getLoc(), tile.getType(), tile, mlir::ValueRange{offsetX, offsetY}, imex::xegpu::Mode::VC); diff --git a/test/Integration/Dialect/XeGPU/gemm_1024x1016x1016_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/gemm_1024x1016x1016_f16_f16_f32.mlir new file mode 100644 index 000000000..90826dec7 --- /dev/null +++ b/test/Integration/Dialect/XeGPU/gemm_1024x1016x1016_f16_f16_f32.mlir @@ -0,0 +1,111 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" @__constant_1024x1016xf16 : memref<1024x1016xf16> = dense<1.0> + memref.global "private" @__constant_1016x1016xf16_ : memref<1016x1016xf16> = dense<1.0> + memref.global "private" @__constant_1024x1016xf32 : memref<1024x1016xf32> = dense<0.0> + func.func @test(%arg0: memref<1024x1016xf16>, %arg1: memref<1016x1016xf16>) -> memref<1024x1016xf32> attributes {llvm.emit_c_interface} { + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %memref = gpu.alloc host_shared () : memref<1024x1016xf16> + memref.copy %arg0, %memref : memref<1024x1016xf16> to memref<1024x1016xf16> + %memref_0 = gpu.alloc host_shared () : memref<1016x1016xf16> + memref.copy %arg1, %memref_0 : memref<1016x1016xf16> to memref<1016x1016xf16> + %memref_1 = gpu.alloc host_shared () : memref<1024x1016xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1016xf16>, %memref_0 : memref<1016x1016xf16>, %memref_1 : memref<1024x1016xf32>) + gpu.dealloc %memref : memref<1024x1016xf16> + gpu.dealloc %memref_0 : memref<1016x1016xf16> + return %memref_1 : memref<1024x1016xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<1024x1016xf16>, %arg1: memref<1016x1016xf16>, %arg2: memref<1024x1016xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %c1016 = arith.constant 1016 : index + %0 = gpu.block_id x + %1 = gpu.block_id y + %2 = arith.muli %0, %c8 : index + %3 = arith.muli %1, %c16 : index + %4 = xegpu.create_nd_tdesc %arg2[%2, %3] {mode = vc} : memref<1024x1016xf32> -> !xegpu.tensor_desc<8x16xf32> + %5 = xegpu.load_nd %4 {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + // each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block + %6 = scf.for %arg3 = %c0 to %c1016 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { + %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] {mode = vc} : memref<1024x1016xf16> -> !xegpu.tensor_desc<8x16xf16> + %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] {mode = vc} : memref<1016x1016xf16> -> !xegpu.tensor_desc<16x16xf16> + %9 = xegpu.load_nd %7 {mode = vc, vnni_axis = 1}: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %10 = xegpu.load_nd %8 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %11 = xegpu.dpas %9, %10, %arg4 : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %11 : vector<8x16xf32> + } + xegpu.store_nd %6, %4 {mode = vc}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %0 = memref.get_global @__constant_1024x1016xf16 : memref<1024x1016xf16> + %1 = memref.get_global @__constant_1016x1016xf16_ : memref<1016x1016xf16> + %ref = memref.get_global @__constant_1024x1016xf32 : memref<1024x1016xf32> + %init = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c1016 = arith.constant 1016 : index + // fill the top-left block 128x128 + // A matrix: row-major, start from 0.0, increase 0.01 per element + // B matrix: A matrix + 1.0 + scf.for %arg0 = %c0 to %c128 step %c1 { + scf.for %arg1 = %c0 to %c128 step %c1 { + %int0 = arith.index_cast %arg0 : index to i16 + %int1 = arith.index_cast %arg1 : index to i16 + %c128_i16 = arith.constant 128 : i16 + %idx0 = arith.muli %int0, %c128_i16 : i16 + %idx1 = arith.addi %int1, %idx0 : i16 + %fp = arith.uitofp %idx1 : i16 to f16 + %cst100 = arith.constant 100.0 : f16 + %val0 = arith.divf %fp, %cst100 : f16 + %cst1 = arith.constant 1.0 : f16 + %val1 = arith.addf %val0, %cst1 : f16 + memref.store %val0, %0[%arg0, %arg1] : memref<1024x1016xf16> + memref.store %val1, %1[%arg0, %arg1] : memref<1016x1016xf16> + } + } + // caculate the result C matrix + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1016 step %c1 { + %acc = memref.load %ref[%arg0, %arg1] : memref<1024x1016xf32> + %res = scf.for %arg2 = %c0 to %c1016 step %c1 iter_args(%arg3 = %acc) -> f32 { + %a = memref.load %0[%arg0, %arg2] : memref<1024x1016xf16> + %b = memref.load %1[%arg2, %arg1] : memref<1016x1016xf16> + %c = arith.mulf %a, %b : f16 + %cc = arith.extf %c : f16 to f32 + %ccc = arith.addf %cc, %arg3 : f32 + scf.yield %ccc : f32 + } + memref.store %res, %ref[%arg0, %arg1] : memref<1024x1016xf32> + } + } + + %2 = call @test(%0, %1) : (memref<1024x1016xf16>, memref<1016x1016xf16>) -> memref<1024x1016xf32> + %cast = memref.cast %2 : memref<1024x1016xf32> to memref<*xf32> + // call @printMemrefF32(%cast) : (memref<*xf32>) -> () + %cast_ref = memref.cast %ref : memref<1024x1016xf32> to memref<*xf32> + // call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeGPU/lit.local.cfg b/test/Integration/Dialect/XeGPU/lit.local.cfg index 50449795e..cf920ae42 100644 --- a/test/Integration/Dialect/XeGPU/lit.local.cfg +++ b/test/Integration/Dialect/XeGPU/lit.local.cfg @@ -1,3 +1,8 @@ -local_excludes = ['gemm_1024x1024xf16.mlir'] - +local_excludes = [ + 'gemm_1024x1024xf16.mlir', + 'gemm_1024x1016x1016_f16_f16_f32.mlir', + 'load2d_dpas_store2d.mlir', + 'load2d-padding-f32.mlir', + 'load2d-padding.mlir' + ] config.excludes.update(local_excludes) diff --git a/test/Integration/Dialect/XeGPU/load2d-padding-f32.mlir b/test/Integration/Dialect/XeGPU/load2d-padding-f32.mlir new file mode 100644 index 000000000..897659768 --- /dev/null +++ b/test/Integration/Dialect/XeGPU/load2d-padding-f32.mlir @@ -0,0 +1,66 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<1.0> + func.func @test(%arg0: memref<8x16xf32>,%arg1:index)attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %memref_0 = gpu.alloc host_shared () : memref<8x16xf32> + memref.copy %arg0, %memref_0 : memref<8x16xf32> to memref<8x16xf32> + %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> + gpu.launch_func @test_kernel::@test_padding_f32 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<8x16xf32>, %memref_1 : memref<8x16xf32>, %arg1:index) + %cast1 = memref.cast %memref_1 : memref<8x16xf32> to memref<*xf32> + call @printMemrefF32(%cast1) : (memref<*xf32>) -> () + return + } + + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + + gpu.func @test_padding_f32(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>, %arg3:index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[%arg3, %arg3] {mode = vc} + : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc} + : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %0 {mode = vc}: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + xegpu.store_nd %3,%1 {mode = vc}: vector<8x16xf32>,!xegpu.tensor_desc<8x16xf32> + gpu.return + } + + } + func.func @main() attributes {llvm.emit_c_interface} { + %0 = memref.get_global @__constant_8x16xf32 : memref<8x16xf32> + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + call @test(%0, %c1) : (memref<8x16xf32>, index)-> () + call @test(%0, %c2) : (memref<8x16xf32>, index)-> () + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} + + +// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} +// CHECK-SAME: rank = 2 offset = 0 sizes = [8, 16] strides = [16, 1] data = +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], +// CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] +// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} +// CHECK-SAME: rank = 2 offset = 0 sizes = [8, 16] strides = [16, 1] data = +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], +// CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], +// CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +// CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] diff --git a/test/Integration/Dialect/XeGPU/load2d-padding.mlir b/test/Integration/Dialect/XeGPU/load2d-padding.mlir new file mode 100644 index 000000000..c86282770 --- /dev/null +++ b/test/Integration/Dialect/XeGPU/load2d-padding.mlir @@ -0,0 +1,73 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @gemm attributes {gpu.container_module} { + // memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<1.0> + memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<1.0> + + memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<[ +[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], +[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]]> + + func.func @test(%arg0: memref<8x16xf16>,%arg3:index) -> memref<8x16xf16> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %memref = gpu.alloc host_shared () : memref<8x16xf16> + memref.copy %arg0, %memref : memref<8x16xf16> to memref<8x16xf16> + %memref_1 = gpu.alloc host_shared () : memref<8x16xf16> + gpu.launch_func @test_kernel::@test_padding blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_1 : memref<8x16xf16>, %arg3:index) + + gpu.dealloc %memref : memref<8x16xf16> + return %memref_1 : memref<8x16xf16> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_padding(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>,%arg3:index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[%arg3, %arg3] {mode = vc} + : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %2 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc} + : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %3 = xegpu.load_nd %0 {mode = vc,vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %7 = vector.shape_cast %3: vector<8x8x2xf16> to vector<8x16xf16> + xegpu.store_nd %7,%2 {mode = vc}: vector<8x16xf16>,!xegpu.tensor_desc<8x16xf16> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %0 = memref.get_global @__constant_8x16xf16 : memref<8x16xf16> + %1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16> + %f32identity = memref.get_global @__constant_8x16xf16 : memref<8x16xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %2 = call @test(%0, %c1) : (memref<8x16xf16>, index) -> memref<8x16xf16> + %3 = call @test(%0, %c2) : (memref<8x16xf16>, index) -> memref<8x16xf16> + + %c7 = arith.constant 7 : index + %vector_0 = vector.load %2[%c7,%c0] :memref<8x16xf16>, vector<16xf16> +// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + vector.print %vector_0 : vector<16xf16> + + %vector_1 = vector.load %3[%c0,%c0] :memref<8x16xf16>, vector<16xf16> +// CHECK: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0 ) + vector.print %vector_1 : vector<16xf16> + return + } +} diff --git a/test/Integration/Dialect/XeGPU/load2d_dpas_store2d.mlir b/test/Integration/Dialect/XeGPU/load2d_dpas_store2d.mlir new file mode 100644 index 000000000..85f50177a --- /dev/null +++ b/test/Integration/Dialect/XeGPU/load2d_dpas_store2d.mlir @@ -0,0 +1,97 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" @__constant_8x16xf16 : memref<8x16xf16> = dense<1.0> + memref.global "private" @__constant_16x16xf16 : memref<16x16xf16> = dense<1.0> + memref.global "private" @__constant_16x16xf32 : memref<16x16xf32> = dense<0.0> + + func.func @test(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %memref = gpu.alloc host_shared () : memref<8x16xf16> + memref.copy %arg0, %memref : memref<8x16xf16> to memref<8x16xf16> + %memref_0 = gpu.alloc host_shared () : memref<16x16xf16> + memref.copy %arg1, %memref_0 : memref<16x16xf16> to memref<16x16xf16> + %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_0 : memref<16x16xf16>, %memref_1 : memref<8x16xf32>) + gpu.dealloc %memref : memref<8x16xf16> + gpu.dealloc %memref_0 : memref<16x16xf16> + return %memref_1 : memref<8x16xf32> + } + + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[0, 0] {mode = vc} + : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc} + : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xegpu.create_nd_tdesc %arg2[0, 0] {mode = vc} + : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %0 {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %4 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %5 = xegpu.load_nd %2 {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %6 = xegpu.dpas %3, %4, %5 : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.store_nd %6,%2 {mode = vc} : vector<8x16xf32>,!xegpu.tensor_desc<8x16xf32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %0 = memref.get_global @__constant_8x16xf16 : memref<8x16xf16> + %1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16> + %ref = memref.get_global @__constant_16x16xf32 : memref<16x16xf32> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %int0 = arith.index_cast %arg0 : index to i16 + %int1 = arith.index_cast %arg1 : index to i16 + %c16_i16 = arith.constant 16 : i16 + %idx0 = arith.muli %int0, %c16_i16 : i16 + %idx1 = arith.addi %int1, %idx0 : i16 + %fp = arith.uitofp %idx1 : i16 to f16 + %cst100 = arith.constant 1.0 : f16 + %val0 = arith.divf %fp, %cst100 : f16 + %cst1 = arith.constant 1.0 : f16 + %val1 = arith.addf %val0, %cst1 : f16 + memref.store %val0, %0[%arg0, %arg1] : memref<8x16xf16> + memref.store %val1, %1[%arg0, %arg1] : memref<16x16xf16> + } + } + // caculate the result C matrix + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %acc = memref.load %ref[%arg0, %arg1] : memref<16x16xf32> + %res = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %acc) -> f32 { + %a = memref.load %0[%arg0, %arg2] : memref<8x16xf16> + %b = memref.load %1[%arg2, %arg1] : memref<16x16xf16> + %c = arith.mulf %a, %b : f16 + %cc = arith.extf %c : f16 to f32 + %ccc = arith.addf %cc, %arg3 : f32 + scf.yield %ccc : f32 + } + memref.store %res, %ref[%arg0, %arg1] : memref<16x16xf32> + } + } + + %cast_ref = memref.cast %ref : memref<16x16xf32> to memref<*xf32> + %2 = call @test(%0, %1) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32> + %cast = memref.cast %2 : memref<8x16xf32> to memref<*xf32> + // call @printMemrefF32(%cast) : (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeGPU/xegpu-to-llvm.pp b/test/Integration/Dialect/XeGPU/xegpu-to-llvm.pp index 5a66cfc36..bc7826608 100644 --- a/test/Integration/Dialect/XeGPU/xegpu-to-llvm.pp +++ b/test/Integration/Dialect/XeGPU/xegpu-to-llvm.pp @@ -1,17 +1,24 @@ +// linalg dialect to gpu dialect lowering pipeline +// Ready for vulkan runner or narrow scope l0/sycl runner starting from GPU dialect. builtin.module( imex-convert-gpu-to-spirv spirv.module(spirv-lower-abi-attrs spirv-update-vce) func.func(llvm-request-c-wrappers) serialize-spirv + convert-vector-to-scf convert-gpu-to-gpux convert-scf-to-cf convert-cf-to-llvm + convert-vector-to-llvm + convert-index-to-llvm convert-arith-to-llvm convert-func-to-llvm convert-math-to-llvm convert-gpux-to-llvm + convert-index-to-llvm expand-strided-metadata lower-affine finalize-memref-to-llvm reconcile-unrealized-casts) +// End