Skip to content

Commit

Permalink
Refine #917, allow vnni propagation for bitcast-like Ops (#919)
Browse files Browse the repository at this point in the history
Refine #917, allow vnni propagation for castOps with same input and output bitwidth
  • Loading branch information
chencha3 authored Oct 10, 2024
1 parent 74105b3 commit 9d73c3c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
19 changes: 16 additions & 3 deletions lib/Transforms/VnniTransformation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,22 @@ class LayoutAnalysisImpl

// for non-cast elementwise ops only. Propagation is stopped
// when meet an cast op, e.g., truncf, in which source and result
// needs different vnni factors.
if (mlir::OpTrait::hasElementwiseMappableTraits(op) &&
!mlir::isa<mlir::CastOpInterface>(op)) {
// needs different vnni factors. An exception is bitcast op, which
// source and results has the same bitwidth.
if (mlir::OpTrait::hasElementwiseMappableTraits(op)) {
// stop propagation for cast ops that are not guaranteed
// to have same bitwidth between source and result.
if (mlir::isa<mlir::CastOpInterface>(op)) {
auto srcTy = mlir::getElementTypeOrSelf(op->getOperand(0));
auto dstTy = mlir::getElementTypeOrSelf(op->getResult(0));
if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat() ||
srcTy.getIntOrFloatBitWidth() != dstTy.getIntOrFloatBitWidth()) {
for (auto operand : operands)
propagateIfChanged(operand, operand->join(Layout(false)));
return mlir::success();
}
}

Layout layout;

// if the op has results, initial the layout to be vnni
Expand Down
19 changes: 19 additions & 0 deletions test/Transforms/VnniTransform/unit-tests.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -355,4 +355,23 @@ func.func @test(%arg1 : !xegpu.tensor_desc<8x16xf16>, %arg2 : !xegpu.tensor_desc
}
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
return %2 : vector<8x16xf32>
}

// -----

// CHECK-LABEL: @test
// CHECK-SAME: (%[[ARG1:.*]]: !xegpu.tensor_desc<8x16xi16>, %[[ARG2:.*]]: !xegpu.tensor_desc<16x16xi16>)
// CHECK: %[[r0:.*]] = xegpu.load_nd %[[ARG1]] : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
// CHECK: %[[r1:.*]] = xegpu.load_nd %[[ARG2]] <{packed}> : !xegpu.tensor_desc<16x16xi16> -> vector<8x16x2xi16>
// CHECK: %[[r2:.*]] = arith.bitcast %[[r0]] : vector<8x16xi16> to vector<8x16xf16>
// CHECK: %[[r3:.*]] = arith.bitcast %[[r1]] : vector<8x16x2xi16> to vector<8x16x2xf16>
// CHECK: %[[r4:.*]] = xegpu.dpas %[[r2]], %[[r3]] : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
// CHECK: return %[[r4]] : vector<8x16xf32>
func.func @test(%arg1 : !xegpu.tensor_desc<8x16xi16>, %arg2 : !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
%a = xegpu.load_nd %arg1 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
%b = xegpu.load_nd %arg2 : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
%0 = arith.bitcast %a : vector<8x16xi16> to vector<8x16xf16>
%1 = arith.bitcast %b : vector<16x16xi16> to vector<16x16xf16>
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
return %2 : vector<8x16xf32>
}

0 comments on commit 9d73c3c

Please sign in to comment.