Skip to content

Commit

Permalink
align computation with triton
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 committed Jul 11, 2024
1 parent edc6cc0 commit 9ed0b84
Showing 1 changed file with 32 additions and 42 deletions.
74 changes: 32 additions & 42 deletions lib/gc/Transforms/FlashAttentionConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@ struct MHAToFlashAttention
ValueRange{curSumSlice, rescaledPrevSumSlice},
ValueRange{reducedShapeOut})
.getResult(0);
Value newSumSliceRecip =
rewriter
.create<linalg::ReciprocalOp>(loc, reducedShapeOut.getType(),
ValueRange{newSumSlice},
ValueRange{reducedShapeOut})
.getResult(0);
SmallVector<int64_t> VShape{cfg.RowBlockSize, headDim};
Value VShapeOut = rewriter.create<tensor::EmptyOp>(loc, VShape, dtype);
Value matmulVOutFilled =
Expand All @@ -341,38 +347,40 @@ struct MHAToFlashAttention
ValueRange{PSlice, collapsedVSlice},
ValueRange{matmulVOutFilled})
.getResult(0);
Value expMaxDiffBroadcasted =
Value newSumSliceRecipBroadcasted =
rewriter
.create<linalg::BroadcastOp>(loc, expMaxDiff, VShapeOut,
.create<linalg::BroadcastOp>(loc, newSumSliceRecip, VShapeOut,
SmallVector<int64_t>{1})
.getResults()[0];
Value expMaxDiffBroadcastedEps =
Value rescaledPrevSumSliceBroadcasted =
rewriter
.create<linalg::GenericOp>(
loc, VShapeOut.getType(), ValueRange{expMaxDiffBroadcasted},
ValueRange{VShapeOut}, indexingMaps,
SmallVector<utils::IteratorType>(2,
utils::IteratorType::parallel),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
Value eps = nestedBuilder.create<arith::ConstantOp>(
loc, nestedBuilder.getFloatAttr(dtype, 1e-9));
Value added =
nestedBuilder.create<arith::AddFOp>(loc, args[0], eps);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
})
.create<linalg::BroadcastOp>(loc, rescaledPrevSumSlice, VShapeOut,
SmallVector<int64_t>{1})
.getResults()[0];
Value rescaledMatmulV =
rewriter
.create<linalg::MulOp>(
loc, matmulVOutFilled.getType(),
ValueRange{matmulV, newSumSliceRecipBroadcasted},
ValueRange{matmulVOutFilled})
.getResult(0);
Value sumSliceQuotient =
rewriter
.create<linalg::MulOp>(loc, matmulVOutFilled.getType(),
ValueRange{rescaledPrevSumSliceBroadcasted,
newSumSliceRecipBroadcasted},
ValueRange{matmulVOutFilled})
.getResult(0);
Value rescaledOSlice =
rewriter
.create<linalg::DivOp>(
loc, VShapeOut.getType(),
ValueRange{prevOSlice, expMaxDiffBroadcastedEps},
ValueRange{VShapeOut})
.create<linalg::MulOp>(loc, matmulVOutFilled.getType(),
ValueRange{prevOSlice, sumSliceQuotient},
ValueRange{matmulVOutFilled})
.getResult(0);
Value newOSlice =
rewriter
.create<linalg::AddOp>(loc, VShapeOut.getType(),
ValueRange{rescaledOSlice, matmulV},
ValueRange{rescaledOSlice, rescaledMatmulV},
ValueRange{VShapeOut})
.getResult(0);
// yield all the results of the innermost loop.
Expand All @@ -381,25 +389,7 @@ struct MHAToFlashAttention
// yield rowBlockLoop results
rewriter.setInsertionPointToEnd(rowBlockLoop.getBody());
auto innermostLoopResults = columnBlockLoop->getResults();
Value OSliceFinal = innermostLoopResults[0],
sumSliceFinal = innermostLoopResults[2];
Value sliceShapeOut =
rewriter.create<tensor::EmptyOp>(loc, reducedShape, dtype);
Value broadcastedSliceShapeOut =
rewriter.create<tensor::EmptyOp>(loc, VShape, dtype);
Value sumSliceFinalBroadcasted =
rewriter
.create<linalg::BroadcastOp>(loc, sumSliceFinal,
broadcastedSliceShapeOut,
SmallVector<int64_t>{1})
.getResults()[0];
Value rescaledOSliceFinal =
rewriter
.create<linalg::DivOp>(
loc, broadcastedSliceShapeOut.getType(),
ValueRange{OSliceFinal, sumSliceFinalBroadcasted},
ValueRange{broadcastedSliceShapeOut})
.getResult(0);
Value OSliceFinal = innermostLoopResults[0];
SmallVector<OpFoldResult> outputOffsets;
outputOffsets.push_back(getAsOpFoldResult(ivs[0]));
outputOffsets.push_back(getAsOpFoldResult(ivs[1]));
Expand All @@ -409,8 +399,8 @@ struct MHAToFlashAttention
outputSizes[2] = rewriter.getIndexAttr(cfg.RowBlockSize);
outputSizes[3] = rewriter.getIndexAttr(headDim);
Value insertedRescaledOSlice = rewriter.create<tensor::InsertSliceOp>(
loc, rescaledOSliceFinal, rowBlockLoop.getRegionIterArgs()[0],
outputOffsets, outputSizes, strides);
loc, OSliceFinal, rowBlockLoop.getRegionIterArgs()[0], outputOffsets,
outputSizes, strides);
rewriter.create<scf::YieldOp>(loc, ValueRange{insertedRescaledOSlice});
// Add the scf.yield operations for all the outer loops.
for (auto [outerLoop, innerLoop] :
Expand Down

0 comments on commit 9ed0b84

Please sign in to comment.