Skip to content

Commit

Permalink
Add transformation pattern for vector.broadcast in wg to sg pass (#946)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbpatel authored Oct 30, 2024
1 parent 4288ffe commit 07920ba
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 11 deletions.
59 changes: 48 additions & 11 deletions lib/Dialect/XeTile/Transforms/WgToSg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ class WGToSGInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
mlir::OneToNTypeMapping newMapping(op.getResult().getType());
newMapping.addInputs(0, newResultTypes);
rewriter.replaceOp(op, newInitTileOps, newMapping);

return mlir::success();
}
};
Expand Down Expand Up @@ -358,7 +357,6 @@ struct WGToSGSCFYieldOpPattern : public XeOneToNConversion<mlir::scf::YieldOp> {
mlir::LogicalResult
matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor,
imex::XeOneToNPatternRewriter &rewriter) const override {

llvm::SmallVector<mlir::Value> convertedResults;
llvm::SmallVector<mlir::Type> newResultTypes;
for (auto &values : adaptor.getResults())
Expand All @@ -383,7 +381,6 @@ class WGToSGUpdateTileOffsetOpPattern
mlir::LogicalResult
matchAndRewrite(xetile::UpdateTileOffsetOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {

llvm::SmallVector<::mlir::Value> newUpdateTileOffsetOps;
llvm::SmallVector<mlir::Type> newResultTypes;
for (auto tile : adaptor.getTile()) {
Expand Down Expand Up @@ -582,6 +579,47 @@ class WGToSGVectorTranspose
};



class WGToSGVectorBroadcast
:public XeOneToNConversion<mlir::vector::BroadcastOp> {
using XeOneToNConversion<mlir::vector::BroadcastOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::vector::BroadcastOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {
if (op.getVector().getType().getRank() != 2)
return mlir::failure();

auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());

auto srcTy = mlir::dyn_cast<mlir::VectorType>((adaptor.getSource()[0]).getType());
auto srcShape = srcTy.getShape();

auto mapAttr =
llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("map"));

if (!mapAttr) {
return mlir::failure();
}

auto sgData = mapAttr.getSgData();
auto newTy = mlir::VectorType::get({sgData[0], sgData[1]},
resType.getElementType());
auto dstShape = newTy.getShape();

if (!(srcShape[0] == 1 && srcShape[1] == dstShape[1]) &&
!(srcShape[1] == 1 && srcShape[0] == dstShape[0]))
return mlir::failure();

auto newOp = rewriter.create<mlir::vector::BroadcastOp>(
op.getLoc(), newTy, adaptor.getSource()[0]);
rewriter.replaceOp(op, newOp);
return mlir::success();
}
};


// TODO: Add more pre-ops
bool isElementWiseOp(mlir::Operation *op) {
return llvm::isa<mlir::arith::AddFOp>(op) ||
Expand Down Expand Up @@ -639,18 +677,15 @@ void analyzeInitTileOps(mlir::Operation *op) {
llvm::cast<mlir::vector::TransposeOp>(*loadUser->user_begin());
ops.push_back(transposeOp);

// Check if the transpose has only one user and that user is a TileMMAOp
// or a pre-op followed by TileMMA
if (!transposeOp->hasOneUse())
return mlir::WalkResult::skip();

auto consumerOp = *transposeOp->user_begin();

// Check if vector.transpose is consumed by TileMMA directly or
// is consumed by some pre-op and then TileMMA.
if(!llvm::isa<imex::xetile::TileMMAOp>(consumerOp)){
if(!isElementWiseOp(consumerOp))
if(!isElementWiseOp(consumerOp) &&
!(llvm::isa<mlir::vector::BroadcastOp>(consumerOp))) {
return mlir::WalkResult::skip();
}
else {
if (!(consumerOp->hasOneUse() &&
llvm::isa<imex::xetile::TileMMAOp>(*consumerOp->user_begin())))
Expand All @@ -676,7 +711,8 @@ void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter,
patterns.insert<WGToSGInitTileOpPattern, WGToSGLoadTileOpPattern,
WGToSGTileMMAOpPattern, WGToSGStoreTileOpPattern,
WGToSGSCFForOpPattern, WGToSGUpdateTileOffsetOpPattern,
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose>(patterns.getContext(), converter,
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose,
WGToSGVectorBroadcast>(patterns.getContext(), converter,
analysis);
patterns.insert<WGToSGElementWiseOpPattern<mlir::math::ExpOp, 1>,
WGToSGElementWiseOpPattern<mlir::arith::AddFOp, 2>,
Expand Down Expand Up @@ -777,7 +813,8 @@ class XeTileWgToSgPass
});

target.addDynamicallyLegalOp<mlir::arith::ConstantOp, mlir::arith::AddFOp,
mlir::math::ExpOp, mlir::vector::TransposeOp>(
mlir::math::ExpOp, mlir::vector::TransposeOp,
mlir::vector::BroadcastOp>(
[&](mlir::Operation *op) -> bool {
auto mapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(
op->getAttr("map"));
Expand Down
38 changes: 38 additions & 0 deletions test/Dialect/XeTile/Transforms/wg_to_sg_broadcast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s

gpu.module @test_broadcast {
gpu.func @test_kernel(%arg0: memref<256x384xf16>, %arg1: memref<1x384xf16>, %arg2: memref<256x512xf32>) attributes {gemm_tiles_b = 1 : i64, gemm_tiles_x = dense<[1, 1, 1, 4]> : vector<4xi64>, gemm_tiles_y = dense<[1, 1, 1, 8]> : vector<4xi64>, habana_runner.num_inputs = 2 : i64, habana_runner.tests = [{inputs = [dense<1.000000e+00> : tensor<256x384xf16>, dense<1.000000e+00> : tensor<1x384xf16>], outputs = [dense<3.840000e+02> : tensor<256x512xf32>]}], physical_nd_range = dense<1> : vector<2xi64>, region_partition = 0 : i64, region_size = 1 : i64, syn.fusion_successful, syn.tensor_signature = (tensor<256x384xf16>, tensor<1x384xf16>) -> tensor<256x512xf32>, synFusionGenOps = 6 : i64, synFusionRequiredBeamSize = 1 : i64, synFusionTotalCost = 1000015571.16 : f64} {
%c1 = arith.constant 1 : index
%c1_0 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c1_1 = arith.constant 1 : index
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1_0, %arg11 = %c1_1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c4, %arg13 = %c8, %arg14 = %c1_1) {
%c384 = arith.constant 384 : index
%c32 = arith.constant 32 : index
%cst = arith.constant {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [64, 64]>} dense<0.000000e+00> : vector<256x512xf32>
%c0 = arith.constant 0 : index
%0 = xetile.init_tile %arg0[%c0, %c0] : memref<256x384xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [64, 32]>, inner_blocks = []>>
%1 = xetile.init_tile %arg1[%c0, %c0] : memref<1x384xf16> -> !xetile.tile<1x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [1, 32]>, inner_blocks = []>>
%2:3 = scf.for %arg15 = %c0 to %c384 step %c32 iter_args(%arg16 = %cst, %arg17 = %0, %arg18 = %1) -> (vector<256x512xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [64, 32]>, inner_blocks = []>>, !xetile.tile<1x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [1, 32]>, inner_blocks = []>>) {
%4 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<1x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [1, 32]>, inner_blocks = []>>
%5 = xetile.update_tile_offset %arg17, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [64, 32]>, inner_blocks = []>>
%6 = xetile.load_tile %arg17 { padding = 0.000000e+00 : f32 } : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [64, 32]>, inner_blocks = []>> -> vector<256x32xf16>
%7 = xetile.load_tile %arg18 { padding = 0.000000e+00 : f32 } : !xetile.tile<1x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [1, 32]>, inner_blocks = []>> -> vector<1x32xf16>
//CHECK: %[[TRANSPOSE:.*]] = vector.transpose {{%.*}}, [1, 0] : vector<1x32xf16> to vector<32x1xf16>
%8 = vector.transpose %7, [1, 0] {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 1]>} : vector<1x32xf16> to vector<32x1xf16>
//CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[TRANSPOSE]] : vector<32x1xf16> to vector<32x64xf16>
%9 = vector.broadcast %8 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 64]>} : vector<32x1xf16> to vector<32x512xf16>
xegpu.compile_hint
%10 = xetile.tile_mma %6, %9, %cst {wg_map_a =#xetile.wg_map<sg_layout = [4, 8], sg_data = [64, 32]>, wg_map_b =#xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 64]>, wg_map_c =#xetile.wg_map<sg_layout = [4, 8], sg_data = [64, 64]>} : vector<256x32xf16>, vector<32x512xf16>, vector<256x512xf32> -> vector<256x512xf32>
xegpu.compile_hint
%11 = arith.addf %arg16, %10 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [64, 64]>} : vector<256x512xf32>
scf.yield %11, %5, %4 : vector<256x512xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [64, 32]>, inner_blocks = []>>, !xetile.tile<1x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [1, 32]>, inner_blocks = []>>
}
%3 = xetile.init_tile %arg2[%c0, %c0] : memref<256x512xf32> -> !xetile.tile<256x512xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [64, 64]>, inner_blocks = []>>
xetile.store_tile %2#0, %3 : vector<256x512xf32>, !xetile.tile<256x512xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [64, 64]>, inner_blocks = []>>
gpu.terminator
}
gpu.return
}
}

0 comments on commit 07920ba

Please sign in to comment.