From 05e7fbcbb184566a48ff156305d78f0016e1105b Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Mon, 20 Nov 2023 11:50:52 -0800 Subject: [PATCH] add XeGPU createDesc/load2d/store2d/dpas to spirv genISA lowering add a pass option to differentiate between vc-intrinsic and genIsa intrinsic lower load2d/store2d/dpas to corresponding genISA --- include/imex/Conversion/Passes.td | 4 +- .../Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h | 3 + lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp | 5 +- lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp | 278 +++++++++++++++++- ...atomic_basic.mlir => atomic_basic.vc.mlir} | 0 ...rrier_basic.mlir => barrier_basic.vc.mlir} | 0 test/Conversion/XeGPUToSPIRV/gemm_basic.mlir | 55 ++-- .../XeGPUToSPIRV/gemm_basic.vc.mlir | 59 ++++ ...mm_basic_1d.mlir => gemm_basic_1d.vc.mlir} | 0 ..._gather.mlir => gemm_basic_gather.vc.mlir} | 0 ...date_offset.mlir => update_offset.vc.mlir} | 0 11 files changed, 368 insertions(+), 36 deletions(-) rename test/Conversion/XeGPUToSPIRV/{atomic_basic.mlir => atomic_basic.vc.mlir} (100%) rename test/Conversion/XeGPUToSPIRV/{barrier_basic.mlir => barrier_basic.vc.mlir} (100%) create mode 100644 test/Conversion/XeGPUToSPIRV/gemm_basic.vc.mlir rename test/Conversion/XeGPUToSPIRV/{gemm_basic_1d.mlir => gemm_basic_1d.vc.mlir} (100%) rename test/Conversion/XeGPUToSPIRV/{gemm_basic_gather.mlir => gemm_basic_gather.vc.mlir} (100%) rename test/Conversion/XeGPUToSPIRV/{update_offset.mlir => update_offset.vc.mlir} (100%) diff --git a/include/imex/Conversion/Passes.td b/include/imex/Conversion/Passes.td index b5d43eed8..2766c91c5 100644 --- a/include/imex/Conversion/Passes.td +++ b/include/imex/Conversion/Passes.td @@ -251,8 +251,8 @@ memref, arith and math. let constructor = "imex::createConvertGPUXToSPIRVPass()"; let dependentDialects = ["::mlir::spirv::SPIRVDialect"]; let options = [ - Option<"enableSimtIntrinsic", "enable-simt-intrinsic","bool", "false", - "Enable XeGPU.simt Ops lowered to intel genISA simt Intrinsics"> + Option<"enableVCIntrinsic", "enable-vc-intrinsic","bool", "true", + "Enable XeGPU Ops lowered to intel vc Intrinsics"> ]; } diff --git a/include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h b/include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h index f048a0e02..91615dbad 100644 --- a/include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h +++ b/include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h @@ -27,6 +27,9 @@ namespace imex { // XeGPU to VC Intrinsics pattern void populateXeGPUToVCIntrinsicsPatterns( mlir::SPIRVTypeConverter &typeConverter, mlir::RewritePatternSet &patterns); +// XeGPU to genISA Intrinsics pattern +void populateXeGPUToGenISAPatterns(mlir::SPIRVTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); } // namespace imex #endif // IMEX_CONVERSION_XEGPUTOSPIRV_H diff --git a/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index c149f5aaa..a92005eb3 100644 --- a/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -198,7 +198,10 @@ void GPUXToSPIRVPass::runOnOperation() { mlir::populateSCFToSPIRVPatterns(typeConverter, scfToSpirvCtx, patterns); mlir::cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns); mlir::populateMathToSPIRVPatterns(typeConverter, patterns); - imex::populateXeGPUToVCIntrinsicsPatterns(typeConverter, patterns); + if (this->enableVCIntrinsic) + imex::populateXeGPUToVCIntrinsicsPatterns(typeConverter, patterns); + else + imex::populateXeGPUToGenISAPatterns(typeConverter, patterns); if (failed(applyFullConversion(gpuModule, *target, std::move(patterns)))) return signalPassFailure(); diff --git a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp index 7edb0ecda..67f29668b 100644 --- a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp +++ b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp @@ -200,7 +200,8 @@ unsigned encodeOpcode(xegpu::AtomicRMWKind kind) { } void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op, - std::string name, FunctionType funcType) { + std::string name, FunctionType funcType, + bool isVC = true) { auto funcAttr = StringAttr::get(rewriter.getContext(), name); Operation *found = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); if (!found) { @@ -215,7 +216,8 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op, auto linkage = spirv::LinkageAttributesAttr::get(rewriter.getContext(), nameAttr, linkageTypeAttr); func.setLinkageAttributesAttr(linkage); - func->setAttr("VectorComputeFunctionINTEL", rewriter.getUnitAttr()); + if (isVC) + func->setAttr("VectorComputeFunctionINTEL", rewriter.getUnitAttr()); } } @@ -1246,3 +1248,275 @@ void imex::populateXeGPUToVCIntrinsicsPatterns( LoadStorePrefetchNdToRawSend>( typeConverter, patterns.getContext()); } + +/// below is for XeGPU to SPIRV genISA Intrinsic + +/// @brief encodeVectorType(xxx, 8x8x2xf16, false) returns ["v64i32", 64xi32] +std::pair +encodeGenISAVectorType(ConversionPatternRewriter &rewriter, VectorType type, + bool use32bitData = true) { + auto elemType = type.getElementType(); + auto bitWidth = elemType.getIntOrFloatBitWidth(); + int size = type.getNumElements() * bitWidth / 16; + if (use32bitData) { + size /= 2; + } + std::string str = "v"; + str += std::to_string(size); + if (!use32bitData) { + str += "i16"; + elemType = rewriter.getI16Type(); + } else if (elemType == rewriter.getF32Type()) + str += "f32"; + else if (elemType == rewriter.getF16Type()) { + str += "i32"; + elemType = rewriter.getI32Type(); + } else + assert(0 && "add more support"); + auto newType = VectorType::get(size, elemType); + return std::make_pair(str, newType); +} + +class CreateNdDescToGenISA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(CreateNdDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + auto v4i32 = VectorType::get(4, i32Type); + auto v2i64 = VectorType::get(2, i64Type); + Value payLoad = rewriter.create(loc, v2i64); + auto createIntConstant = [&](Type type, unsigned value) { + auto attr = rewriter.getIntegerAttr(type, value); + return rewriter.create(loc, type, attr); + }; + auto base = rewriter.create(loc, i64Type, + adaptor.getSource()); + auto idx0 = createIntConstant(i32Type, 0); + payLoad = + rewriter.create(loc, payLoad, base, idx0); + auto tileType = op.getTensorDesc().getType(); + auto rank = tileType.getRank(); + if (rank == 2) { + payLoad = rewriter.create(loc, v4i32, payLoad); + auto createOffset = [&](unsigned idx) -> Value { + Value val; + if (ShapedType::isDynamic(op.getStaticOffsets()[idx])) { + val = op.getOffsets()[idx]; + val = rewriter.create(loc, i32Type, val); + } else { + val = createIntConstant(i32Type, op.getStaticOffsets()[idx]); + } + return val; + }; + auto offsetX = createOffset(1); + auto offsetY = createOffset(0); + auto idx2 = createIntConstant(i32Type, 2); + auto idx3 = createIntConstant(i32Type, 3); + payLoad = rewriter.create(loc, payLoad, + offsetX, idx2); + payLoad = rewriter.create(loc, payLoad, + offsetY, idx3); + payLoad = rewriter.create(loc, v2i64, payLoad); + } + rewriter.replaceOp(op, payLoad); + return success(); + } +}; + +template +class LoadStorePrefetchNdToGenISA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tileType = op.getTensorDesc().getType(); + int rank = tileType.getRank(); + assert(rank <= 2 && "only support 1d/2d load/store/prefetch for now"); + auto loc = op.getLoc(); + ::mlir::VectorType vecType; + std::string funcName; + constexpr bool isLoad = std::is_same_v; + constexpr bool isPrefetch = std::is_same_v; + if constexpr (isLoad) { + vecType = cast(op.getResult().getType()); + funcName = rank == 2 ? "llvm.genx.GenISA.LSC2DBlockRead." + : "llvm.genx.GenISA.LSCLoadBlock."; + } else if constexpr (isPrefetch) { + vecType = VectorType::get({8, 16}, rewriter.getF32Type()); + funcName = rank == 2 ? "llvm.genx.GenISA.LSC2DPrefetch." + : "llvm.genx.GenISA.LSCPrefetch"; + } else { + vecType = cast(op.getValue().getType()); + funcName = rank == 2 ? "llvm.genx.GenISA.LSC2DBlockWrite." + : "llvm.genx.GenISA.LSCStoreBlock"; + } + auto createIntConstant = [&](Type type, unsigned value) { + auto attr = rewriter.getIntegerAttr(type, value); + return rewriter.create(loc, type, attr); + }; + auto i1Type = rewriter.getI1Type(); + auto i8Type = rewriter.getI8Type(); + auto i32Type = rewriter.getI32Type(); + auto vnni = false; + auto transpose = false; + if constexpr (isLoad) { + auto vnniValue = op.getVnniAxis(); + vnni = vnniValue.has_value() && vnniValue.value() == 0 ? true : false; + auto transposeValue = op.getTranspose(); + transpose = transposeValue.has_value() && transposeValue.value()[0] == 1 + ? true + : false; + } + unsigned dataSize = vecType.getElementType().getIntOrFloatBitWidth(); + auto elemSize = createIntConstant(i8Type, dataSize); + auto trans = createIntConstant(i1Type, transpose ? 1 : 0); + // number of blocks(1 for now) + auto nBlks = createIntConstant(i8Type, 1); + auto tensorDesc = adaptor.getTensorDesc(); + auto idx0 = createIntConstant(i32Type, 0); + auto base = + rewriter.create(loc, tensorDesc, idx0); + auto [typeStr, newType] = encodeGenISAVectorType(rewriter, vecType, false); + SmallVector args; + if (rank == 2) { + auto blockWidth = tileType.getShape()[1]; + auto blockHeight = tileType.getShape()[0]; + auto blockW = createIntConstant(i32Type, blockWidth); + auto blockH = createIntConstant(i32Type, blockHeight); + auto transform = createIntConstant(i1Type, vnni ? 1 : 0); + // static memref for now + auto createDescOp = + op.getTensorDesc().template getDefiningOp(); + auto memType = cast(createDescOp.getSource().getType()); + unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth(); + auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1; + auto surfaceHeight = memType.getShape()[0] - 1; + // pitch = width for now + auto surfacePitch = surfaceWidth; + auto surfaceW = createIntConstant(i32Type, surfaceWidth); + auto surfaceH = createIntConstant(i32Type, surfaceHeight); + auto surfaceP = createIntConstant(i32Type, surfacePitch); + auto v4i32 = VectorType::get(4, i32Type); + tensorDesc = rewriter.create(loc, v4i32, tensorDesc); + auto idx2 = createIntConstant(i32Type, 2); + auto idx3 = createIntConstant(i32Type, 3); + auto offsetX = + rewriter.create(loc, tensorDesc, idx2); + auto offsetY = + rewriter.create(loc, tensorDesc, idx3); + args.assign({base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, + elemSize, blockW, blockH, nBlks, trans, transform}); + if constexpr (!isLoad && !isPrefetch) { + args.push_back(adaptor.getValue()); + } + } + if constexpr (isLoad) + funcName += typeStr; + else if constexpr (!isPrefetch) + funcName += "isVoid"; + if constexpr (isLoad) { + auto funcType = + rewriter.getFunctionType(ValueRange(args).getTypes(), newType); + Operation *opPtr = op; + lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true); + auto funcOp = + rewriter.create(loc, newType, funcName, args); + auto castTy = this->getTypeConverter()->convertType(op.getType()); + auto cast = + rewriter.create(loc, castTy, funcOp->getResult(0)); + rewriter.replaceOp(op, cast); + } else { + auto funcType = rewriter.getFunctionType(ValueRange(args).getTypes(), {}); + Operation *opPtr = op; + lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true); + rewriter.create(loc, TypeRange(), funcName, args); + rewriter.eraseOp(op); + } + return success(); + } +}; + +class DpasToGenISA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(DpasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + { + OpBuilder::InsertionGuard guard(rewriter); + auto func = op->getParentOfType(); + rewriter.setInsertionPointAfter(func); + rewriter.create( + op.getLoc(), func, spirv::ExecutionMode::SubgroupSize, 16); + } + auto i32Type = rewriter.getI32Type(); + auto createIntConstant = [&](Type type, unsigned value) { + auto attr = rewriter.getIntegerAttr(type, value); + return rewriter.create(loc, type, attr); + }; + auto encodePrecision = [&](Type type) -> uint8_t { + if (type == rewriter.getBF16Type()) + return 9; + else if (type == rewriter.getF16Type()) + return 10; + else if (type == rewriter.getTF32Type()) + return 12; + else { + assert(0 && "add more support"); + return 0; + } + }; + auto lType = op.getLhs().getType().cast(); + auto rType = op.getRhs().getType().cast(); + auto resultType = op.getResultType().cast(); + auto [lhsStr, lhsType] = encodeGenISAVectorType(rewriter, lType, false); + auto [rhsStr, rhsType] = encodeGenISAVectorType(rewriter, rType, false); + auto [newStr, newType] = encodeGenISAVectorType(rewriter, resultType); + auto lhs = + rewriter.create(loc, lhsType, adaptor.getLhs()); + auto rhs = + rewriter.create(loc, rhsType, adaptor.getRhs()); + uint8_t preca = encodePrecision(lType.getElementType()); + uint8_t precb = encodePrecision(rType.getElementType()); + auto precA = createIntConstant(i32Type, preca); + auto precB = createIntConstant(i32Type, precb); + // fixed for now + auto rc = createIntConstant(i32Type, 8); + auto sd = createIntConstant(i32Type, 8); + auto dpasW = createIntConstant(rewriter.getI1Type(), 0); + Value acc = op.getAcc() ? adaptor.getAcc() + : rewriter.create(loc, newType); + SmallVector args{acc, lhs, rhs, precA, precB, sd, rc, dpasW}; + std::string funcName = "llvm.genx.GenISA.sub.group.dpas."; + funcName += newStr; + funcName += "."; + funcName += newStr; + funcName += "."; + funcName += lhsStr; + funcName += "."; + funcName += rhsStr; + auto funcType = + rewriter.getFunctionType(ValueRange(args).getTypes(), newType); + Operation *opPtr = op; + lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true); + auto funcOp = + rewriter.create(loc, newType, funcName, args); + rewriter.replaceOp(op, funcOp); + return success(); + } +}; + +void imex::populateXeGPUToGenISAPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add, + LoadStorePrefetchNdToGenISA, + LoadStorePrefetchNdToGenISA>( + typeConverter, patterns.getContext()); +} diff --git a/test/Conversion/XeGPUToSPIRV/atomic_basic.mlir b/test/Conversion/XeGPUToSPIRV/atomic_basic.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/atomic_basic.mlir rename to test/Conversion/XeGPUToSPIRV/atomic_basic.vc.mlir diff --git a/test/Conversion/XeGPUToSPIRV/barrier_basic.mlir b/test/Conversion/XeGPUToSPIRV/barrier_basic.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/barrier_basic.mlir rename to test/Conversion/XeGPUToSPIRV/barrier_basic.vc.mlir diff --git a/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir b/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir index a4c12ec46..ac4d5e699 100644 --- a/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir +++ b/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir @@ -1,5 +1,8 @@ -// RUN: imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s -// RUN: IMEX_NOT_PREFER_RAWSEND=1 imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s --check-prefix=LSC +// RUN: imex-opt -imex-convert-gpu-to-spirv='enable-vc-intrinsic=false' %s | FileCheck %s + +#sg_map_fp16_a = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}> +#sg_map_fp16_b = #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +#sg_map_fp16_c = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}> module @gemm attributes {gpu.container_module} { memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<5.000000e-01> memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<1.099610e+00> @@ -15,35 +18,25 @@ module @gemm attributes {gpu.container_module} { gpu.dealloc %memref_0 : memref<16x16xf16> return %memref_1 : memref<8x16xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - // LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64 - // LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64 - // LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v64i32_i1_i64 - // LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v128i32_i1_i64 - // LSC: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32 - // LSC: spirv.FunctionCall @llvm_genx_lsc_store2d_stateless_i1_i64_v128f32 - // CHECK: %[[BASE:.*]] = spirv.ConvertPtrToU %arg0 : !spirv.ptr, CrossWorkgroup> to i64 - // CHECK: %[[BASE1:.*]] = spirv.VectorInsertDynamic %[[BASE]] - // CHECK: %[[BASE2:.*]] = spirv.Bitcast %[[BASE1]] - // CHECK: spirv.VectorInsertDynamic - // CHECK: spirv.VectorInsertDynamic - // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32 - // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32 - // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v64i32_i1_v8i32 - // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v128i32_i1_v8i32 - // CHECK: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32 - // CHECK: spirv.FunctionCall @llvm_genx_raw_sends2_noresult_i1_v8i32_v128f32 - %0 = xegpu.create_nd_tdesc %arg0[0, 0] {mode = vc} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> - %2 = xegpu.create_nd_tdesc %arg2[0, 0] {mode = vc} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xf16> - xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<16x16xf16> - - %3 = xegpu.load_nd %0 {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> - %4 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %5 = xegpu.dpas %3, %4 {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> - xegpu.store_nd %5, %2 {mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // CHECK: %[[a:.*]] = spirv.FunctionCall @llvm.genx.GenISA.LSC2DBlockRead.v8i16 + // CHECK: %[[a0:.*]] = spirv.Bitcast %[[a]] + // CHECK: %[[b:.*]] = spirv.FunctionCall @llvm.genx.GenISA.LSC2DBlockRead.v16i16 + // CHECK: %[[b0:.*]] = spirv.Bitcast %[[b]] + // CHECK: %[[A:.*]] = spirv.Bitcast %[[a0]] + // CHECK: %[[B:.*]] = spirv.Bitcast %[[b0]] + // CHECK: %[[C:.*]] = spirv.FunctionCall @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v16i16 + // CHECK-SAME: %[[A]], %[[B]] + // CHECK: spirv.FunctionCall @llvm.genx.GenISA.LSC2DBlockWrite.isVoid + // CHECK-SAME: %[[C]] + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> + %2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> + %3 = xegpu.load_nd %0 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> -> vector<4x1x2xf16> + %4 = xegpu.load_nd %1 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> -> vector<8x1x2xf16> + %5 = xegpu.dpas %3, %4 : vector<4x1x2xf16>, vector<8x1x2xf16> -> vector<8x1xf32> + xegpu.store_nd %5, %2 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> gpu.return } } diff --git a/test/Conversion/XeGPUToSPIRV/gemm_basic.vc.mlir b/test/Conversion/XeGPUToSPIRV/gemm_basic.vc.mlir new file mode 100644 index 000000000..a4c12ec46 --- /dev/null +++ b/test/Conversion/XeGPUToSPIRV/gemm_basic.vc.mlir @@ -0,0 +1,59 @@ +// RUN: imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s +// RUN: IMEX_NOT_PREFER_RAWSEND=1 imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s --check-prefix=LSC +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<5.000000e-01> + memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<1.099610e+00> + func.func @test(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %memref = gpu.alloc host_shared () : memref<8x16xf16> + memref.copy %arg0, %memref : memref<8x16xf16> to memref<8x16xf16> + %memref_0 = gpu.alloc host_shared () : memref<16x16xf16> + memref.copy %arg1, %memref_0 : memref<16x16xf16> to memref<16x16xf16> + %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_0 : memref<16x16xf16>, %memref_1 : memref<8x16xf32>) + gpu.dealloc %memref : memref<8x16xf16> + gpu.dealloc %memref_0 : memref<16x16xf16> + return %memref_1 : memref<8x16xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64 + // LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64 + // LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v64i32_i1_i64 + // LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v128i32_i1_i64 + // LSC: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32 + // LSC: spirv.FunctionCall @llvm_genx_lsc_store2d_stateless_i1_i64_v128f32 + // CHECK: %[[BASE:.*]] = spirv.ConvertPtrToU %arg0 : !spirv.ptr, CrossWorkgroup> to i64 + // CHECK: %[[BASE1:.*]] = spirv.VectorInsertDynamic %[[BASE]] + // CHECK: %[[BASE2:.*]] = spirv.Bitcast %[[BASE1]] + // CHECK: spirv.VectorInsertDynamic + // CHECK: spirv.VectorInsertDynamic + // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32 + // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32 + // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v64i32_i1_v8i32 + // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v128i32_i1_v8i32 + // CHECK: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32 + // CHECK: spirv.FunctionCall @llvm_genx_raw_sends2_noresult_i1_v8i32_v128f32 + %0 = xegpu.create_nd_tdesc %arg0[0, 0] {mode = vc} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xegpu.create_nd_tdesc %arg2[0, 0] {mode = vc} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<16x16xf16> + + %3 = xegpu.load_nd %0 {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %4 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %5 = xegpu.dpas %3, %4 {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + xegpu.store_nd %5, %2 {mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %0 = memref.get_global @__constant_8x16xf16 : memref<8x16xf16> + %1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16> + %2 = call @test(%0, %1) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32> + %cast = memref.cast %2 : memref<8x16xf32> to memref<*xf32> + //call @printMemrefF32(%cast) : (memref<*xf32>) -> () + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Conversion/XeGPUToSPIRV/gemm_basic_1d.mlir b/test/Conversion/XeGPUToSPIRV/gemm_basic_1d.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/gemm_basic_1d.mlir rename to test/Conversion/XeGPUToSPIRV/gemm_basic_1d.vc.mlir diff --git a/test/Conversion/XeGPUToSPIRV/gemm_basic_gather.mlir b/test/Conversion/XeGPUToSPIRV/gemm_basic_gather.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/gemm_basic_gather.mlir rename to test/Conversion/XeGPUToSPIRV/gemm_basic_gather.vc.mlir diff --git a/test/Conversion/XeGPUToSPIRV/update_offset.mlir b/test/Conversion/XeGPUToSPIRV/update_offset.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/update_offset.mlir rename to test/Conversion/XeGPUToSPIRV/update_offset.vc.mlir