Skip to content

Commit

Permalink
[XeTileToXeGPU] Add lowering patterns for xetile gather/scatter versi…
Browse files Browse the repository at this point in the history
…on ops (#948)
  • Loading branch information
chencha3 authored Oct 31, 2024
1 parent 362d432 commit d6876bd
Show file tree
Hide file tree
Showing 9 changed files with 531 additions and 93 deletions.
303 changes: 217 additions & 86 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class XeTileConversionTarget : public mlir::ConversionTarget {
addLegalOp<mlir::vector::ReductionOp>();
addLegalOp<mlir::vector::ShuffleOp>();
addLegalOp<mlir::vector::ShapeCastOp>();
addLegalOp<mlir::vector::SplatOp>();
addLegalOp<mlir::memref::ReinterpretCastOp>();

addLegalDialect<mlir::xegpu::XeGPUDialect>();
Expand Down Expand Up @@ -168,6 +167,10 @@ class XeTileConversionTarget : public mlir::ConversionTarget {
[](mlir::vector::TransposeOp op) {
return op.getResult().getType().getRank() == 2;
});

addDynamicallyLegalOp<mlir::vector::SplatOp>([&](mlir::vector::SplatOp op) {
return op.getAggregate().getType().getRank() != 4;
});
}

private:
Expand Down
8 changes: 6 additions & 2 deletions lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,12 @@ XeOneToNTypeConverter::computeTypeMapping(mlir::ValueRange original,
return mlir::failure();
auto shape = tileTy.getShape();
auto blkSZ = tdescTy.getShape();
auto arr_len = tdescTy.getArrayLength();
auto size = shape[0] / blkSZ[0] * shape[1] / (blkSZ[1] * arr_len);
auto arr_len = tdescTy.isScattered() ? 1 : tdescTy.getArrayLength();
auto totalNumElems =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>{});
auto blockNumElems =
std::accumulate(blkSZ.begin(), blkSZ.end(), 1, std::multiplies<>{});
auto size = totalNumElems / blockNumElems / arr_len;
llvm::ArrayRef<mlir::Type> types(convertedTypes.begin() + j,
convertedTypes.begin() + j + size);
resultMap.addInputs(i, types);
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,11 @@ template <typename Integertype>
Block BlockingAnalysisImpl::getInnerBlockSize(
mlir::Operation *op, mlir::Type elemTy, llvm::ArrayRef<Integertype> &shape,
int memorySpace) {
assert(elemTy.isIntOrFloat() && "only support int or float element type.");

int elemSize = elemTy.getIntOrFloatBitWidth();
// TODO: is it safe to treat index as 32 bit integer?
// Expecting index vector is mainly used for gather/scatter ops on SLM.
// in which the address is 32-bit.
int elemSize = elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 32;
const int64_t subgroupSize = uArch->getOneGRFSizeBits() / elemSize;

int maxHeight = 0, minHeight = 0, maxWidth = 0, minWidth = 0;
Expand Down
50 changes: 48 additions & 2 deletions lib/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,34 @@ namespace imex {
} // namespace imex

namespace {

// rewrite arith.constant op in form of vector<1xmxindex> into 1D form
// (vector<mxindex>)
struct ArithConstantOpConversion final
: public mlir::OpConversionPattern<mlir::arith::ConstantOp> {
using mlir::OpConversionPattern<mlir::arith::ConstantOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::arith::ConstantOp constOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto value = llvm::dyn_cast<mlir::DenseElementsAttr>(constOp.getValue());
if (!value || value.getType().getRank() != 2)
return mlir::failure();
auto type = value.getType();
auto shape = type.getShape();
auto elemTy = type.getElementType();
if (shape[0] != 1 || !elemTy.isIndex())
return mlir::failure();
auto newTy = mlir::VectorType::get({shape[1]}, elemTy);
value = value.reshape(newTy);
auto newOp =
rewriter.create<mlir::arith::ConstantOp>(constOp.getLoc(), value);
auto castOp = rewriter.create<mlir::vector::ShapeCastOp>(constOp.getLoc(),
type, newOp);
rewriter.replaceOp(constOp, castOp);
return mlir::success();
}
};

struct VectorLoadOpConversion final
: public mlir::OpConversionPattern<mlir::vector::LoadOp> {
using mlir::OpConversionPattern<mlir::vector::LoadOp>::OpConversionPattern;
Expand Down Expand Up @@ -485,11 +513,29 @@ struct VectorLinearizePass final
return (op && op.getAggregate().getType().getRank() == 1);
});

// borrowed from upstream with hacking for index type. Currently
// we only target vector<1xmxindex> to vector<mxindex> conversion. It is
// unclear whether others are valid or not; thus they are left untouched.
target.addDynamicallyLegalOp<mlir::arith::ConstantOp>(
[&](mlir::arith::ConstantOp op) -> bool {
auto vecTy = mlir::dyn_cast<mlir::VectorType>(op.getType());
if (!vecTy || vecTy.getRank() == 0)
return true;

auto elemTy = vecTy.getElementType();
if (elemTy.isIndex()) {
if (vecTy.getRank() == 2 && vecTy.getShape()[0] == 1)
return false;
return true;
}
return !mlir::vector::isLinearizableVector(vecTy);
});

patterns.add<VectorExtractStridedSliceConversion, VectorShffleOpConversion,
VectorExtractOpConversion, VectorInsertOpConversion,
VectorSplatOpConversion, VectorLoadOpConversion,
VectorStoreOpConversion, VectorCreateMaskOpConversion>(
typeConverter, context);
VectorStoreOpConversion, VectorCreateMaskOpConversion,
ArithConstantOpConversion>(typeConverter, context);

// Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
mlir::vector::populateVectorTransposeLoweringPatterns(
Expand Down
3 changes: 3 additions & 0 deletions lib/Transforms/VnniTransformation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ static bool isVNNIApplicable(mlir::Type type) {
// VNNI transform only available for 2D vectors.
if (!vecTy || vecTy.getRank() != 2)
return false;
auto elemTy = vecTy.getElementType();
if (!elemTy.isIntOrFloat())
return false;
auto factor = getVnniFactor(vecTy.getElementType());
auto shape = vecTy.getShape();
// factor == 1 means 32-bit data, and no need to apply VNNI.
Expand Down
78 changes: 78 additions & 0 deletions test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \
// RUN: --cse --convert-xetile-to-xegpu --cse --canonicalize %s -verify-diagnostics -o -| FileCheck %s

gpu.module @test {
//CHECK-LABEL: @test_init_tile_for_scattered
//CHECK-SAME: %[[arg0:.*]]: memref<1024xf16>
gpu.func @test_init_tile_for_scattered(%arg0: memref<1024xf16>) {
//CHECK: %[[cst:.*]] = arith.constant dense<true> : vector<32xi1>
//CHECK: %[[cst_0:.*]] = arith.constant dense<1> : vector<32xindex>
//CHECK: %[[r0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst_0]] : memref<1024xf16>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>
//CHECK: %[[r1:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r2:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r3:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r4:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r5:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: %[[r6:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: %[[r7:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: %[[r8:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: xegpu.store %[[r1]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
//CHECK: xegpu.store %[[r2]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
//CHECK: xegpu.store %[[r3]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
//CHECK: xegpu.store %[[r4]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
%cst = arith.constant dense<true> : vector<4x32xi1>
%cst_0 = arith.constant dense<1> : vector<4x32xindex>
%0 = xetile.init_tile %arg0, %cst_0 : memref<1024xf16>, vector<4x32xindex> -> !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>
%1 = xetile.load %0, %cst : !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>, vector<4x32xi1> -> vector<4x32xf16>
%2 = xetile.update_tile_offset %0, %cst_0 : !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>, vector<4x32xindex>
xetile.store %1, %0, %cst : vector<4x32xf16>, !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>, vector<4x32xi1>
gpu.return
}

//CHECK-LABEL: @add_kernel
//CHECK-SAME: %[[arg0:.*]]: memref<*xf32>, %[[arg1:.*]]: memref<*xf32>, %[[arg2:.*]]: memref<*xf32>
gpu.func @add_kernel(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) {
//CHECK: %[[cst:.*]] = arith.constant dense<true> : vector<16xi1>
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[cast:.*]] = memref.cast %[[arg0]] : memref<*xf32> to memref<?xf32>
//CHECK: %[[cast_0:.*]] = memref.cast %[[arg1]] : memref<*xf32> to memref<?xf32>
//CHECK: %[[cast_1:.*]] = memref.cast %[[arg2]] : memref<*xf32> to memref<?xf32>
//CHECK: %[[block_id_x:.*]] = gpu.block_id x
//CHECK: %[[r0:.*]] = arith.muli %[[block_id_x]], %[[c1024]] : index
//CHECK: %[[r1:.*]] = vector.splat %[[r0]] : vector<1x16xindex>
//CHECK: %[[r2:.*]] = vector.shape_cast %[[r1]] : vector<1x16xindex> to vector<16xindex>
//CHECK: %[[r3:.*]] = xegpu.create_tdesc %[[cast]], %[[r2]] : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>
//CHECK: %[[r4:.*]] = xegpu.load %[[r3]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
//CHECK: %[[r5:.*]] = vector.shape_cast %[[r4]] : vector<16xf32> to vector<1x16xf32>
//CHECK: %[[r6:.*]] = xegpu.load %[[r3]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
//CHECK: %[[r7:.*]] = vector.shape_cast %[[r6]] : vector<16xf32> to vector<1x16xf32>
//CHECK: %[[r8:.*]] = xegpu.create_tdesc %[[cast_0]], %[[r2]] : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>
//CHECK: %[[r9:.*]] = xegpu.load %[[r8]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
//CHECK: %[[r10:.*]] = vector.shape_cast %[[r9]] : vector<16xf32> to vector<1x16xf32>
//CHECK: %[[r11:.*]] = xegpu.load %[[r8]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
//CHECK: %[[r12:.*]] = vector.shape_cast %[[r11]] : vector<16xf32> to vector<1x16xf32>
//CHECK: %[[r13:.*]] = arith.addf %[[r5]], %[[r10]] : vector<1x16xf32>
//CHECK: %[[r14:.*]] = arith.addf %[[r7]], %[[r12]] : vector<1x16xf32>
//CHECK: %[[r15:.*]] = xegpu.create_tdesc %[[cast_1]], %[[r2]] : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>
//CHECK: %[[r16:.*]] = vector.shape_cast %[[r13]] : vector<1x16xf32> to vector<16xf32>
//CHECK: xegpu.store %[[r16]], %[[r15]], %[[cst]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<16xi1>
//CHECK: %[[r17:.*]] = vector.shape_cast %[[r14]] : vector<1x16xf32> to vector<16xf32>
//CHECK: xegpu.store %[[r17]], %[[r15]], %[[cst]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<16xi1>
%c1024 = arith.constant 1024 : index
%cst = arith.constant dense<true> : vector<1x32xi1>
%cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
%cast_0 = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
%cast_1 = memref.cast %arg2 : memref<*xf32> to memref<?xf32>
%block_id_x = gpu.block_id x
%0 = arith.muli %block_id_x, %c1024 : index
%1 = vector.splat %0 : vector<1x32xindex>
%2 = xetile.init_tile %cast, %1 : memref<?xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr<scattered = true>>
%3 = xetile.load %2, %cst : !xetile.tile<1x32xf32, #xetile.tile_attr<scattered = true>>, vector<1x32xi1> -> vector<1x32xf32>
%4 = xetile.init_tile %cast_0, %1 : memref<?xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr<scattered = true>>
%5 = xetile.load %4, %cst : !xetile.tile<1x32xf32, #xetile.tile_attr<scattered = true>>, vector<1x32xi1> -> vector<1x32xf32>
%6 = arith.addf %3, %5 : vector<1x32xf32>
%7 = xetile.init_tile %cast_1, %1 : memref<?xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr<scattered = true>>
xetile.store %6, %7, %cst : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr<scattered = true>>, vector<1x32xi1>
gpu.return
}
}
Loading

0 comments on commit d6876bd

Please sign in to comment.