Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix rebase] adjust code on top of new tensor descriptor #670

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 16 additions & 66 deletions lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op,
/// offsetX, offsetY, blockInfo] for 2D tensor desc
/// -> [base pointer, unused] for 1D and scattered tensor desc
/// only base pointer is i64, others are i32
class CreateNdDescToVCPattern : public OpConversionPattern<CreateNdDescOp> {
class CreateNdDescToSPIRV : public OpConversionPattern<CreateNdDescOp> {
public:
using OpConversionPattern<CreateNdDescOp>::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -313,7 +313,7 @@ class UpdateNDOffsetToVCPattern : public OpConversionPattern<UpdateNDOffsetOp> {
auto i32Type = rewriter.getI32Type();
auto offsets = adaptor.getOffsets();
auto desc = adaptor.getTensorDesc();
for (auto i = 0; i < offsets.size(); i++) {
for (size_t i = 0; i < offsets.size(); i++) {
auto offset = offsets[i];
if (auto cst = dyn_cast<spirv::ConstantOp>(offset.getDefiningOp()))
if (auto attr = dyn_cast<mlir::IntegerAttr>(cst.getValue());
Expand Down Expand Up @@ -1198,7 +1198,7 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {

void imex::populateXeGPUToVCIntrinsicsPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<CreateNdDescToVCPattern, CreateDescToVCPattern, DpasToVCPattern,
patterns.add<CreateNdDescToSPIRV, CreateDescToVCPattern, DpasToVCPattern,
AllocNbarrierToVCPattern, CreateNbarrierToVCPattern,
NbarrierArriveToVCPattern, NbarrierWaitToVCPattern,
CompilerHintToVCPattern, MfenceToVCPattern, VectorShapeCast,
Expand Down Expand Up @@ -1245,56 +1245,6 @@ encodeGenISAVectorType(ConversionPatternRewriter &rewriter, VectorType type,
return std::make_pair(str, newType);
}

class CreateNdDescToGenISA : public OpConversionPattern<CreateNdDescOp> {
public:
using OpConversionPattern<CreateNdDescOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CreateNdDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto v4i32 = VectorType::get(4, i32Type);
auto v2i64 = VectorType::get(2, i64Type);
Value payLoad = rewriter.create<spirv::UndefOp>(loc, v2i64);
auto createIntConstant = [&](Type type, unsigned value) {
auto attr = rewriter.getIntegerAttr(type, value);
return rewriter.create<spirv::ConstantOp>(loc, type, attr);
};
auto base = rewriter.create<spirv::ConvertPtrToUOp>(loc, i64Type,
adaptor.getSource());
auto idx0 = createIntConstant(i32Type, 0);
payLoad =
rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad, base, idx0);
auto tileType = op.getTensorDesc().getType();
auto rank = tileType.getRank();
if (rank == 2) {
payLoad = rewriter.create<spirv::BitcastOp>(loc, v4i32, payLoad);
auto createOffset = [&](unsigned idx) -> Value {
Value val;
if (ShapedType::isDynamic(op.getStaticOffsets()[idx])) {
val = op.getOffsets()[idx];
val = rewriter.create<arith::TruncIOp>(loc, i32Type, val);
} else {
val = createIntConstant(i32Type, op.getStaticOffsets()[idx]);
}
return val;
};
auto offsetX = createOffset(1);
auto offsetY = createOffset(0);
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetX, idx2);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetY, idx3);
payLoad = rewriter.create<spirv::BitcastOp>(loc, v2i64, payLoad);
}
rewriter.replaceOp(op, payLoad);
return success();
}
};

template <typename OpType>
class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
public:
Expand Down Expand Up @@ -1330,6 +1280,8 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
auto i1Type = rewriter.getI1Type();
auto i8Type = rewriter.getI8Type();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto v4i64 = VectorType::get(4, i64Type);
auto vnni = false;
auto transpose = false;
if constexpr (isLoad) {
Expand All @@ -1347,8 +1299,8 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
auto nBlks = createIntConstant(i8Type, 1);
auto tensorDesc = adaptor.getTensorDesc();
auto idx0 = createIntConstant(i32Type, 0);
auto base =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx0);
auto cast = rewriter.create<spirv::BitcastOp>(loc, v4i64, tensorDesc);
auto base = rewriter.create<spirv::VectorExtractDynamicOp>(loc, cast, idx0);
auto [typeStr, newType] = encodeGenISAVectorType(rewriter, vecType, false);
SmallVector<Value> args;
if (rank == 2) {
Expand All @@ -1360,7 +1312,7 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
// static memref for now
auto createDescOp =
op.getTensorDesc().template getDefiningOp<CreateNdDescOp>();
auto memType = cast<MemRefType>(createDescOp.getSource().getType());
auto memType = llvm::cast<MemRefType>(createDescOp.getSource().getType());
unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth();
auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1;
auto surfaceHeight = memType.getShape()[0] - 1;
Expand All @@ -1369,14 +1321,12 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
auto surfaceW = createIntConstant(i32Type, surfaceWidth);
auto surfaceH = createIntConstant(i32Type, surfaceHeight);
auto surfaceP = createIntConstant(i32Type, surfacePitch);
auto v4i32 = VectorType::get(4, i32Type);
tensorDesc = rewriter.create<spirv::BitcastOp>(loc, v4i32, tensorDesc);
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
auto idx5 = createIntConstant(i32Type, 5);
auto idx6 = createIntConstant(i32Type, 6);
auto offsetX =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx2);
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx5);
auto offsetY =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx3);
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx6);
args.assign({base, surfaceW, surfaceH, surfaceP, offsetX, offsetY,
elemSize, blockW, blockH, nBlks, trans, transform});
if constexpr (!isLoad && !isPrefetch) {
Expand All @@ -1391,7 +1341,7 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
auto funcType =
rewriter.getFunctionType(ValueRange(args).getTypes(), newType);
Operation *opPtr = op;
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true);
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, false);
auto funcOp =
rewriter.create<spirv::FunctionCallOp>(loc, newType, funcName, args);
auto castTy = this->getTypeConverter()->convertType(op.getType());
Expand All @@ -1401,7 +1351,7 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
} else {
auto funcType = rewriter.getFunctionType(ValueRange(args).getTypes(), {});
Operation *opPtr = op;
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true);
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, false);
rewriter.create<spirv::FunctionCallOp>(loc, TypeRange(), funcName, args);
rewriter.eraseOp(op);
}
Expand Down Expand Up @@ -1472,7 +1422,7 @@ class DpasToGenISA : public OpConversionPattern<DpasOp> {
auto funcType =
rewriter.getFunctionType(ValueRange(args).getTypes(), newType);
Operation *opPtr = op;
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true);
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, false);
auto funcOp =
rewriter.create<spirv::FunctionCallOp>(loc, newType, funcName, args);
rewriter.replaceOp(op, funcOp);
Expand All @@ -1482,7 +1432,7 @@ class DpasToGenISA : public OpConversionPattern<DpasOp> {

void imex::populateXeGPUToGenISAPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<CreateNdDescToGenISA, DpasToGenISA,
patterns.add<CreateNdDescToSPIRV, DpasToGenISA,
LoadStorePrefetchNdToGenISA<LoadNDOp>,
LoadStorePrefetchNdToGenISA<StoreNDOp>,
LoadStorePrefetchNdToGenISA<PrefetchNDOp>>(
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/XeGPUToSPIRV/gemm_basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ module @gemm attributes {gpu.container_module} {
%1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16>
%2 = call @test(%0, %1) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32>
%cast = memref.cast %2 : memref<8x16xf32> to memref<*xf32>
//call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
// call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
return
}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
Expand Down
Loading