From 8337785b53e94033459c76ea11b298b5653da481 Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Tue, 21 Nov 2023 06:06:59 +0000 Subject: [PATCH] [fix rebase] adjust code on top of new tensor descriptor --- lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp | 82 ++++---------------- test/Conversion/XeGPUToSPIRV/gemm_basic.mlir | 2 +- 2 files changed, 17 insertions(+), 67 deletions(-) diff --git a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp index bf8700548..c803c3b92 100644 --- a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp +++ b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp @@ -227,7 +227,7 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op, /// offsetX, offsetY, blockInfo] for 2D tensor desc /// -> [base pointer, unused] for 1D and scattered tensor desc /// only base pointer is i64, others are i32 -class CreateNdDescToVCPattern : public OpConversionPattern { +class CreateNdDescToSPIRV : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult @@ -313,7 +313,7 @@ class UpdateNDOffsetToVCPattern : public OpConversionPattern { auto i32Type = rewriter.getI32Type(); auto offsets = adaptor.getOffsets(); auto desc = adaptor.getTensorDesc(); - for (auto i = 0; i < offsets.size(); i++) { + for (size_t i = 0; i < offsets.size(); i++) { auto offset = offsets[i]; if (auto cst = dyn_cast(offset.getDefiningOp())) if (auto attr = dyn_cast(cst.getValue()); @@ -1198,7 +1198,7 @@ struct VectorShapeCast final : public OpConversionPattern { void imex::populateXeGPUToVCIntrinsicsPatterns( SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(CreateNdDescOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto i32Type = rewriter.getI32Type(); - auto i64Type = rewriter.getI64Type(); - auto v4i32 = VectorType::get(4, i32Type); - auto v2i64 = VectorType::get(2, i64Type); - Value payLoad = rewriter.create(loc, v2i64); - auto createIntConstant = [&](Type type, unsigned value) { - auto attr = rewriter.getIntegerAttr(type, value); - return rewriter.create(loc, type, attr); - }; - auto base = rewriter.create(loc, i64Type, - adaptor.getSource()); - auto idx0 = createIntConstant(i32Type, 0); - payLoad = - rewriter.create(loc, payLoad, base, idx0); - auto tileType = op.getTensorDesc().getType(); - auto rank = tileType.getRank(); - if (rank == 2) { - payLoad = rewriter.create(loc, v4i32, payLoad); - auto createOffset = [&](unsigned idx) -> Value { - Value val; - if (ShapedType::isDynamic(op.getStaticOffsets()[idx])) { - val = op.getOffsets()[idx]; - val = rewriter.create(loc, i32Type, val); - } else { - val = createIntConstant(i32Type, op.getStaticOffsets()[idx]); - } - return val; - }; - auto offsetX = createOffset(1); - auto offsetY = createOffset(0); - auto idx2 = createIntConstant(i32Type, 2); - auto idx3 = createIntConstant(i32Type, 3); - payLoad = rewriter.create(loc, payLoad, - offsetX, idx2); - payLoad = rewriter.create(loc, payLoad, - offsetY, idx3); - payLoad = rewriter.create(loc, v2i64, payLoad); - } - rewriter.replaceOp(op, payLoad); - return success(); - } -}; - template class LoadStorePrefetchNdToGenISA : public OpConversionPattern { public: @@ -1330,6 +1280,8 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern { auto i1Type = rewriter.getI1Type(); auto i8Type = rewriter.getI8Type(); auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + auto v4i64 = VectorType::get(4, i64Type); auto vnni = false; auto transpose = false; if constexpr (isLoad) { @@ -1347,8 +1299,8 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern { auto nBlks = createIntConstant(i8Type, 1); auto tensorDesc = adaptor.getTensorDesc(); auto idx0 = createIntConstant(i32Type, 0); - auto base = - rewriter.create(loc, tensorDesc, idx0); + auto cast = rewriter.create(loc, v4i64, tensorDesc); + auto base = rewriter.create(loc, cast, idx0); auto [typeStr, newType] = encodeGenISAVectorType(rewriter, vecType, false); SmallVector args; if (rank == 2) { @@ -1360,7 +1312,7 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern { // static memref for now auto createDescOp = op.getTensorDesc().template getDefiningOp(); - auto memType = cast(createDescOp.getSource().getType()); + auto memType = llvm::cast(createDescOp.getSource().getType()); unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth(); auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1; auto surfaceHeight = memType.getShape()[0] - 1; @@ -1369,14 +1321,12 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern { auto surfaceW = createIntConstant(i32Type, surfaceWidth); auto surfaceH = createIntConstant(i32Type, surfaceHeight); auto surfaceP = createIntConstant(i32Type, surfacePitch); - auto v4i32 = VectorType::get(4, i32Type); - tensorDesc = rewriter.create(loc, v4i32, tensorDesc); - auto idx2 = createIntConstant(i32Type, 2); - auto idx3 = createIntConstant(i32Type, 3); + auto idx5 = createIntConstant(i32Type, 5); + auto idx6 = createIntConstant(i32Type, 6); auto offsetX = - rewriter.create(loc, tensorDesc, idx2); + rewriter.create(loc, tensorDesc, idx5); auto offsetY = - rewriter.create(loc, tensorDesc, idx3); + rewriter.create(loc, tensorDesc, idx6); args.assign({base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, elemSize, blockW, blockH, nBlks, trans, transform}); if constexpr (!isLoad && !isPrefetch) { @@ -1391,7 +1341,7 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern { auto funcType = rewriter.getFunctionType(ValueRange(args).getTypes(), newType); Operation *opPtr = op; - lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true); + lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, false); auto funcOp = rewriter.create(loc, newType, funcName, args); auto castTy = this->getTypeConverter()->convertType(op.getType()); @@ -1401,7 +1351,7 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern { } else { auto funcType = rewriter.getFunctionType(ValueRange(args).getTypes(), {}); Operation *opPtr = op; - lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true); + lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, false); rewriter.create(loc, TypeRange(), funcName, args); rewriter.eraseOp(op); } @@ -1472,7 +1422,7 @@ class DpasToGenISA : public OpConversionPattern { auto funcType = rewriter.getFunctionType(ValueRange(args).getTypes(), newType); Operation *opPtr = op; - lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true); + lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, false); auto funcOp = rewriter.create(loc, newType, funcName, args); rewriter.replaceOp(op, funcOp); @@ -1482,7 +1432,7 @@ class DpasToGenISA : public OpConversionPattern { void imex::populateXeGPUToGenISAPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add, LoadStorePrefetchNdToGenISA, LoadStorePrefetchNdToGenISA>( diff --git a/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir b/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir index ac4d5e699..3c0ca946b 100644 --- a/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir +++ b/test/Conversion/XeGPUToSPIRV/gemm_basic.mlir @@ -45,7 +45,7 @@ module @gemm attributes {gpu.container_module} { %1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16> %2 = call @test(%0, %1) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32> %cast = memref.cast %2 : memref<8x16xf32> to memref<*xf32> - //call @printMemrefF32(%cast) : (memref<*xf32>) -> () + // call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}