Skip to content

Commit

Permalink
[Blocking] Add blocking patterns for xetile gather/scatter ops (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 authored Oct 28, 2024
1 parent ea2ddcf commit 2eda085
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 21 deletions.
12 changes: 6 additions & 6 deletions include/imex/Dialect/XeTile/IR/XeTileOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_sizes,
OptionalAttr<DenseI64ArrayAttr>: $const_strides,
Optional<VectorOfRankAndType<[1,2], [Index]>>: $indices);
Optional<VectorOfRankAndType<[1,2,4], [Index]>>: $indices);

let results = (outs XeTile: $tile);

Expand Down Expand Up @@ -493,7 +493,7 @@ def XeTile_UpdateTileOffsetOp : XeTile_Op<"update_tile_offset", [AttrSizedOperan
XeTile: $tile,
Optional<Index>: $offset_x,
Optional<Index>: $offset_y,
Optional<FixedVectorOfRankAndType<[1], [Index]>>:$indices);
Optional<FixedVectorOfRankAndType<[1,2,4], [Index]>>:$indices);

let results = (outs
XeTile: $result
Expand Down Expand Up @@ -644,7 +644,7 @@ def XeTile_ConvertLayoutOp: XeTile_Op<"convert_layout", [AllTypesMatch<["source"
}

def XeTile_LoadGatherOp: XeTile_Op<"load", [AllElementTypesMatch<["tile", "value"]>,
AllShapesMatch<["tile", "value", "mask"]>]> {
AllShapesMatch<["value", "mask"]>]> {
let summary = "load a set of scattered data points from memory.";
let description = [{
The `load` operation is used to load data with scattered tile (each element in the tile
Expand All @@ -656,22 +656,22 @@ def XeTile_LoadGatherOp: XeTile_Op<"load", [AllElementTypesMatch<["tile", "value
let arguments = (ins XeTile: $tile,
XeTile_MaskType: $mask,
OptionalAttr<XeTile_PaddingValueAttr>: $padding);
let results = (outs XeTile_1DOr2DVector: $value);
let results = (outs XeTile_1DOr2DOr4DVector: $value);
let assemblyFormat = [{
$tile `` `,` $mask attr-dict `:` qualified(type($tile)) `` `,` type($mask) `->` type($value)
}];
}

def XeTile_StoreScatterOp: XeTile_Op<"store", [AllElementTypesMatch<["value", "tile"]>,
AllShapesMatch<["value", "tile", "mask"]>]> {
AllShapesMatch<["value", "mask"]>]> {
let summary = "load a set of data to scattered memory locations.";
let description = [{
The `store` operation is used to store data into scattered tile (each element in the tile
is interpreted as location, one location per data element). the `mask` operand masks out
memory access so that it is safe to pass out-of-boundary addresses/offsets as long as they
are masked.
}];
let arguments = (ins XeTile_1DOr2DVector: $value,
let arguments = (ins XeTile_1DOr2DOr4DVector: $value,
XeTile: $tile,
XeTile_MaskType: $mask);
let assemblyFormat = [{
Expand Down
11 changes: 7 additions & 4 deletions include/imex/Dialect/XeTile/IR/XeTileTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface],
];

let extraClassDeclaration = [{
using TensorType::clone;
using mlir::ShapedType::Trait<TileType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TileType>::getRank;
using mlir::ShapedType::Trait<TileType>::getNumElements;
Expand All @@ -82,6 +81,10 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface],
using mlir::ShapedType::Trait<TileType>::getDimSize;
using mlir::ShapedType::Trait<TileType>::getDynamicDimIndex;

TileType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) {
return TileType::get(shape.value_or(getShape()), elementType, getEncoding());
}

TileType clone(mlir::Type elementType) {
return llvm::cast<TileType>(cloneWith(getShape(), elementType));
}
Expand Down Expand Up @@ -156,7 +159,7 @@ def XeTile_IntType : AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI6
def XeTile_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>;

// Define the scalar type for XeTile
def XeTile_ScalarType : AnyTypeOf<[XeTile_IntType, XeTile_FloatType]>;
def XeTile_ScalarType : AnyTypeOf<[XeTile_IntType, XeTile_FloatType, Index]>;

// define the source type for XeTile init_tile
def XeTile_BaseAddrType : AnyTypeOf<[MemRefOf<[XeTile_ScalarType]>, UI64, UI32, I64, I32]>;
Expand All @@ -167,12 +170,12 @@ def XeTile_2DVector : VectorOfRankAndType<[2], [XeTile_ScalarType]>;
def XeTile_4DVector : VectorOfRankAndType<[4], [XeTile_ScalarType]>;

// define the value type for XeTile load_gather and store_scatter op
def XeTile_1DOr2DVector: VectorOfRankAndType<[1, 2], [XeTile_ScalarType]>;
def XeTile_1DOr2DOr4DVector: VectorOfRankAndType<[1, 2, 4], [XeTile_ScalarType]>;

// define the value type for XeTile load_tile and store_tile op
def XeTile_2DOr4DVector: VectorOfRankAndType<[2, 4], [XeTile_ScalarType]>;

def XeTile_MaskType: VectorOfRankAndType<[1, 2], [I1]>;
def XeTile_MaskType: VectorOfRankAndType<[1, 2, 4], [I1]>;

// define the attribute type allowed for padding values for load op
def XeTile_PaddingValueAttr : AnyAttrOf<[I32Attr, F32Attr]>;
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/XeTile/IR/XeTileOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ mlir::LogicalResult InitTileOp::verify() {
if (!tileTy.getScatterAttr())
return emitOpError("Expecting a scattered TileType.");

if (tileTy.getShape() != indices.getType().getShape())
return emitOpError("Shape mismatch between indices and result tile.");
// TODO: temoprary disable it in favor of 4D representation of
// blocking pass
// if (tileTy.getShape() != indices.getType().getShape())
// return emitOpError("Shape mismatch between indices and result tile.");

return mlir::success();
}
Expand Down
124 changes: 116 additions & 8 deletions lib/Dialect/XeTile/Transforms/Blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
#include "imex/Utils/DebugUtils.h"
#include "imex/Utils/XeArch.h"

#define DEBUG_TYPE "xetile-blocking"

using namespace mlir;
using namespace llvm;
using namespace imex;
Expand Down Expand Up @@ -218,9 +220,16 @@ struct InitTileOpPattern
auto elemTy = tileTy.getElementType();
auto newTileTy = imex::xetile::TileType::get(shape, elemTy, attr);

llvm::SmallVector<mlir::Value> operands =
llvm::to_vector(adaptor.getOperands());
if (tileTy.getScatterAttr() == mlir::BoolAttr::get(getContext(), true)) {
auto indices =
addPackOp(adaptor.getIndices(), blockSize.asArrayRef(), rewriter);
operands[1] = indices;
}

auto newOp = rewriter.create<xetile::InitTileOp>(
op.getLoc(), mlir::TypeRange({newTileTy}), op->getOperands(),
op->getAttrs());
op.getLoc(), mlir::TypeRange({newTileTy}), operands, op->getAttrs());

rewriter.replaceOp(op, newOp);

Expand Down Expand Up @@ -291,6 +300,40 @@ struct LoadTileOpPattern
}
};

struct LoadGatherOpPattern
: public OpConversionPatternWithAnalysis<xetile::LoadGatherOp,
BlockingAnalysis> {

using OpConversionPatternWithAnalysis<
xetile::LoadGatherOp, BlockingAnalysis>::OpConversionPatternWithAnalysis;

::mlir::LogicalResult
matchAndRewrite(xetile::LoadGatherOp op, OpAdaptor adaptor,
OpPatternRewriter &rewriter) const override {
auto source = adaptor.getTile();
auto tileTy = mlir::cast<xetile::TileType>(source.getType());
auto blockSize = tileTy.getInnerBlocks();
auto rank = op.getValue().getType().getRank();

if (!blockSize || rank == 4)
return rewriter.notifyMatchFailure(
op, "Input is not updated or the op has been updated.\n");

auto shape = tileTy.getShape();
auto vecTy = ::mlir::VectorType::get({shape[0] / blockSize[0],
shape[1] / blockSize[1], blockSize[0],
blockSize[1]},
tileTy.getElementType());
auto mask = addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter);
mlir::Value newOp = rewriter.create<xetile::LoadGatherOp>(
op.getLoc(), vecTy, source, mask,
op.getPadding().value_or(mlir::Attribute()));
newOp = addUnpackOp(newOp, rewriter);
rewriter.replaceOp(op, newOp);
return mlir::success();
}
};

// It updates store_tile to reveal effects of innerblock attribute.
// It uses pack op to align the shape of its vector value to the tile shape.
struct StoreTileOpPattern
Expand Down Expand Up @@ -319,6 +362,36 @@ struct StoreTileOpPattern
}
};

struct StoreScatterOpPattern
: public OpConversionPatternWithAnalysis<xetile::StoreScatterOp,
BlockingAnalysis> {

using OpConversionPatternWithAnalysis<
xetile::StoreScatterOp,
BlockingAnalysis>::OpConversionPatternWithAnalysis;

::mlir::LogicalResult
matchAndRewrite(xetile::StoreScatterOp op, OpAdaptor adaptor,
OpPatternRewriter &rewriter) const override {
auto value = adaptor.getValue();
auto valTy = mlir::dyn_cast<mlir::VectorType>(value.getType());
auto tile = adaptor.getTile();
auto tileTy = mlir::cast<xetile::TileType>(tile.getType());
auto blockSize = tileTy.getInnerBlocks();

// its inputs has not been updated yet.
if (blockSize && valTy.getRank() == 2) {
value = addPackOp(value, blockSize.asArrayRef(), rewriter);
auto mask =
addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter);
rewriter.replaceOpWithNewOp<xetile::StoreScatterOp>(op, value, tile,
mask);
return mlir::success();
}
return mlir::failure();
}
};

// It updates update_tile_offset to reveal effects of innerblock attribute
// by updating the type of it result.
struct UpdateTileOffsetOpPattern
Expand All @@ -339,9 +412,15 @@ struct UpdateTileOffsetOpPattern
if (!blockSize)
return failure();

rewriter.replaceOpWithNewOp<xetile::UpdateTileOffsetOp>(
op, tileTy, tile, adaptor.getOffsetX(), adaptor.getOffsetY(),
adaptor.getIndices());
llvm::SmallVector<mlir::Value> operands =
llvm::to_vector(adaptor.getOperands());
if (tileTy.getScatterAttr() == mlir::BoolAttr::get(getContext(), true)) {
auto indices =
addPackOp(adaptor.getIndices(), blockSize.asArrayRef(), rewriter);
operands[1] = indices;
}
rewriter.replaceOpWithNewOp<xetile::UpdateTileOffsetOp>(op, operands,
op->getAttrs());
return mlir::success();
}
};
Expand Down Expand Up @@ -822,6 +901,34 @@ struct VectorCreateMaskOpPattern
}
};

