diff --git a/include/imex/Conversion/Passes.td b/include/imex/Conversion/Passes.td index b5d43eed8..2766c91c5 100644 --- a/include/imex/Conversion/Passes.td +++ b/include/imex/Conversion/Passes.td @@ -251,8 +251,8 @@ memref, arith and math. let constructor = "imex::createConvertGPUXToSPIRVPass()"; let dependentDialects = ["::mlir::spirv::SPIRVDialect"]; let options = [ - Option<"enableSimtIntrinsic", "enable-simt-intrinsic","bool", "false", - "Enable XeGPU.simt Ops lowered to intel genISA simt Intrinsics"> + Option<"enableVCIntrinsic", "enable-vc-intrinsic","bool", "true", + "Enable XeGPU Ops lowered to intel vc Intrinsics"> ]; } diff --git a/include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h b/include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h index f048a0e02..91615dbad 100644 --- a/include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h +++ b/include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h @@ -27,6 +27,9 @@ namespace imex { // XeGPU to VC Intrinsics pattern void populateXeGPUToVCIntrinsicsPatterns( mlir::SPIRVTypeConverter &typeConverter, mlir::RewritePatternSet &patterns); +// XeGPU to genISA Intrinsics pattern +void populateXeGPUToGenISAPatterns(mlir::SPIRVTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); } // namespace imex #endif // IMEX_CONVERSION_XEGPUTOSPIRV_H diff --git a/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index c149f5aaa..589b3b033 100644 --- a/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -162,8 +162,8 @@ void GPUXToSPIRVPass::runOnOperation() { }); typeConverter.addConversion( [&](xegpu::TensorDescType type) -> ::mlir::Type { - auto i64Type = ::mlir::IntegerType::get(context, 64); - return ::mlir::VectorType::get(2, i64Type); + auto i32Type = ::mlir::IntegerType::get(context, 32); + return ::mlir::VectorType::get(8, i32Type); }); typeConverter.addConversion([&](::mlir::VectorType type) -> ::mlir::Type { unsigned rank = type.getRank(); @@ -198,7 +198,10 @@ void GPUXToSPIRVPass::runOnOperation() { mlir::populateSCFToSPIRVPatterns(typeConverter, scfToSpirvCtx, patterns); mlir::cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns); mlir::populateMathToSPIRVPatterns(typeConverter, patterns); - imex::populateXeGPUToVCIntrinsicsPatterns(typeConverter, patterns); + if (this->enableVCIntrinsic) + imex::populateXeGPUToVCIntrinsicsPatterns(typeConverter, patterns); + else + imex::populateXeGPUToGenISAPatterns(typeConverter, patterns); if (failed(applyFullConversion(gpuModule, *target, std::move(patterns)))) return signalPassFailure(); diff --git a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp index 7edb0ecda..bf8700548 100644 --- a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp +++ b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp @@ -200,7 +200,8 @@ unsigned encodeOpcode(xegpu::AtomicRMWKind kind) { } void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op, - std::string name, FunctionType funcType) { + std::string name, FunctionType funcType, + bool isVC = true) { auto funcAttr = StringAttr::get(rewriter.getContext(), name); Operation *found = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); if (!found) { @@ -215,14 +216,17 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op, auto linkage = spirv::LinkageAttributesAttr::get(rewriter.getContext(), nameAttr, linkageTypeAttr); func.setLinkageAttributesAttr(linkage); - func->setAttr("VectorComputeFunctionINTEL", rewriter.getUnitAttr()); + if (isVC) + func->setAttr("VectorComputeFunctionINTEL", rewriter.getUnitAttr()); } } /// @brief -/// convert the tensor descriptor to [2xi64] which is of the format -/// -> [base pointer: i64, offsetX: i32, offsetY: i32] for 2D tensor desc -/// -> [base pointer: i64, unused] for 1D and scattered tensor desc +/// assemble the tensor descriptor payload[8xi32] which is of the format +/// -> [base pointer, surface width, surface height, surface pitch, +/// offsetX, offsetY, blockInfo] for 2D tensor desc +/// -> [base pointer, unused] for 1D and scattered tensor desc +/// only base pointer is i64, others are i32 class CreateNdDescToVCPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -233,9 +237,9 @@ class CreateNdDescToVCPattern : public OpConversionPattern { auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); // payload - auto v4i32 = VectorType::get(4, i32Type); - auto v2i64 = VectorType::get(2, i64Type); - Value payLoad = rewriter.create(loc, v2i64); + auto v8i32 = VectorType::get(8, i32Type); + auto v4i64 = VectorType::get(4, i64Type); + Value payLoad = rewriter.create(loc, v4i64); auto createIntConstant = [&](Type type, unsigned value) { auto attr = rewriter.getIntegerAttr(type, value); return rewriter.create(loc, type, attr); @@ -245,10 +249,34 @@ class CreateNdDescToVCPattern : public OpConversionPattern { auto idx0 = createIntConstant(i32Type, 0); payLoad = rewriter.create(loc, payLoad, base, idx0); + payLoad = rewriter.create(loc, v8i32, payLoad); auto tileType = op.getTensorDesc().getType(); auto rank = tileType.getRank(); if (rank == 2) { - payLoad = rewriter.create(loc, v4i32, payLoad); + auto idx2 = createIntConstant(i32Type, 2); + auto idx3 = createIntConstant(i32Type, 3); + auto idx4 = createIntConstant(i32Type, 4); + auto idx5 = createIntConstant(i32Type, 5); + auto idx6 = createIntConstant(i32Type, 6); + auto idx7 = createIntConstant(i32Type, 7); + auto blockWidth = tileType.getShape()[1]; + auto blockHeight = tileType.getShape()[0]; + // fixme: support memref for now + auto memType = cast(op.getSource().getType()); + unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth(); + auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1; + auto surfaceHeight = memType.getShape()[0] - 1; + // fixme: pitch = width for now + auto surfacePitch = surfaceWidth; + auto surfaceW = createIntConstant(i32Type, surfaceWidth); + auto surfaceH = createIntConstant(i32Type, surfaceHeight); + auto surfaceP = createIntConstant(i32Type, surfacePitch); + payLoad = rewriter.create(loc, payLoad, + surfaceW, idx2); + payLoad = rewriter.create(loc, payLoad, + surfaceH, idx3); + payLoad = rewriter.create(loc, payLoad, + surfaceP, idx4); auto createOffset = [&](unsigned idx) -> Value { Value val; if (ShapedType::isDynamic(op.getStaticOffsets()[idx])) { @@ -261,13 +289,14 @@ class CreateNdDescToVCPattern : public OpConversionPattern { }; auto offsetX = createOffset(1); auto offsetY = createOffset(0); - auto idx2 = createIntConstant(i32Type, 2); - auto idx3 = createIntConstant(i32Type, 3); payLoad = rewriter.create(loc, payLoad, - offsetX, idx2); + offsetX, idx5); payLoad = rewriter.create(loc, payLoad, - offsetY, idx3); - payLoad = rewriter.create(loc, v2i64, payLoad); + offsetY, idx6); + unsigned blockVal = ((blockHeight - 1) << 8) | (blockWidth - 1); + auto blockInfo = createIntConstant(i32Type, blockVal); + payLoad = rewriter.create(loc, payLoad, + blockInfo, idx7); } rewriter.replaceOp(op, payLoad); return success(); @@ -281,32 +310,29 @@ class UpdateNDOffsetToVCPattern : public OpConversionPattern { matchAndRewrite(UpdateNDOffsetOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto desc = adaptor.getTensorDesc(); auto i32Type = rewriter.getI32Type(); - auto v4i32 = VectorType::get(4, i32Type); - auto v2i64 = VectorType::get(2, rewriter.getI64Type()); - Value cast = rewriter.create(loc, v4i32, desc); auto offsets = adaptor.getOffsets(); + auto desc = adaptor.getTensorDesc(); for (auto i = 0; i < offsets.size(); i++) { auto offset = offsets[i]; if (auto cst = dyn_cast(offset.getDefiningOp())) if (auto attr = dyn_cast(cst.getValue()); attr && attr.getInt() == 0) continue; - auto idx2 = rewriter.create( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 2)); - auto idx3 = rewriter.create( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 3)); - Value idx = i == 0 ? idx3 : idx2; + auto idx5 = rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 5)); + auto idx6 = rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 6)); + Value idx = i == 0 ? idx6 : idx5; auto oldOffset = - rewriter.create(loc, cast, idx); + rewriter.create(loc, desc, idx); offset = rewriter.create(loc, i32Type, offset); auto newOffset = rewriter.create(loc, i32Type, oldOffset, offset); - cast = rewriter.create(loc, v4i32, cast, - newOffset, idx); + desc = rewriter.create(loc, desc, newOffset, + idx); } - rewriter.replaceOpWithNewOp(op, v2i64, cast); + rewriter.replaceOp(op, desc); return success(); } }; @@ -320,14 +346,16 @@ class CreateDescToVCPattern : public OpConversionPattern { auto loc = op.getLoc(); auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); - auto v2i64 = VectorType::get(2, i64Type); - Value payLoad = rewriter.create(loc, v2i64); + auto v8i32 = VectorType::get(8, i32Type); + auto v4i64 = VectorType::get(4, i64Type); + Value payLoad = rewriter.create(loc, v4i64); auto base = rewriter.create(loc, i64Type, adaptor.getSource()); auto idx0 = rewriter.create( loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)); payLoad = rewriter.create(loc, payLoad, base, idx0); + payLoad = rewriter.create(loc, v8i32, payLoad); rewriter.replaceOp(op, payLoad); return success(); } @@ -367,6 +395,8 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern { auto i8Type = rewriter.getI8Type(); auto i16Type = rewriter.getI16Type(); auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + auto v4i64 = VectorType::get(4, i64Type); auto vnni = false; auto transpose = false; if constexpr (isLoad) { @@ -394,11 +424,9 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern { auto nBlks = createIntConstant(i8Type, 1); auto tensorDesc = adaptor.getTensorDesc(); auto idx0 = createIntConstant(i32Type, 0); - auto base = - rewriter.create(loc, tensorDesc, idx0); - std::string typeStr; - VectorType newType; - std::tie(typeStr, newType) = encodeVectorType(rewriter, vecType, rank == 1); + auto cast = rewriter.create(loc, v4i64, tensorDesc); + auto base = rewriter.create(loc, cast, idx0); + auto [typeStr, newType] = encodeVectorType(rewriter, vecType, rank == 1); SmallVector args; if (rank == 2) { auto blockWidth = tileType.getShape()[1]; @@ -409,7 +437,7 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern { // static memref for now auto createDescOp = op.getTensorDesc().template getDefiningOp(); - auto memType = cast(createDescOp.getSource().getType()); + auto memType = llvm::cast(createDescOp.getSource().getType()); unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth(); auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1; auto surfaceHeight = memType.getShape()[0] - 1; @@ -418,14 +446,12 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern { auto surfaceW = createIntConstant(i32Type, surfaceWidth); auto surfaceH = createIntConstant(i32Type, surfaceHeight); auto surfaceP = createIntConstant(i32Type, surfacePitch); - auto v4i32 = VectorType::get(4, i32Type); - tensorDesc = rewriter.create(loc, v4i32, tensorDesc); - auto idx2 = createIntConstant(i32Type, 2); - auto idx3 = createIntConstant(i32Type, 3); + auto idx5 = createIntConstant(i32Type, 5); + auto idx6 = createIntConstant(i32Type, 6); auto offsetX = - rewriter.create(loc, tensorDesc, idx2); + rewriter.create(loc, tensorDesc, idx5); auto offsetY = - rewriter.create(loc, tensorDesc, idx3); + rewriter.create(loc, tensorDesc, idx6); args.assign({pred, l1CacheHint, l3CacheHint, dataum, trans, nBlks, blockW, blockH, transform, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY}); @@ -537,7 +563,6 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern { auto i1Type = rewriter.getI1Type(); auto i8Type = rewriter.getI8Type(); auto i32Type = rewriter.getI32Type(); - auto i64Type = rewriter.getI64Type(); auto vnni = false; auto transpose = false; if constexpr (isLoad) { @@ -605,66 +630,7 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern { rawSendMsg |= 1 << 25; } auto msg = createIntConstant(i32Type, rawSendMsg); - // payload - // payload is v8i32 = [base:i64, surfaceWidth:i32, surfaceHeight:i32, - // surefacePitch:i32, offsetX:i32, offsetY:i32, blockInfo:i32] - // the base/surfaceInfo/blockInfo are staticly from the tensor desc - // while the offsetX/Y are dynamicly udpated - auto insertPoint = rewriter.saveInsertionPoint(); - CreateNdDescOp createDescOp = *findDescOp(op.getTensorDesc()); - rewriter.setInsertionPointAfter(createDescOp); - auto v8i32 = VectorType::get(8, i32Type); - auto v4i64 = VectorType::get(4, i64Type); - Value payLoad = rewriter.create(loc, v4i64); - auto idx0 = createIntConstant(i32Type, 0); - auto desc = rewriter.getRemappedValue(createDescOp); - auto base = rewriter.create(loc, desc, idx0); - payLoad = - rewriter.create(loc, payLoad, base, idx0); - payLoad = rewriter.create(loc, v8i32, payLoad); - if (rank == 2) { - auto idx2 = createIntConstant(i32Type, 2); - auto idx3 = createIntConstant(i32Type, 3); - auto idx4 = createIntConstant(i32Type, 4); - auto idx5 = createIntConstant(i32Type, 5); - auto idx6 = createIntConstant(i32Type, 6); - auto idx7 = createIntConstant(i32Type, 7); - auto blockWidth = tileType.getShape()[1]; - auto blockHeight = tileType.getShape()[0]; - // fixme: support memref for now - auto memType = cast(createDescOp.getSource().getType()); - unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth(); - auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1; - auto surfaceHeight = memType.getShape()[0] - 1; - // fixme: pitch = width for now - auto surfacePitch = surfaceWidth; - auto surfaceW = createIntConstant(i32Type, surfaceWidth); - auto surfaceH = createIntConstant(i32Type, surfaceHeight); - auto surfaceP = createIntConstant(i32Type, surfacePitch); - payLoad = rewriter.create(loc, payLoad, - surfaceW, idx2); - payLoad = rewriter.create(loc, payLoad, - surfaceH, idx3); - payLoad = rewriter.create(loc, payLoad, - surfaceP, idx4); - unsigned blockVal = ((blockHeight - 1) << 8) | (blockWidth - 1); - auto blockInfo = createIntConstant(i32Type, blockVal); - payLoad = rewriter.create(loc, payLoad, - blockInfo, idx7); - rewriter.restoreInsertionPoint(insertPoint); - auto v4i32 = VectorType::get(4, i32Type); - auto tensorDesc = adaptor.getTensorDesc(); - tensorDesc = rewriter.create(loc, v4i32, tensorDesc); - auto offsetX = - rewriter.create(loc, tensorDesc, idx2); - auto offsetY = - rewriter.create(loc, tensorDesc, idx3); - payLoad = rewriter.create(loc, payLoad, - offsetX, idx5); - payLoad = rewriter.create(loc, payLoad, - offsetY, idx6); - } - rewriter.restoreInsertionPoint(insertPoint); + auto payLoad = adaptor.getTensorDesc(); SmallVector args{modifier, execSize, pred, numSrc1, numDst, sfid, extMsg, msg, payLoad}; if constexpr (isLoad) { @@ -794,8 +760,10 @@ class GatherScatterToRawSend : public OpConversionPattern { auto i8Type = rewriter.getI8Type(); auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); + auto v4i64 = VectorType::get(4, i64Type); auto tensorDesc = adaptor.getTensorDesc(); auto idx0 = createIntConstant(i32Type, 0); + tensorDesc = rewriter.create(loc, v4i64, tensorDesc); auto base = rewriter.create(loc, tensorDesc, idx0); VectorType newType = VectorType::get(1, i32Type); @@ -906,6 +874,7 @@ class AtomicToLsc : public OpConversionPattern { auto i16Type = rewriter.getI16Type(); auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); + auto v4i64 = VectorType::get(4, i64Type); VectorType vecType = cast(op.getResult().getType()); std::string funcName = "llvm_genx_lsc_xatomic_stateless_"; auto [typeStr, newType] = encodeVectorType(rewriter, vecType, false, true); @@ -934,6 +903,7 @@ class AtomicToLsc : public OpConversionPattern { auto mask = createIntConstant(i8Type, 0); auto tensorDesc = adaptor.getTensorDesc(); + tensorDesc = rewriter.create(loc, v4i64, tensorDesc); auto idx0 = createIntConstant(i32Type, 0); auto base = rewriter.create(loc, tensorDesc, idx0); @@ -1246,3 +1216,275 @@ void imex::populateXeGPUToVCIntrinsicsPatterns( LoadStorePrefetchNdToRawSend>( typeConverter, patterns.getContext()); } + +/// below is for XeGPU to SPIRV genISA Intrinsic + +/// @brief encodeVectorType(xxx, 8x8x2xf16, false) returns ["v64i32", 64xi32] +std::pair +encodeGenISAVectorType(ConversionPatternRewriter &rewriter, VectorType type, + bool use32bitData = true) { + auto elemType = type.getElementType(); + auto bitWidth = elemType.getIntOrFloatBitWidth(); + int size = type.getNumElements() * bitWidth / 16; + if (use32bitData) { + size /= 2; + } + std::string str = "v"; + str += std::to_string(size); + if (!use32bitData) { + str += "i16"; + elemType = rewriter.getI16Type(); + } else if (elemType == rewriter.getF32Type()) + str += "f32"; + else if (elemType == rewriter.getF16Type()) { + str += "i32"; + elemType = rewriter.getI32Type(); + } else + assert(0 && "add more support"); + auto newType = VectorType::get(size, elemType); + return std::make_pair(str, newType); +} + +class CreateNdDescToGenISA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(CreateNdDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + auto v4i32 = VectorType::get(4, i32Type); + auto v2i64 = VectorType::get(2, i64Type); + Value payLoad = rewriter.create(loc, v2i64); + auto createIntConstant = [&](Type type, unsigned value) { + auto attr = rewriter.getIntegerAttr(type, value); + return rewriter.create(loc, type, attr); + }; + auto base = rewriter.create(loc, i64Type, + adaptor.getSource()); + auto idx0 = createIntConstant(i32Type, 0); + payLoad = + rewriter.create(loc, payLoad, base, idx0); + auto tileType = op.getTensorDesc().getType(); + auto rank = tileType.getRank(); + if (rank == 2) { + payLoad = rewriter.create(loc, v4i32, payLoad); + auto createOffset = [&](unsigned idx) -> Value { + Value val; + if (ShapedType::isDynamic(op.getStaticOffsets()[idx])) { + val = op.getOffsets()[idx]; + val = rewriter.create(loc, i32Type, val); + } else { + val = createIntConstant(i32Type, op.getStaticOffsets()[idx]); + } + return val; + }; + auto offsetX = createOffset(1); + auto offsetY = createOffset(0); + auto idx2 = createIntConstant(i32Type, 2); + auto idx3 = createIntConstant(i32Type, 3); + payLoad = rewriter.create(loc, payLoad, + offsetX, idx2); + payLoad = rewriter.create(loc, payLoad, + offsetY, idx3); + payLoad = rewriter.create(loc, v2i64, payLoad); + } + rewriter.replaceOp(op, payLoad); + return success(); + } +}; + +template +class LoadStorePrefetchNdToGenISA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tileType = op.getTensorDesc().getType(); + int rank = tileType.getRank(); + assert(rank <= 2 && "only support 1d/2d load/store/prefetch for now"); + auto loc = op.getLoc(); + ::mlir::VectorType vecType; + std::string funcName; + constexpr bool isLoad = std::is_same_v; + constexpr bool isPrefetch = std::is_same_v; + if constexpr (isLoad) { + vecType = cast(op.getResult().getType()); + funcName = rank == 2 ? "llvm.genx.GenISA.LSC2DBlockRead." + : "llvm.genx.GenISA.LSCLoadBlock."; + } else if constexpr (isPrefetch) { + vecType = VectorType::get({8, 16}, rewriter.getF32Type()); + funcName = rank == 2 ? "llvm.genx.GenISA.LSC2DPrefetch." + : "llvm.genx.GenISA.LSCPrefetch"; + } else { + vecType = cast(op.getValue().getType()); + funcName = rank == 2 ? "llvm.genx.GenISA.LSC2DBlockWrite." + : "llvm.genx.GenISA.LSCStoreBlock"; + } + auto createIntConstant = [&](Type type, unsigned value) { + auto attr = rewriter.getIntegerAttr(type, value); + return rewriter.create(loc, type, attr); + }; + auto i1Type = rewriter.getI1Type(); + auto i8Type = rewriter.getI8Type(); + auto i32Type = rewriter.getI32Type(); + auto vnni = false; + auto transpose = false; + if constexpr (isLoad) { + auto vnniValue = op.getVnniAxis(); + vnni = vnniValue.has_value() && vnniValue.value() == 0 ? true : false; + auto transposeValue = op.getTranspose(); + transpose = transposeValue.has_value() && transposeValue.value()[0] == 1 + ? true + : false; + } + unsigned dataSize = vecType.getElementType().getIntOrFloatBitWidth(); + auto elemSize = createIntConstant(i8Type, dataSize); + auto trans = createIntConstant(i1Type, transpose ? 1 : 0); + // number of blocks(1 for now) + auto nBlks = createIntConstant(i8Type, 1); + auto tensorDesc = adaptor.getTensorDesc(); + auto idx0 = createIntConstant(i32Type, 0); + auto base = + rewriter.create(loc, tensorDesc, idx0); + auto [typeStr, newType] = encodeGenISAVectorType(rewriter, vecType, false); + SmallVector args; + if (rank == 2) { + auto blockWidth = tileType.getShape()[1]; + auto blockHeight = tileType.getShape()[0]; + auto blockW = createIntConstant(i32Type, blockWidth); + auto blockH = createIntConstant(i32Type, blockHeight); + auto transform = createIntConstant(i1Type, vnni ? 1 : 0); + // static memref for now + auto createDescOp = + op.getTensorDesc().template getDefiningOp(); + auto memType = cast(createDescOp.getSource().getType()); + unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth(); + auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1; + auto surfaceHeight = memType.getShape()[0] - 1; + // pitch = width for now + auto surfacePitch = surfaceWidth; + auto surfaceW = createIntConstant(i32Type, surfaceWidth); + auto surfaceH = createIntConstant(i32Type, surfaceHeight); + auto surfaceP = createIntConstant(i32Type, surfacePitch); + auto v4i32 = VectorType::get(4, i32Type); + tensorDesc = rewriter.create(loc, v4i32, tensorDesc); + auto idx2 = createIntConstant(i32Type, 2); + auto idx3 = createIntConstant(i32Type, 3); + auto offsetX = + rewriter.create(loc, tensorDesc, idx2); + auto offsetY = + rewriter.create(loc, tensorDesc, idx3); + args.assign({base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, + elemSize, blockW, blockH, nBlks, trans, transform}); + if constexpr (!isLoad && !isPrefetch) { + args.push_back(adaptor.getValue()); + } + } + if constexpr (isLoad) + funcName += typeStr; + else if constexpr (!isPrefetch) + funcName += "isVoid"; + if constexpr (isLoad) { + auto funcType = + rewriter.getFunctionType(ValueRange(args).getTypes(), newType); + Operation *opPtr = op; + lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true); + auto funcOp = + rewriter.create(loc, newType, funcName, args); + auto castTy = this->getTypeConverter()->convertType(op.getType()); + auto cast = + rewriter.create(loc, castTy, funcOp->getResult(0)); + rewriter.replaceOp(op, cast); + } else { + auto funcType = rewriter.getFunctionType(ValueRange(args).getTypes(), {}); + Operation *opPtr = op; + lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true); + rewriter.create(loc, TypeRange(), funcName, args); + rewriter.eraseOp(op); + } + return success(); + } +}; + +class DpasToGenISA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(DpasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + { + OpBuilder::InsertionGuard guard(rewriter); + auto func = op->getParentOfType(); + rewriter.setInsertionPointAfter(func); + rewriter.create( + op.getLoc(), func, spirv::ExecutionMode::SubgroupSize, 16); + } + auto i32Type = rewriter.getI32Type(); + auto createIntConstant = [&](Type type, unsigned value) { + auto attr = rewriter.getIntegerAttr(type, value); + return rewriter.create(loc, type, attr); + }; + auto encodePrecision = [&](Type type) -> uint8_t { + if (type == rewriter.getBF16Type()) + return 9; + else if (type == rewriter.getF16Type()) + return 10; + else if (type == rewriter.getTF32Type()) + return 12; + else { + assert(0 && "add more support"); + return 0; + } + }; + auto lType = op.getLhs().getType().cast(); + auto rType = op.getRhs().getType().cast(); + auto resultType = op.getResultType().cast(); + auto [lhsStr, lhsType] = encodeGenISAVectorType(rewriter, lType, false); + auto [rhsStr, rhsType] = encodeGenISAVectorType(rewriter, rType, false); + auto [newStr, newType] = encodeGenISAVectorType(rewriter, resultType); + auto lhs = + rewriter.create(loc, lhsType, adaptor.getLhs()); + auto rhs = + rewriter.create(loc, rhsType, adaptor.getRhs()); + uint8_t preca = encodePrecision(lType.getElementType()); + uint8_t precb = encodePrecision(rType.getElementType()); + auto precA = createIntConstant(i32Type, preca); + auto precB = createIntConstant(i32Type, precb); + // fixed for now + auto rc = createIntConstant(i32Type, 8); + auto sd = createIntConstant(i32Type, 8); + auto dpasW = createIntConstant(rewriter.getI1Type(), 0); + Value acc = op.getAcc() ? adaptor.getAcc() + : rewriter.create(loc, newType); + SmallVector args{acc, lhs, rhs, precA, precB, sd, rc, dpasW}; + std::string funcName = "llvm.genx.GenISA.sub.group.dpas."; + funcName += newStr; + funcName += "."; + funcName += newStr; + funcName += "."; + funcName += lhsStr; + funcName += "."; + funcName += rhsStr; + auto funcType = + rewriter.getFunctionType(ValueRange(args).getTypes(), newType); + Operation *opPtr = op; + lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true); + auto funcOp = + rewriter.create(loc, newType, funcName, args); + rewriter.replaceOp(op, funcOp); + return success(); + } +}; + +void imex::populateXeGPUToGenISAPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add, + LoadStorePrefetchNdToGenISA, + LoadStorePrefetchNdToGenISA>( + typeConverter, patterns.getContext()); +} diff --git a/test/Conversion/XeGPUToSPIRV/atomic_basic.mlir b/test/Conversion/XeGPUToSPIRV/atomic_basic.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/atomic_basic.mlir rename to test/Conversion/XeGPUToSPIRV/atomic_basic.vc.mlir diff --git a/test/Conversion/XeGPUToSPIRV/barrier_basic.mlir b/test/Conversion/XeGPUToSPIRV/barrier_basic.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/barrier_basic.mlir rename to test/Conversion/XeGPUToSPIRV/barrier_basic.vc.mlir diff --git a/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir b/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir index a4c12ec46..ac4d5e699 100644 --- a/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir +++ b/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir @@ -1,5 +1,8 @@ -// RUN: imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s -// RUN: IMEX_NOT_PREFER_RAWSEND=1 imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s --check-prefix=LSC +// RUN: imex-opt -imex-convert-gpu-to-spirv='enable-vc-intrinsic=false' %s | FileCheck %s + +#sg_map_fp16_a = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}> +#sg_map_fp16_b = #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +#sg_map_fp16_c = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}> module @gemm attributes {gpu.container_module} { memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<5.000000e-01> memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<1.099610e+00> @@ -15,35 +18,25 @@ module @gemm attributes {gpu.container_module} { gpu.dealloc %memref_0 : memref<16x16xf16> return %memref_1 : memref<8x16xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - // LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64 - // LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64 - // LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v64i32_i1_i64 - // LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v128i32_i1_i64 - // LSC: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32 - // LSC: spirv.FunctionCall @llvm_genx_lsc_store2d_stateless_i1_i64_v128f32 - // CHECK: %[[BASE:.*]] = spirv.ConvertPtrToU %arg0 : !spirv.ptr, CrossWorkgroup> to i64 - // CHECK: %[[BASE1:.*]] = spirv.VectorInsertDynamic %[[BASE]] - // CHECK: %[[BASE2:.*]] = spirv.Bitcast %[[BASE1]] - // CHECK: spirv.VectorInsertDynamic - // CHECK: spirv.VectorInsertDynamic - // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32 - // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32 - // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v64i32_i1_v8i32 - // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v128i32_i1_v8i32 - // CHECK: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32 - // CHECK: spirv.FunctionCall @llvm_genx_raw_sends2_noresult_i1_v8i32_v128f32 - %0 = xegpu.create_nd_tdesc %arg0[0, 0] {mode = vc} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> - %2 = xegpu.create_nd_tdesc %arg2[0, 0] {mode = vc} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xf16> - xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<16x16xf16> - - %3 = xegpu.load_nd %0 {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> - %4 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %5 = xegpu.dpas %3, %4 {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> - xegpu.store_nd %5, %2 {mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // CHECK: %[[a:.*]] = spirv.FunctionCall @llvm.genx.GenISA.LSC2DBlockRead.v8i16 + // CHECK: %[[a0:.*]] = spirv.Bitcast %[[a]] + // CHECK: %[[b:.*]] = spirv.FunctionCall @llvm.genx.GenISA.LSC2DBlockRead.v16i16 + // CHECK: %[[b0:.*]] = spirv.Bitcast %[[b]] + // CHECK: %[[A:.*]] = spirv.Bitcast %[[a0]] + // CHECK: %[[B:.*]] = spirv.Bitcast %[[b0]] + // CHECK: %[[C:.*]] = spirv.FunctionCall @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v16i16 + // CHECK-SAME: %[[A]], %[[B]] + // CHECK: spirv.FunctionCall @llvm.genx.GenISA.LSC2DBlockWrite.isVoid + // CHECK-SAME: %[[C]] + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> + %2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> + %3 = xegpu.load_nd %0 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> -> vector<4x1x2xf16> + %4 = xegpu.load_nd %1 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> -> vector<8x1x2xf16> + %5 = xegpu.dpas %3, %4 : vector<4x1x2xf16>, vector<8x1x2xf16> -> vector<8x1xf32> + xegpu.store_nd %5, %2 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> gpu.return } } diff --git a/test/Conversion/XeGPUToSPIRV/gemm_basic.vc.mlir b/test/Conversion/XeGPUToSPIRV/gemm_basic.vc.mlir new file mode 100644 index 000000000..a4c12ec46 --- /dev/null +++ b/test/Conversion/XeGPUToSPIRV/gemm_basic.vc.mlir @@ -0,0 +1,59 @@ +// RUN: imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s +// RUN: IMEX_NOT_PREFER_RAWSEND=1 imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s --check-prefix=LSC +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<5.000000e-01> + memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<1.099610e+00> + func.func @test(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %memref = gpu.alloc host_shared () : memref<8x16xf16> + memref.copy %arg0, %memref : memref<8x16xf16> to memref<8x16xf16> + %memref_0 = gpu.alloc host_shared () : memref<16x16xf16> + memref.copy %arg1, %memref_0 : memref<16x16xf16> to memref<16x16xf16> + %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_0 : memref<16x16xf16>, %memref_1 : memref<8x16xf32>) + gpu.dealloc %memref : memref<8x16xf16> + gpu.dealloc %memref_0 : memref<16x16xf16> + return %memref_1 : memref<8x16xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64 + // LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64 + // LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v64i32_i1_i64 + // LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v128i32_i1_i64 + // LSC: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32 + // LSC: spirv.FunctionCall @llvm_genx_lsc_store2d_stateless_i1_i64_v128f32 + // CHECK: %[[BASE:.*]] = spirv.ConvertPtrToU %arg0 : !spirv.ptr, CrossWorkgroup> to i64 + // CHECK: %[[BASE1:.*]] = spirv.VectorInsertDynamic %[[BASE]] + // CHECK: %[[BASE2:.*]] = spirv.Bitcast %[[BASE1]] + // CHECK: spirv.VectorInsertDynamic + // CHECK: spirv.VectorInsertDynamic + // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32 + // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32 + // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v64i32_i1_v8i32 + // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v128i32_i1_v8i32 + // CHECK: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32 + // CHECK: spirv.FunctionCall @llvm_genx_raw_sends2_noresult_i1_v8i32_v128f32 + %0 = xegpu.create_nd_tdesc %arg0[0, 0] {mode = vc} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xegpu.create_nd_tdesc %arg2[0, 0] {mode = vc} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<16x16xf16> + + %3 = xegpu.load_nd %0 {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %4 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %5 = xegpu.dpas %3, %4 {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + xegpu.store_nd %5, %2 {mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %0 = memref.get_global @__constant_8x16xf16 : memref<8x16xf16> + %1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16> + %2 = call @test(%0, %1) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32> + %cast = memref.cast %2 : memref<8x16xf32> to memref<*xf32> + //call @printMemrefF32(%cast) : (memref<*xf32>) -> () + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Conversion/XeGPUToSPIRV/gemm_basic_1d.mlir b/test/Conversion/XeGPUToSPIRV/gemm_basic_1d.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/gemm_basic_1d.mlir rename to test/Conversion/XeGPUToSPIRV/gemm_basic_1d.vc.mlir diff --git a/test/Conversion/XeGPUToSPIRV/gemm_basic_gather.mlir b/test/Conversion/XeGPUToSPIRV/gemm_basic_gather.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/gemm_basic_gather.mlir rename to test/Conversion/XeGPUToSPIRV/gemm_basic_gather.vc.mlir diff --git a/test/Conversion/XeGPUToSPIRV/update_offset.mlir b/test/Conversion/XeGPUToSPIRV/update_offset.vc.mlir similarity index 100% rename from test/Conversion/XeGPUToSPIRV/update_offset.mlir rename to test/Conversion/XeGPUToSPIRV/update_offset.vc.mlir diff --git a/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir b/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir new file mode 100644 index 000000000..c58d2a6c1 --- /dev/null +++ b/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir @@ -0,0 +1,111 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.0> + memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.0> + memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.0> + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %memref = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %arg0, %memref : memref<1024x1024xf16> to memref<1024x1024xf16> + %memref_0 = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %arg1, %memref_0 : memref<1024x1024xf16> to memref<1024x1024xf16> + %memref_1 = gpu.alloc host_shared () : memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + return %memref_1 : memref<1024x1024xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %0 = gpu.block_id x + %1 = gpu.block_id y + %2 = arith.muli %0, %c8 : index + %3 = arith.muli %1, %c16 : index + %4 = xegpu.create_nd_tdesc %arg2[%2, %3] {mode = vc} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %5 = xegpu.load_nd %4 {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + // each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block + %7 = xegpu.create_nd_tdesc %arg0[%2, %c0] {mode=vc}: memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %8 = xegpu.create_nd_tdesc %arg1[%c0, %3] {mode=vc}: memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5, %subA = %7, %subB = %8) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>) { + %9 = xegpu.load_nd %subA {mode=vc, vnni_axis = 1}: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %10 = xegpu.load_nd %subB {mode=vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %11 = xegpu.dpas %9, %10, %arg4 {mode=vc}: vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %12 = xegpu.update_nd_offset %subA, [%c0, %c16] {mode=vc}: !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %13 = xegpu.update_nd_offset %subB, [%c16, %c0] {mode=vc}: !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + scf.yield %11, %12, %13: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16> + } + xegpu.store_nd %6#0, %4 {mode = vc}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %0 = memref.get_global @__constant_1024x1024xf16 : memref<1024x1024xf16> + %1 = memref.get_global @__constant_1024x1024xf16_ : memref<1024x1024xf16> + %ref = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> + %init = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + // fill the top-left block 128x128 + // A matrix: row-major, start from 0.0, increase 0.01 per element + // B matrix: A matrix + 1.0 + scf.for %arg0 = %c0 to %c128 step %c1 { + scf.for %arg1 = %c0 to %c128 step %c1 { + %int0 = arith.index_cast %arg0 : index to i16 + %int1 = arith.index_cast %arg1 : index to i16 + %c128_i16 = arith.constant 128 : i16 + %idx0 = arith.muli %int0, %c128_i16 : i16 + %idx1 = arith.addi %int1, %idx0 : i16 + %fp = arith.uitofp %idx1 : i16 to f16 + %cst100 = arith.constant 100.0 : f16 + %val0 = arith.divf %fp, %cst100 : f16 + %cst1 = arith.constant 1.0 : f16 + %val1 = arith.addf %val0, %cst1 : f16 + memref.store %val0, %0[%arg0, %arg1] : memref<1024x1024xf16> + memref.store %val1, %1[%arg0, %arg1] : memref<1024x1024xf16> + } + } + // caculate the result C matrix + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %acc = memref.load %ref[%arg0, %arg1] : memref<1024x1024xf32> + %res = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %acc) -> f32 { + %a = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> + %b = memref.load %1[%arg2, %arg1] : memref<1024x1024xf16> + %c = arith.mulf %a, %b : f16 + %cc = arith.extf %c : f16 to f32 + %ccc = arith.addf %cc, %arg3 : f32 + scf.yield %ccc : f32 + } + memref.store %res, %ref[%arg0, %arg1] : memref<1024x1024xf32> + } + } + + %2 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> + %cast = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + //call @printMemrefF32(%cast) : (memref<*xf32>) -> () + %cast_ref = memref.cast %ref : memref<1024x1024xf32> to memref<*xf32> + //call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeGPU/lit.local.cfg b/test/Integration/Dialect/XeGPU/lit.local.cfg index cf920ae42..e084b0d12 100644 --- a/test/Integration/Dialect/XeGPU/lit.local.cfg +++ b/test/Integration/Dialect/XeGPU/lit.local.cfg @@ -1,5 +1,6 @@ local_excludes = [ 'gemm_1024x1024xf16.mlir', + 'gemm_1024x1024xf16.using.updateoffset.mlir', 'gemm_1024x1016x1016_f16_f16_f32.mlir', 'load2d_dpas_store2d.mlir', 'load2d-padding-f32.mlir',