diff --git a/lib/Dialect/XeTile/Transforms/Canonicalization.cpp b/lib/Dialect/XeTile/Transforms/Canonicalization.cpp index 32ac30a36..b41a97397 100644 --- a/lib/Dialect/XeTile/Transforms/Canonicalization.cpp +++ b/lib/Dialect/XeTile/Transforms/Canonicalization.cpp @@ -272,10 +272,11 @@ struct VectorBroadcastToXetileBroadcastOpPattern newOp->setDiscardableAttrs(discardableAttrs); return mlir::success(); } - // If ranks are same, inner dimension is stretched in vector.broadcast. So - // broadcast dimension is 1 for this case. + // If ranks are same, decide the broadcast dimension based on the source + // vector shape. + auto broadcastDim = (sourceShape[0] == 1) ? 0 : 1; auto newOp = rewriter.replaceOpWithNewOp( - op, resultTy, op.getSource(), llvm::ArrayRef({1})); + op, resultTy, op.getSource(), llvm::ArrayRef({broadcastDim})); newOp->setDiscardableAttrs(discardableAttrs); return mlir::success(); } diff --git a/test/Dialect/XeTile/Transforms/canonicalization.mlir b/test/Dialect/XeTile/Transforms/canonicalization.mlir index 922178edb..a3bd6db51 100644 --- a/test/Dialect/XeTile/Transforms/canonicalization.mlir +++ b/test/Dialect/XeTile/Transforms/canonicalization.mlir @@ -268,6 +268,19 @@ gpu.module @test_module { // CHECK: %[[T1:.*]] = xetile.broadcast %[[T0]] [0] : vector<1x16xf32> -> vector<8x16xf32> // CHECK: gpu.return %[[T1]] : vector<8x16xf32> +// ----- +gpu.module @test_module { + gpu.func @test_broadcast_3(%arg0 : vector<1x16xf32>) -> vector<8x16xf32> { + %0 = vector.broadcast %arg0 : vector<1x16xf32> to vector<8x16xf32> + gpu.return %0 : vector<8x16xf32> + } +} + +// CHECK-LABEL: @test_broadcast_3 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<1x16xf32>) -> vector<8x16xf32> +// CHECK: %[[T0:.*]] = xetile.broadcast %[[ARG0]] [0] : vector<1x16xf32> -> vector<8x16xf32> +// CHECK: gpu.return %[[T0]] : vector<8x16xf32> + // ----- gpu.module @test_module { gpu.func @test_multireduction_1(%arg0 : vector<64x256xf32>, %arg1 : vector<256xf32>) -> vector<256xf32> {