struct VectorSplatOpPattern
: public OpConversionPatternWithAnalysis<mlir::vector::SplatOp,
BlockingAnalysis> {
using OpConversionPatternWithAnalysis<
mlir::vector::SplatOp, BlockingAnalysis>::OpConversionPatternWithAnalysis;

mlir::LogicalResult
matchAndRewrite(mlir::vector::SplatOp op, OpAdaptor adaptor,
OpPatternRewriter &rewriter) const override {
auto res = op.getAggregate();
auto resTy = res.getType();
auto block = analysis.getDefBlockSize(res);
if (!block || resTy.getRank() != 2)
return mlir::failure();

auto shape = resTy.getShape();
auto newTy = mlir::VectorType::get(
{shape[0] / block[0], shape[1] / block[1], block[0], block[1]},
resTy.getElementType());

auto newOp = rewriter.create<mlir::vector::SplatOp>(
op.getLoc(), newTy, op->getOperands(), op->getAttrs());
auto unpack = addUnpackOp(newOp, rewriter);
rewriter.replaceOp(op, unpack);
return mlir::success();
}
};

} // namespace Blocking

void populateXeTileBlockingPatterns(mlir::RewritePatternSet &patterns,
Expand All @@ -830,11 +937,12 @@ void populateXeTileBlockingPatterns(mlir::RewritePatternSet &patterns,
Blocking::ArithConstantOpPattern, Blocking::InitTileOpPattern,
Blocking::PrefetchTileOpPattern, Blocking::LoadTileOpPattern,
Blocking::StoreTileOpPattern, Blocking::UpdateTileOffsetOpPattern,
Blocking::LoadGatherOpPattern, Blocking::StoreScatterOpPattern,
Blocking::TileMMAOpPattern, Blocking::TileReductionOpPattern,
Blocking::TileBroadcastOpPattern, Blocking::TileTransposeOpPattern,
Blocking::VectorizableOpPattern, Blocking::SCFForOpPattern,
Blocking::SCFYieldOpPattern, Blocking::VectorCreateMaskOpPattern>(
patterns.getContext(), analysis);
Blocking::SCFYieldOpPattern, Blocking::VectorCreateMaskOpPattern,
Blocking::VectorSplatOpPattern>(patterns.getContext(), analysis);
}

// Lowers XeTile to blocked layout with high-dim vector
Expand Down Expand Up @@ -882,7 +990,7 @@ class XeTileBlockingPass : public impl::XeTileBlockingBase<XeTileBlockingPass> {
if (mlir::failed(analysis.run(mod)))
return signalPassFailure();

// analysis.printAnalysisResult();
LLVM_DEBUG(analysis.printAnalysisResult());

mlir::MLIRContext &context = getContext();

Expand Down
Loading

0 comments on commit 2eda085

Please sign in to comment.