Skip to content

Commit

Permalink
[XeTile][Canonicalization] Bug fix in VectorBroadcast Canonicalizatio…
Browse files Browse the repository at this point in the history
…n. (#956)
  • Loading branch information
charithaintc authored Nov 7, 2024
1 parent 42fdea5 commit 0330284
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
7 changes: 4 additions & 3 deletions lib/Dialect/XeTile/Transforms/Canonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,11 @@ struct VectorBroadcastToXetileBroadcastOpPattern
newOp->setDiscardableAttrs(discardableAttrs);
return mlir::success();
}
// If ranks are same, inner dimension is stretched in vector.broadcast. So
// broadcast dimension is 1 for this case.
// If ranks are same, decide the broadcast dimension based on the source
// vector shape.
auto broadcastDim = (sourceShape[0] == 1) ? 0 : 1;
auto newOp = rewriter.replaceOpWithNewOp<imex::xetile::BroadcastOp>(
op, resultTy, op.getSource(), llvm::ArrayRef<int64_t>({1}));
op, resultTy, op.getSource(), llvm::ArrayRef<int64_t>({broadcastDim}));
newOp->setDiscardableAttrs(discardableAttrs);
return mlir::success();
}
Expand Down
13 changes: 13 additions & 0 deletions test/Dialect/XeTile/Transforms/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,19 @@ gpu.module @test_module {
// CHECK: %[[T1:.*]] = xetile.broadcast %[[T0]] [0] : vector<1x16xf32> -> vector<8x16xf32>
// CHECK: gpu.return %[[T1]] : vector<8x16xf32>

// -----
gpu.module @test_module {
gpu.func @test_broadcast_3(%arg0 : vector<1x16xf32>) -> vector<8x16xf32> {
%0 = vector.broadcast %arg0 : vector<1x16xf32> to vector<8x16xf32>
gpu.return %0 : vector<8x16xf32>
}
}

// CHECK-LABEL: @test_broadcast_3
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<1x16xf32>) -> vector<8x16xf32>
// CHECK: %[[T0:.*]] = xetile.broadcast %[[ARG0]] [0] : vector<1x16xf32> -> vector<8x16xf32>
// CHECK: gpu.return %[[T0]] : vector<8x16xf32>

// -----
gpu.module @test_module {
gpu.func @test_multireduction_1(%arg0 : vector<64x256xf32>, %arg1 : vector<256xf32>) -> vector<256xf32> {
Expand Down

0 comments on commit 0330284

Please sign in to comment.